From 2321ab295d1e3c6e09d6647c246fe6f59655fc98 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 22 Jun 2026 20:26:30 +0000 Subject: [PATCH] feat: add graalvm community distribution support --- README.md | 2 + .../distributors/graalvm-installer.test.ts | 115 ++++- dist/cleanup/index.js | 415 +++++++++++++----- dist/setup/index.js | 415 +++++++++++++----- docs/advanced-usage.md | 16 + src/distributions/distribution-factory.ts | 8 +- src/distributions/graalvm/installer.ts | 254 +++++++++-- 7 files changed, 971 insertions(+), 254 deletions(-) diff --git a/README.md b/README.md index 50147b5a..db413d96 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ Currently, the following distributions are supported: | `dragonwell` | [Alibaba Dragonwell JDK](https://dragonwell-jdk.io/) | [`dragonwell` license](https://www.aliyun.com/product/dragonwell/) | `sapmachine` | [SAP SapMachine JDK/JRE](https://sapmachine.io/) | [`sapmachine` license](https://github.com/SAP/SapMachine/blob/sapmachine/LICENSE) | `graalvm` | [Oracle GraalVM](https://www.graalvm.org/) | [`graalvm` license](https://www.oracle.com/downloads/licenses/graal-free-license.html) +| `graalvm-community` | [GraalVM Community](https://github.com/graalvm/graalvm-ce-builds/releases) | [`graalvm-community` license](https://github.com/oracle/graal/blob/master/LICENSE) | `jetbrains` | [JetBrains Runtime](https://github.com/JetBrains/JetBrainsRuntime/) | [`jetbrains` license](https://github.com/JetBrains/JetBrainsRuntime/blob/main/LICENSE) > [!NOTE] @@ -126,6 +127,7 @@ Currently, the following distributions are supported: > - AdoptOpenJDK got moved to Eclipse Temurin and won't be updated anymore. It is highly recommended to migrate workflows from `adopt` and `adopt-openj9`, to `temurin` and `semeru` respectively, to keep receiving software and security updates. See more details in the [Good-bye AdoptOpenJDK post](https://blog.adoptopenjdk.net/2021/08/goodbye-adoptopenjdk-hello-adoptium/). > - For Azul Zulu OpenJDK architectures x64 and arm64 are mapped to x86 / arm with proper hw_bitness. > - To comply with the GraalVM Free Terms and Conditions (GFTC) license, it is recommended to use GraalVM JDK 17 version 17.0.12, as this is the only version of GraalVM JDK 17 available under the GFTC license. Additionally, it is encouraged to consider upgrading to GraalVM JDK 21, which offers the latest features and improvements. +> - GraalVM Community is available as `distribution: 'graalvm-community'` for stable JDK 17 and later releases published on GitHub. **NOTE:** Oracle JDK 17 licensing varies by patch level. As shown on the [JDK 17 Archive](https://www.oracle.com/java/technologies/javase/jdk17-archive-downloads.html) (versions up to 17.0.12 are under the [NFTC](https://www.oracle.com/downloads/licenses/no-fee-license.html) license) and the [JDK 17.0.13+ Archive](https://www.oracle.com/java/technologies/javase/jdk17-0-13-later-archive-downloads.html) (versions 17.0.13 and later are under the [OTN](https://www.oracle.com/downloads/licenses/javase-license1.html) license). To stay on the free NFTC license, use `distribution: 'oracle'` with `java-version: '17.0.12'` (or earlier) instead of the floating `'17'`. Alternatively, upgrade to Oracle JDK 21+, which remains under the NFTC license. diff --git a/__tests__/distributors/graalvm-installer.test.ts b/__tests__/distributors/graalvm-installer.test.ts index 23f90b88..924141dc 100644 --- a/__tests__/distributors/graalvm-installer.test.ts +++ b/__tests__/distributors/graalvm-installer.test.ts @@ -3,7 +3,11 @@ import * as tc from '@actions/tool-cache'; import * as http from '@actions/http-client'; import fs from 'fs'; import path from 'path'; -import {GraalVMDistribution} from '../../src/distributions/graalvm/installer'; +import { + GraalVMCommunityDistribution, + GraalVMDistribution +} from '../../src/distributions/graalvm/installer'; +import {getJavaDistribution} from '../../src/distributions/distribution-factory'; import {JavaInstallerOptions} from '../../src/distributions/base-models'; import * as util from '../../src/util'; @@ -41,6 +45,7 @@ beforeAll(() => { describe('GraalVMDistribution', () => { let distribution: GraalVMDistribution; + let communityDistribution: GraalVMCommunityDistribution; let mockHttpClient: jest.Mocked; let spyCoreError: jest.SpyInstance; @@ -55,9 +60,11 @@ describe('GraalVMDistribution', () => { jest.clearAllMocks(); distribution = new GraalVMDistribution(defaultOptions); + communityDistribution = new GraalVMCommunityDistribution(defaultOptions); mockHttpClient = new http.HttpClient() as jest.Mocked; (distribution as any).http = mockHttpClient; + (communityDistribution as any).http = mockHttpClient; (util.getDownloadArchiveExtension as jest.Mock).mockReturnValue('tar.gz'); @@ -242,6 +249,21 @@ describe('GraalVMDistribution', () => { path: '/cached/java/path' }); }); + + it('should use a dedicated toolcache folder for GraalVM Community', async () => { + const result = await (communityDistribution as any).downloadTool(javaRelease); + + expect(tc.cacheDir).toHaveBeenCalledWith( + path.join('/tmp/extracted', 'graalvm-jdk-17.0.5'), + 'Java_GraalVM_Community_jdk', + '17.0.5', + 'x64' + ); + expect(result).toEqual({ + version: '17.0.5', + path: '/cached/java/path' + }); + }); }); describe('findPackageForDownload', () => { @@ -948,5 +970,96 @@ describe('GraalVMDistribution', () => { configurable: true }); }); + + describe('GraalVMCommunityDistribution', () => { + beforeEach(() => { + jest.spyOn(communityDistribution, 'getPlatform').mockReturnValue('linux'); + }); + + it('should resolve an exact GraalVM Community version from GitHub releases', async () => { + mockHttpClient.getJson.mockResolvedValue({ + result: [ + { + draft: false, + prerelease: false, + assets: [ + { + name: 'graalvm-community-jdk-21.0.2_linux-x64_bin.tar.gz', + browser_download_url: + 'https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.2/graalvm-community-jdk-21.0.2_linux-x64_bin.tar.gz' + } + ] + } + ], + statusCode: 200, + headers: {} + }); + + const result = await (communityDistribution as any).findPackageForDownload( + '21.0.2' + ); + + expect(result).toEqual({ + url: 'https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.2/graalvm-community-jdk-21.0.2_linux-x64_bin.tar.gz', + version: '21.0.2' + }); + }); + + it('should resolve the latest GraalVM Community release for a major version', async () => { + mockHttpClient.getJson.mockResolvedValue({ + result: [ + { + draft: false, + prerelease: false, + assets: [ + { + name: 'graalvm-community-jdk-21.0.1_linux-x64_bin.tar.gz', + browser_download_url: + 'https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.1/graalvm-community-jdk-21.0.1_linux-x64_bin.tar.gz' + } + ] + }, + { + draft: false, + prerelease: false, + assets: [ + { + name: 'graalvm-community-jdk-21.0.2_linux-x64_bin.tar.gz', + browser_download_url: + 'https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.2/graalvm-community-jdk-21.0.2_linux-x64_bin.tar.gz' + } + ] + } + ], + statusCode: 200, + headers: {} + }); + + const result = await (communityDistribution as any).findPackageForDownload( + '21' + ); + + expect(result).toEqual({ + url: 'https://github.com/graalvm/graalvm-ce-builds/releases/download/jdk-21.0.2/graalvm-community-jdk-21.0.2_linux-x64_bin.tar.gz', + version: '21.0.2' + }); + }); + + it('should reject GraalVM Community early access requests', async () => { + (communityDistribution as any).stable = false; + + await expect( + (communityDistribution as any).findPackageForDownload('23') + ).rejects.toThrow('GraalVM Community does not provide early access builds'); + }); + }); + + describe('distribution factory', () => { + it('should map graalvm-community to the community installer', () => { + const community = getJavaDistribution('graalvm-community', defaultOptions); + + expect(community).toBeInstanceOf(GraalVMCommunityDistribution); + }); + }); }); }); diff --git a/dist/cleanup/index.js b/dist/cleanup/index.js index 5a1db963..0c327816 100644 --- a/dist/cleanup/index.js +++ b/dist/cleanup/index.js @@ -28337,8 +28337,6 @@ function defaultFactory (origin, opts) { class Agent extends DispatcherBase { constructor ({ factory = defaultFactory, maxRedirections = 0, connect, ...options } = {}) { - super() - if (typeof factory !== 'function') { throw new InvalidArgumentError('factory must be a function.') } @@ -28351,6 +28349,8 @@ class Agent extends DispatcherBase { throw new InvalidArgumentError('maxRedirections must be a positive number') } + super(options) + if (connect && typeof connect !== 'function') { connect = { ...connect } } @@ -28724,6 +28724,9 @@ const EMPTY_BUF = Buffer.alloc(0) const FastBuffer = Buffer[Symbol.species] const addListener = util.addListener const removeAllListeners = util.removeAllListeners +const kIdleSocketValidation = Symbol('kIdleSocketValidation') +const kIdleSocketValidationTimeout = Symbol('kIdleSocketValidationTimeout') +const kSocketUsed = Symbol('kSocketUsed') let extractBody @@ -28946,29 +28949,71 @@ class Parser { const offset = llhttp.llhttp_get_error_pos(this.ptr) - currentBufferPtr - if (ret === constants.ERROR.PAUSED_UPGRADE) { - this.onUpgrade(data.slice(offset)) - } else if (ret === constants.ERROR.PAUSED) { - this.paused = true - socket.unshift(data.slice(offset)) - } else if (ret !== constants.ERROR.OK) { - const ptr = llhttp.llhttp_get_error_reason(this.ptr) - let message = '' - /* istanbul ignore else: difficult to make a test case for */ - if (ptr) { - const len = new Uint8Array(llhttp.memory.buffer, ptr).indexOf(0) - message = - 'Response does not match the HTTP/1.1 protocol (' + - Buffer.from(llhttp.memory.buffer, ptr, len).toString() + - ')' + if (ret !== constants.ERROR.OK) { + const body = data.subarray(offset) + + if (ret === constants.ERROR.PAUSED_UPGRADE) { + this.onUpgrade(body) + } else if (ret === constants.ERROR.PAUSED) { + this.paused = true + socket.unshift(body) + } else { + throw this.createError(ret, body) } - throw new HTTPParserError(message, constants.ERROR[ret], data.slice(offset)) } } catch (err) { util.destroy(socket, err) } } + finish () { + assert(currentParser === null) + assert(this.ptr != null) + assert(!this.paused) + + const { llhttp } = this + + let ret + + try { + currentParser = this + ret = llhttp.llhttp_finish(this.ptr) + } finally { + currentParser = null + } + + if (ret === constants.ERROR.OK) { + return null + } + + if (ret === constants.ERROR.PAUSED || ret === constants.ERROR.PAUSED_UPGRADE) { + this.paused = true + return null + } + + return this.createError(ret, EMPTY_BUF) + } + + createError (ret, data) { + const { llhttp, contentLength, bytesRead } = this + + if (contentLength && bytesRead !== parseInt(contentLength, 10)) { + return new ResponseContentLengthMismatchError() + } + + const ptr = llhttp.llhttp_get_error_reason(this.ptr) + let message = '' + if (ptr) { + const len = new Uint8Array(llhttp.memory.buffer, ptr).indexOf(0) + message = + 'Response does not match the HTTP/1.1 protocol (' + + Buffer.from(llhttp.memory.buffer, ptr, len).toString() + + ')' + } + + return new HTTPParserError(message, constants.ERROR[ret], data) + } + destroy () { assert(this.ptr != null) assert(currentParser == null) @@ -28996,6 +29041,11 @@ class Parser { return -1 } + if (client[kRunning] === 0) { + util.destroy(socket, new SocketError('bad response', util.getSocketInfo(socket))) + return -1 + } + const request = client[kQueue][client[kRunningIdx]] if (!request) { return -1 @@ -29099,6 +29149,11 @@ class Parser { return -1 } + if (client[kRunning] === 0) { + util.destroy(socket, new SocketError('bad response', util.getSocketInfo(socket))) + return -1 + } + const request = client[kQueue][client[kRunningIdx]] /* istanbul ignore next: difficult to make a test case for */ @@ -29272,6 +29327,7 @@ class Parser { request.onComplete(headers) client[kQueue][client[kRunningIdx]++] = null + socket[kSocketUsed] = true if (socket[kWriting]) { assert(client[kRunning] === 0) @@ -29330,6 +29386,9 @@ async function connectH1 (client, socket) { socket[kWriting] = false socket[kReset] = false socket[kBlocking] = false + socket[kIdleSocketValidation] = 0 + socket[kIdleSocketValidationTimeout] = null + socket[kSocketUsed] = false socket[kParser] = new Parser(client, socket, llhttpInstance) addListener(socket, 'error', function (err) { @@ -29340,8 +29399,11 @@ async function connectH1 (client, socket) { // On Mac OS, we get an ECONNRESET even if there is a full body to be forwarded // to the user. if (err.code === 'ECONNRESET' && parser.statusCode && !parser.shouldKeepAlive) { - // We treat all incoming data so for as a valid response. - parser.onMessageComplete() + const parserErr = parser.finish() + if (parserErr) { + this[kError] = parserErr + this[kClient][kOnError](parserErr) + } return } @@ -29360,8 +29422,10 @@ async function connectH1 (client, socket) { const parser = this[kParser] if (parser.statusCode && !parser.shouldKeepAlive) { - // We treat all incoming data so far as a valid response. - parser.onMessageComplete() + const parserErr = parser.finish() + if (parserErr) { + util.destroy(this, parserErr) + } return } @@ -29371,10 +29435,11 @@ async function connectH1 (client, socket) { const client = this[kClient] const parser = this[kParser] + clearIdleSocketValidation(this) + if (parser) { if (!this[kError] && parser.statusCode && !parser.shouldKeepAlive) { - // We treat all incoming data so far as a valid response. - parser.onMessageComplete() + this[kError] = parser.finish() || this[kError] } this[kParser].destroy() @@ -29437,7 +29502,7 @@ async function connectH1 (client, socket) { return socket.destroyed }, busy (request) { - if (socket[kWriting] || socket[kReset] || socket[kBlocking]) { + if (socket[kWriting] || socket[kReset] || socket[kBlocking] || socket[kIdleSocketValidation] === 1) { return true } @@ -29475,6 +29540,31 @@ async function connectH1 (client, socket) { } } +function clearIdleSocketValidation (socket) { + if (socket[kIdleSocketValidationTimeout]) { + clearTimeout(socket[kIdleSocketValidationTimeout]) + socket[kIdleSocketValidationTimeout] = null + } + + socket[kIdleSocketValidation] = 0 +} + +function scheduleIdleSocketValidation (client, socket) { + socket[kIdleSocketValidation] = 1 + socket[kIdleSocketValidationTimeout] = setTimeout(() => { + socket[kIdleSocketValidationTimeout] = null + socket[kIdleSocketValidation] = 2 + + if (client[kSocket] === socket && !socket.destroyed) { + client[kResume]() + } + }, 0) + socket[kIdleSocketValidationTimeout].unref?.() +} + +/** + * @param {import('./client.js')} client + */ function resumeH1 (client) { const socket = client[kSocket] @@ -29489,6 +29579,32 @@ function resumeH1 (client) { socket[kNoRef] = false } + if (client[kRunning] === 0 && client[kPending] > 0 && socket[kSocketUsed]) { + if (socket[kIdleSocketValidation] === 0) { + scheduleIdleSocketValidation(client, socket) + socket[kParser].readMore() + if (socket.destroyed) { + return + } + return + } + + if (socket[kIdleSocketValidation] === 1) { + socket[kParser].readMore() + if (socket.destroyed) { + return + } + return + } + } + + if (client[kRunning] === 0) { + socket[kParser].readMore() + if (socket.destroyed) { + return + } + } + if (client[kSize] === 0) { if (socket[kParser].timeoutType !== TIMEOUT_KEEP_ALIVE) { socket[kParser].setTimeout(client[kKeepAliveTimeoutValue], TIMEOUT_KEEP_ALIVE) @@ -29582,6 +29698,7 @@ function writeH1 (client, request) { } const socket = client[kSocket] + clearIdleSocketValidation(socket) const abort = (err) => { if (request.aborted || request.completed) { @@ -30903,9 +31020,10 @@ class Client extends DispatcherBase { autoSelectFamilyAttemptTimeout, // h2 maxConcurrentStreams, - allowH2 + allowH2, + webSocket } = {}) { - super() + super({ webSocket }) if (keepAlive !== undefined) { throw new InvalidArgumentError('unsupported keepAlive, use pipelining=0 instead') @@ -31438,15 +31556,24 @@ const { kDestroy, kClose, kClosed, kDestroyed, kDispatch, kInterceptors } = __nc const kOnDestroyed = Symbol('onDestroyed') const kOnClosed = Symbol('onClosed') const kInterceptedDispatch = Symbol('Intercepted Dispatch') +const kWebSocketOptions = Symbol('webSocketOptions') class DispatcherBase extends Dispatcher { - constructor () { + constructor (opts) { super() this[kDestroyed] = false this[kOnDestroyed] = null this[kClosed] = false this[kOnClosed] = [] + this[kWebSocketOptions] = opts?.webSocket ?? {} + } + + get webSocketOptions () { + return { + maxFragments: this[kWebSocketOptions].maxFragments ?? 131072, + maxPayloadSize: this[kWebSocketOptions].maxPayloadSize ?? 128 * 1024 * 1024 + } } get destroyed () { @@ -32010,8 +32137,8 @@ const kRemoveClient = Symbol('remove client') const kStats = Symbol('stats') class PoolBase extends DispatcherBase { - constructor () { - super() + constructor (opts) { + super(opts) this[kQueue] = new FixedQueue() this[kClients] = [] @@ -32271,8 +32398,6 @@ class Pool extends PoolBase { allowH2, ...options } = {}) { - super() - if (connections != null && (!Number.isFinite(connections) || connections < 0)) { throw new InvalidArgumentError('invalid connections') } @@ -32297,6 +32422,8 @@ class Pool extends PoolBase { }) } + super(options) + this[kInterceptors] = options.interceptors?.Pool && Array.isArray(options.interceptors.Pool) ? options.interceptors.Pool : [] @@ -37381,32 +37508,25 @@ function parseUnparsedAttributes (unparsedAttributes, cookieAttributeList = {}) // If the attribute-name case-insensitively matches the string // "SameSite", the user agent MUST process the cookie-av as follows: - // 1. Let enforcement be "Default". - let enforcement = 'Default' - const attributeValueLowercase = attributeValue.toLowerCase() - // 2. If cookie-av's attribute-value is a case-insensitive match for - // "None", set enforcement to "None". - if (attributeValueLowercase.includes('none')) { - enforcement = 'None' - } - // 3. If cookie-av's attribute-value is a case-insensitive match for - // "Strict", set enforcement to "Strict". - if (attributeValueLowercase.includes('strict')) { - enforcement = 'Strict' + // 1. If cookie-av's attribute-value is a case-insensitive match for + // "None", append an attribute to the cookie-attribute-list with an + // attribute-name of "SameSite" and an attribute-value of "None". + if (attributeValueLowercase === 'none') { + cookieAttributeList.sameSite = 'None' + } else if (attributeValueLowercase === 'strict') { + // 2. If cookie-av's attribute-value is a case-insensitive match for + // "Strict", append an attribute to the cookie-attribute-list with + // an attribute-name of "SameSite" and an attribute-value of + // "Strict". + cookieAttributeList.sameSite = 'Strict' + } else if (attributeValueLowercase === 'lax') { + // 3. If cookie-av's attribute-value is a case-insensitive match for + // "Lax", append an attribute to the cookie-attribute-list with an + // attribute-name of "SameSite" and an attribute-value of "Lax". + cookieAttributeList.sameSite = 'Lax' } - - // 4. If cookie-av's attribute-value is a case-insensitive match for - // "Lax", set enforcement to "Lax". - if (attributeValueLowercase.includes('lax')) { - enforcement = 'Lax' - } - - // 5. Append an attribute to the cookie-attribute-list with an - // attribute-name of "SameSite" and an attribute-value of - // enforcement. - cookieAttributeList.sameSite = enforcement } else { cookieAttributeList.unparsed ??= [] @@ -50112,40 +50232,35 @@ const tail = Buffer.from([0x00, 0x00, 0xff, 0xff]) const kBuffer = Symbol('kBuffer') const kLength = Symbol('kLength') -// Default maximum decompressed message size: 4 MB -const kDefaultMaxDecompressedSize = 4 * 1024 * 1024 - class PerMessageDeflate { /** @type {import('node:zlib').InflateRaw} */ #inflate #options = {} - /** @type {boolean} */ - #aborted = false - - /** @type {Function|null} */ - #currentCallback = null + #maxPayloadSize = 0 /** * @param {Map} extensions */ - constructor (extensions) { + constructor (extensions, options) { this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover') this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits') + + this.#maxPayloadSize = options.maxPayloadSize } + /** + * Decompress a compressed payload. + * @param {Buffer} chunk Compressed data + * @param {boolean} fin Final fragment flag + * @param {Function} callback Callback function + */ decompress (chunk, fin, callback) { // An endpoint uses the following algorithm to decompress a message. // 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the // payload of the message. // 2. Decompress the resulting data using DEFLATE. - - if (this.#aborted) { - callback(new MessageSizeExceededError()) - return - } - if (!this.#inflate) { let windowBits = Z_DEFAULT_WINDOWBITS @@ -50168,23 +50283,12 @@ class PerMessageDeflate { this.#inflate[kLength] = 0 this.#inflate.on('data', (data) => { - if (this.#aborted) { - return - } - this.#inflate[kLength] += data.length - if (this.#inflate[kLength] > kDefaultMaxDecompressedSize) { - this.#aborted = true + if (this.#maxPayloadSize > 0 && this.#inflate[kLength] > this.#maxPayloadSize) { + callback(new MessageSizeExceededError()) this.#inflate.removeAllListeners() - this.#inflate.destroy() this.#inflate = null - - if (this.#currentCallback) { - const cb = this.#currentCallback - this.#currentCallback = null - cb(new MessageSizeExceededError()) - } return } @@ -50197,14 +50301,13 @@ class PerMessageDeflate { }) } - this.#currentCallback = callback this.#inflate.write(chunk) if (fin) { this.#inflate.write(tail) } this.#inflate.flush(() => { - if (this.#aborted || !this.#inflate) { + if (!this.#inflate) { return } @@ -50212,7 +50315,6 @@ class PerMessageDeflate { this.#inflate[kBuffer].length = 0 this.#inflate[kLength] = 0 - this.#currentCallback = null callback(null, full) }) @@ -50248,6 +50350,12 @@ const { const { WebsocketFrameSend } = __nccwpck_require__(3264) const { closeWebSocketConnection } = __nccwpck_require__(86897) const { PerMessageDeflate } = __nccwpck_require__(19469) +const { MessageSizeExceededError } = __nccwpck_require__(68707) + +function failWebsocketConnectionWithCode (ws, code, reason) { + closeWebSocketConnection(ws, code, reason, Buffer.byteLength(reason)) + failWebsocketConnection(ws, reason) +} // This code was influenced by ws released under the MIT license. // Copyright (c) 2011 Einar Otto Stangvik @@ -50256,6 +50364,7 @@ const { PerMessageDeflate } = __nccwpck_require__(19469) class ByteParser extends Writable { #buffers = [] + #fragmentsBytes = 0 #byteOffset = 0 #loop = false @@ -50267,18 +50376,27 @@ class ByteParser extends Writable { /** @type {Map} */ #extensions + /** @type {number} */ + #maxFragments + + /** @type {number} */ + #maxPayloadSize + /** * @param {import('./websocket').WebSocket} ws * @param {Map|null} extensions + * @param {{ maxFragments?: number, maxPayloadSize?: number }} [options] */ - constructor (ws, extensions) { + constructor (ws, extensions, options = {}) { super() this.ws = ws this.#extensions = extensions == null ? new Map() : extensions + this.#maxFragments = options.maxFragments ?? 0 + this.#maxPayloadSize = options.maxPayloadSize ?? 0 if (this.#extensions.has('permessage-deflate')) { - this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions)) + this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions, options)) } } @@ -50294,6 +50412,19 @@ class ByteParser extends Writable { this.run(callback) } + #validatePayloadLength () { + if ( + this.#maxPayloadSize > 0 && + !isControlFrame(this.#info.opcode) && + this.#info.payloadLength + this.#fragmentsBytes > this.#maxPayloadSize + ) { + failWebsocketConnectionWithCode(this.ws, 1009, 'Payload size exceeds maximum allowed size') + return false + } + + return true + } + /** * Runs whenever a new chunk is received. * Callback is called whenever there are no more chunks buffering, @@ -50382,6 +50513,10 @@ class ByteParser extends Writable { if (payloadLength <= 125) { this.#info.payloadLength = payloadLength this.#state = parserStates.READ_DATA + + if (!this.#validatePayloadLength()) { + return + } } else if (payloadLength === 126) { this.#state = parserStates.PAYLOADLENGTH_16 } else if (payloadLength === 127) { @@ -50406,6 +50541,10 @@ class ByteParser extends Writable { this.#info.payloadLength = buffer.readUInt16BE(0) this.#state = parserStates.READ_DATA + + if (!this.#validatePayloadLength()) { + return + } } else if (this.#state === parserStates.PAYLOADLENGTH_64) { if (this.#byteOffset < 8) { return callback() @@ -50428,6 +50567,10 @@ class ByteParser extends Writable { this.#info.payloadLength = lower this.#state = parserStates.READ_DATA + + if (!this.#validatePayloadLength()) { + return + } } else if (this.#state === parserStates.READ_DATA) { if (this.#byteOffset < this.#info.payloadLength) { return callback() @@ -50440,42 +50583,58 @@ class ByteParser extends Writable { this.#state = parserStates.INFO } else { if (!this.#info.compressed) { - this.#fragments.push(body) + if (!this.writeFragments(body)) { + return + } + + if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) { + failWebsocketConnectionWithCode(this.ws, 1009, new MessageSizeExceededError().message) + return + } // If the frame is not fragmented, a message has been received. // If the frame is fragmented, it will terminate with a fin bit set // and an opcode of 0 (continuation), therefore we handle that when // parsing continuation frames, not here. if (!this.#info.fragmented && this.#info.fin) { - const fullMessage = Buffer.concat(this.#fragments) - websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage) - this.#fragments.length = 0 + websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments()) } this.#state = parserStates.INFO } else { - this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => { - if (error) { - failWebsocketConnection(this.ws, error.message) - return - } + this.#extensions.get('permessage-deflate').decompress( + body, + this.#info.fin, + (error, data) => { + if (error) { + const code = error instanceof MessageSizeExceededError ? 1009 : 1007 + failWebsocketConnectionWithCode(this.ws, code, error.message) + return + } - this.#fragments.push(data) + if (!this.writeFragments(data)) { + return + } + + if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) { + failWebsocketConnectionWithCode(this.ws, 1009, new MessageSizeExceededError().message) + return + } + + if (!this.#info.fin) { + this.#state = parserStates.INFO + this.#loop = true + this.run(callback) + return + } + + websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments()) - if (!this.#info.fin) { - this.#state = parserStates.INFO this.#loop = true + this.#state = parserStates.INFO this.run(callback) - return } - - websocketMessageReceived(this.ws, this.#info.binaryType, Buffer.concat(this.#fragments)) - - this.#loop = true - this.#state = parserStates.INFO - this.#fragments.length = 0 - this.run(callback) - }) + ) this.#loop = false break @@ -50527,6 +50686,35 @@ class ByteParser extends Writable { return buffer } + writeFragments (fragment) { + if ( + this.#maxFragments > 0 && + this.#fragments.length === this.#maxFragments + ) { + failWebsocketConnectionWithCode(this.ws, 1008, 'Too many message fragments') + return false + } + + this.#fragmentsBytes += fragment.length + this.#fragments.push(fragment) + return true + } + + consumeFragments () { + const fragments = this.#fragments + + if (fragments.length === 1) { + this.#fragmentsBytes = 0 + return fragments.shift() + } + + const output = Buffer.concat(fragments, this.#fragmentsBytes) + this.#fragments = [] + this.#fragmentsBytes = 0 + + return output + } + parseCloseBody (data) { assert(data.length !== 1) @@ -51562,7 +51750,14 @@ class WebSocket extends EventTarget { // once this happens, the connection is open this[kResponse] = response - const parser = new ByteParser(this, parsedExtensions) + const webSocketOptions = this[kController]?.dispatcher?.webSocketOptions + const maxFragments = webSocketOptions?.maxFragments + const maxPayloadSize = webSocketOptions?.maxPayloadSize + + const parser = new ByteParser(this, parsedExtensions, { + maxFragments, + maxPayloadSize + }) parser.on('drain', onParserDrain) parser.on('error', onParserError.bind(this)) diff --git a/dist/setup/index.js b/dist/setup/index.js index 43403904..33056027 100644 --- a/dist/setup/index.js +++ b/dist/setup/index.js @@ -54063,8 +54063,6 @@ function defaultFactory (origin, opts) { class Agent extends DispatcherBase { constructor ({ factory = defaultFactory, maxRedirections = 0, connect, ...options } = {}) { - super() - if (typeof factory !== 'function') { throw new InvalidArgumentError('factory must be a function.') } @@ -54077,6 +54075,8 @@ class Agent extends DispatcherBase { throw new InvalidArgumentError('maxRedirections must be a positive number') } + super(options) + if (connect && typeof connect !== 'function') { connect = { ...connect } } @@ -54450,6 +54450,9 @@ const EMPTY_BUF = Buffer.alloc(0) const FastBuffer = Buffer[Symbol.species] const addListener = util.addListener const removeAllListeners = util.removeAllListeners +const kIdleSocketValidation = Symbol('kIdleSocketValidation') +const kIdleSocketValidationTimeout = Symbol('kIdleSocketValidationTimeout') +const kSocketUsed = Symbol('kSocketUsed') let extractBody @@ -54672,29 +54675,71 @@ class Parser { const offset = llhttp.llhttp_get_error_pos(this.ptr) - currentBufferPtr - if (ret === constants.ERROR.PAUSED_UPGRADE) { - this.onUpgrade(data.slice(offset)) - } else if (ret === constants.ERROR.PAUSED) { - this.paused = true - socket.unshift(data.slice(offset)) - } else if (ret !== constants.ERROR.OK) { - const ptr = llhttp.llhttp_get_error_reason(this.ptr) - let message = '' - /* istanbul ignore else: difficult to make a test case for */ - if (ptr) { - const len = new Uint8Array(llhttp.memory.buffer, ptr).indexOf(0) - message = - 'Response does not match the HTTP/1.1 protocol (' + - Buffer.from(llhttp.memory.buffer, ptr, len).toString() + - ')' + if (ret !== constants.ERROR.OK) { + const body = data.subarray(offset) + + if (ret === constants.ERROR.PAUSED_UPGRADE) { + this.onUpgrade(body) + } else if (ret === constants.ERROR.PAUSED) { + this.paused = true + socket.unshift(body) + } else { + throw this.createError(ret, body) } - throw new HTTPParserError(message, constants.ERROR[ret], data.slice(offset)) } } catch (err) { util.destroy(socket, err) } } + finish () { + assert(currentParser === null) + assert(this.ptr != null) + assert(!this.paused) + + const { llhttp } = this + + let ret + + try { + currentParser = this + ret = llhttp.llhttp_finish(this.ptr) + } finally { + currentParser = null + } + + if (ret === constants.ERROR.OK) { + return null + } + + if (ret === constants.ERROR.PAUSED || ret === constants.ERROR.PAUSED_UPGRADE) { + this.paused = true + return null + } + + return this.createError(ret, EMPTY_BUF) + } + + createError (ret, data) { + const { llhttp, contentLength, bytesRead } = this + + if (contentLength && bytesRead !== parseInt(contentLength, 10)) { + return new ResponseContentLengthMismatchError() + } + + const ptr = llhttp.llhttp_get_error_reason(this.ptr) + let message = '' + if (ptr) { + const len = new Uint8Array(llhttp.memory.buffer, ptr).indexOf(0) + message = + 'Response does not match the HTTP/1.1 protocol (' + + Buffer.from(llhttp.memory.buffer, ptr, len).toString() + + ')' + } + + return new HTTPParserError(message, constants.ERROR[ret], data) + } + destroy () { assert(this.ptr != null) assert(currentParser == null) @@ -54722,6 +54767,11 @@ class Parser { return -1 } + if (client[kRunning] === 0) { + util.destroy(socket, new SocketError('bad response', util.getSocketInfo(socket))) + return -1 + } + const request = client[kQueue][client[kRunningIdx]] if (!request) { return -1 @@ -54825,6 +54875,11 @@ class Parser { return -1 } + if (client[kRunning] === 0) { + util.destroy(socket, new SocketError('bad response', util.getSocketInfo(socket))) + return -1 + } + const request = client[kQueue][client[kRunningIdx]] /* istanbul ignore next: difficult to make a test case for */ @@ -54998,6 +55053,7 @@ class Parser { request.onComplete(headers) client[kQueue][client[kRunningIdx]++] = null + socket[kSocketUsed] = true if (socket[kWriting]) { assert(client[kRunning] === 0) @@ -55056,6 +55112,9 @@ async function connectH1 (client, socket) { socket[kWriting] = false socket[kReset] = false socket[kBlocking] = false + socket[kIdleSocketValidation] = 0 + socket[kIdleSocketValidationTimeout] = null + socket[kSocketUsed] = false socket[kParser] = new Parser(client, socket, llhttpInstance) addListener(socket, 'error', function (err) { @@ -55066,8 +55125,11 @@ async function connectH1 (client, socket) { // On Mac OS, we get an ECONNRESET even if there is a full body to be forwarded // to the user. if (err.code === 'ECONNRESET' && parser.statusCode && !parser.shouldKeepAlive) { - // We treat all incoming data so for as a valid response. - parser.onMessageComplete() + const parserErr = parser.finish() + if (parserErr) { + this[kError] = parserErr + this[kClient][kOnError](parserErr) + } return } @@ -55086,8 +55148,10 @@ async function connectH1 (client, socket) { const parser = this[kParser] if (parser.statusCode && !parser.shouldKeepAlive) { - // We treat all incoming data so far as a valid response. - parser.onMessageComplete() + const parserErr = parser.finish() + if (parserErr) { + util.destroy(this, parserErr) + } return } @@ -55097,10 +55161,11 @@ async function connectH1 (client, socket) { const client = this[kClient] const parser = this[kParser] + clearIdleSocketValidation(this) + if (parser) { if (!this[kError] && parser.statusCode && !parser.shouldKeepAlive) { - // We treat all incoming data so far as a valid response. - parser.onMessageComplete() + this[kError] = parser.finish() || this[kError] } this[kParser].destroy() @@ -55163,7 +55228,7 @@ async function connectH1 (client, socket) { return socket.destroyed }, busy (request) { - if (socket[kWriting] || socket[kReset] || socket[kBlocking]) { + if (socket[kWriting] || socket[kReset] || socket[kBlocking] || socket[kIdleSocketValidation] === 1) { return true } @@ -55201,6 +55266,31 @@ async function connectH1 (client, socket) { } } +function clearIdleSocketValidation (socket) { + if (socket[kIdleSocketValidationTimeout]) { + clearTimeout(socket[kIdleSocketValidationTimeout]) + socket[kIdleSocketValidationTimeout] = null + } + + socket[kIdleSocketValidation] = 0 +} + +function scheduleIdleSocketValidation (client, socket) { + socket[kIdleSocketValidation] = 1 + socket[kIdleSocketValidationTimeout] = setTimeout(() => { + socket[kIdleSocketValidationTimeout] = null + socket[kIdleSocketValidation] = 2 + + if (client[kSocket] === socket && !socket.destroyed) { + client[kResume]() + } + }, 0) + socket[kIdleSocketValidationTimeout].unref?.() +} + +/** + * @param {import('./client.js')} client + */ function resumeH1 (client) { const socket = client[kSocket] @@ -55215,6 +55305,32 @@ function resumeH1 (client) { socket[kNoRef] = false } + if (client[kRunning] === 0 && client[kPending] > 0 && socket[kSocketUsed]) { + if (socket[kIdleSocketValidation] === 0) { + scheduleIdleSocketValidation(client, socket) + socket[kParser].readMore() + if (socket.destroyed) { + return + } + return + } + + if (socket[kIdleSocketValidation] === 1) { + socket[kParser].readMore() + if (socket.destroyed) { + return + } + return + } + } + + if (client[kRunning] === 0) { + socket[kParser].readMore() + if (socket.destroyed) { + return + } + } + if (client[kSize] === 0) { if (socket[kParser].timeoutType !== TIMEOUT_KEEP_ALIVE) { socket[kParser].setTimeout(client[kKeepAliveTimeoutValue], TIMEOUT_KEEP_ALIVE) @@ -55308,6 +55424,7 @@ function writeH1 (client, request) { } const socket = client[kSocket] + clearIdleSocketValidation(socket) const abort = (err) => { if (request.aborted || request.completed) { @@ -56629,9 +56746,10 @@ class Client extends DispatcherBase { autoSelectFamilyAttemptTimeout, // h2 maxConcurrentStreams, - allowH2 + allowH2, + webSocket } = {}) { - super() + super({ webSocket }) if (keepAlive !== undefined) { throw new InvalidArgumentError('unsupported keepAlive, use pipelining=0 instead') @@ -57164,15 +57282,24 @@ const { kDestroy, kClose, kClosed, kDestroyed, kDispatch, kInterceptors } = __nc const kOnDestroyed = Symbol('onDestroyed') const kOnClosed = Symbol('onClosed') const kInterceptedDispatch = Symbol('Intercepted Dispatch') +const kWebSocketOptions = Symbol('webSocketOptions') class DispatcherBase extends Dispatcher { - constructor () { + constructor (opts) { super() this[kDestroyed] = false this[kOnDestroyed] = null this[kClosed] = false this[kOnClosed] = [] + this[kWebSocketOptions] = opts?.webSocket ?? {} + } + + get webSocketOptions () { + return { + maxFragments: this[kWebSocketOptions].maxFragments ?? 131072, + maxPayloadSize: this[kWebSocketOptions].maxPayloadSize ?? 128 * 1024 * 1024 + } } get destroyed () { @@ -57736,8 +57863,8 @@ const kRemoveClient = Symbol('remove client') const kStats = Symbol('stats') class PoolBase extends DispatcherBase { - constructor () { - super() + constructor (opts) { + super(opts) this[kQueue] = new FixedQueue() this[kClients] = [] @@ -57997,8 +58124,6 @@ class Pool extends PoolBase { allowH2, ...options } = {}) { - super() - if (connections != null && (!Number.isFinite(connections) || connections < 0)) { throw new InvalidArgumentError('invalid connections') } @@ -58023,6 +58148,8 @@ class Pool extends PoolBase { }) } + super(options) + this[kInterceptors] = options.interceptors?.Pool && Array.isArray(options.interceptors.Pool) ? options.interceptors.Pool : [] @@ -63107,32 +63234,25 @@ function parseUnparsedAttributes (unparsedAttributes, cookieAttributeList = {}) // If the attribute-name case-insensitively matches the string // "SameSite", the user agent MUST process the cookie-av as follows: - // 1. Let enforcement be "Default". - let enforcement = 'Default' - const attributeValueLowercase = attributeValue.toLowerCase() - // 2. If cookie-av's attribute-value is a case-insensitive match for - // "None", set enforcement to "None". - if (attributeValueLowercase.includes('none')) { - enforcement = 'None' - } - // 3. If cookie-av's attribute-value is a case-insensitive match for - // "Strict", set enforcement to "Strict". - if (attributeValueLowercase.includes('strict')) { - enforcement = 'Strict' + // 1. If cookie-av's attribute-value is a case-insensitive match for + // "None", append an attribute to the cookie-attribute-list with an + // attribute-name of "SameSite" and an attribute-value of "None". + if (attributeValueLowercase === 'none') { + cookieAttributeList.sameSite = 'None' + } else if (attributeValueLowercase === 'strict') { + // 2. If cookie-av's attribute-value is a case-insensitive match for + // "Strict", append an attribute to the cookie-attribute-list with + // an attribute-name of "SameSite" and an attribute-value of + // "Strict". + cookieAttributeList.sameSite = 'Strict' + } else if (attributeValueLowercase === 'lax') { + // 3. If cookie-av's attribute-value is a case-insensitive match for + // "Lax", append an attribute to the cookie-attribute-list with an + // attribute-name of "SameSite" and an attribute-value of "Lax". + cookieAttributeList.sameSite = 'Lax' } - - // 4. If cookie-av's attribute-value is a case-insensitive match for - // "Lax", set enforcement to "Lax". - if (attributeValueLowercase.includes('lax')) { - enforcement = 'Lax' - } - - // 5. Append an attribute to the cookie-attribute-list with an - // attribute-name of "SameSite" and an attribute-value of - // enforcement. - cookieAttributeList.sameSite = enforcement } else { cookieAttributeList.unparsed ??= [] @@ -75838,40 +75958,35 @@ const tail = Buffer.from([0x00, 0x00, 0xff, 0xff]) const kBuffer = Symbol('kBuffer') const kLength = Symbol('kLength') -// Default maximum decompressed message size: 4 MB -const kDefaultMaxDecompressedSize = 4 * 1024 * 1024 - class PerMessageDeflate { /** @type {import('node:zlib').InflateRaw} */ #inflate #options = {} - /** @type {boolean} */ - #aborted = false - - /** @type {Function|null} */ - #currentCallback = null + #maxPayloadSize = 0 /** * @param {Map} extensions */ - constructor (extensions) { + constructor (extensions, options) { this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover') this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits') + + this.#maxPayloadSize = options.maxPayloadSize } + /** + * Decompress a compressed payload. + * @param {Buffer} chunk Compressed data + * @param {boolean} fin Final fragment flag + * @param {Function} callback Callback function + */ decompress (chunk, fin, callback) { // An endpoint uses the following algorithm to decompress a message. // 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the // payload of the message. // 2. Decompress the resulting data using DEFLATE. - - if (this.#aborted) { - callback(new MessageSizeExceededError()) - return - } - if (!this.#inflate) { let windowBits = Z_DEFAULT_WINDOWBITS @@ -75894,23 +76009,12 @@ class PerMessageDeflate { this.#inflate[kLength] = 0 this.#inflate.on('data', (data) => { - if (this.#aborted) { - return - } - this.#inflate[kLength] += data.length - if (this.#inflate[kLength] > kDefaultMaxDecompressedSize) { - this.#aborted = true + if (this.#maxPayloadSize > 0 && this.#inflate[kLength] > this.#maxPayloadSize) { + callback(new MessageSizeExceededError()) this.#inflate.removeAllListeners() - this.#inflate.destroy() this.#inflate = null - - if (this.#currentCallback) { - const cb = this.#currentCallback - this.#currentCallback = null - cb(new MessageSizeExceededError()) - } return } @@ -75923,14 +76027,13 @@ class PerMessageDeflate { }) } - this.#currentCallback = callback this.#inflate.write(chunk) if (fin) { this.#inflate.write(tail) } this.#inflate.flush(() => { - if (this.#aborted || !this.#inflate) { + if (!this.#inflate) { return } @@ -75938,7 +76041,6 @@ class PerMessageDeflate { this.#inflate[kBuffer].length = 0 this.#inflate[kLength] = 0 - this.#currentCallback = null callback(null, full) }) @@ -75974,6 +76076,12 @@ const { const { WebsocketFrameSend } = __nccwpck_require__(3264) const { closeWebSocketConnection } = __nccwpck_require__(86897) const { PerMessageDeflate } = __nccwpck_require__(19469) +const { MessageSizeExceededError } = __nccwpck_require__(68707) + +function failWebsocketConnectionWithCode (ws, code, reason) { + closeWebSocketConnection(ws, code, reason, Buffer.byteLength(reason)) + failWebsocketConnection(ws, reason) +} // This code was influenced by ws released under the MIT license. // Copyright (c) 2011 Einar Otto Stangvik @@ -75982,6 +76090,7 @@ const { PerMessageDeflate } = __nccwpck_require__(19469) class ByteParser extends Writable { #buffers = [] + #fragmentsBytes = 0 #byteOffset = 0 #loop = false @@ -75993,18 +76102,27 @@ class ByteParser extends Writable { /** @type {Map} */ #extensions + /** @type {number} */ + #maxFragments + + /** @type {number} */ + #maxPayloadSize + /** * @param {import('./websocket').WebSocket} ws * @param {Map|null} extensions + * @param {{ maxFragments?: number, maxPayloadSize?: number }} [options] */ - constructor (ws, extensions) { + constructor (ws, extensions, options = {}) { super() this.ws = ws this.#extensions = extensions == null ? new Map() : extensions + this.#maxFragments = options.maxFragments ?? 0 + this.#maxPayloadSize = options.maxPayloadSize ?? 0 if (this.#extensions.has('permessage-deflate')) { - this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions)) + this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions, options)) } } @@ -76020,6 +76138,19 @@ class ByteParser extends Writable { this.run(callback) } + #validatePayloadLength () { + if ( + this.#maxPayloadSize > 0 && + !isControlFrame(this.#info.opcode) && + this.#info.payloadLength + this.#fragmentsBytes > this.#maxPayloadSize + ) { + failWebsocketConnectionWithCode(this.ws, 1009, 'Payload size exceeds maximum allowed size') + return false + } + + return true + } + /** * Runs whenever a new chunk is received. * Callback is called whenever there are no more chunks buffering, @@ -76108,6 +76239,10 @@ class ByteParser extends Writable { if (payloadLength <= 125) { this.#info.payloadLength = payloadLength this.#state = parserStates.READ_DATA + + if (!this.#validatePayloadLength()) { + return + } } else if (payloadLength === 126) { this.#state = parserStates.PAYLOADLENGTH_16 } else if (payloadLength === 127) { @@ -76132,6 +76267,10 @@ class ByteParser extends Writable { this.#info.payloadLength = buffer.readUInt16BE(0) this.#state = parserStates.READ_DATA + + if (!this.#validatePayloadLength()) { + return + } } else if (this.#state === parserStates.PAYLOADLENGTH_64) { if (this.#byteOffset < 8) { return callback() @@ -76154,6 +76293,10 @@ class ByteParser extends Writable { this.#info.payloadLength = lower this.#state = parserStates.READ_DATA + + if (!this.#validatePayloadLength()) { + return + } } else if (this.#state === parserStates.READ_DATA) { if (this.#byteOffset < this.#info.payloadLength) { return callback() @@ -76166,42 +76309,58 @@ class ByteParser extends Writable { this.#state = parserStates.INFO } else { if (!this.#info.compressed) { - this.#fragments.push(body) + if (!this.writeFragments(body)) { + return + } + + if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) { + failWebsocketConnectionWithCode(this.ws, 1009, new MessageSizeExceededError().message) + return + } // If the frame is not fragmented, a message has been received. // If the frame is fragmented, it will terminate with a fin bit set // and an opcode of 0 (continuation), therefore we handle that when // parsing continuation frames, not here. if (!this.#info.fragmented && this.#info.fin) { - const fullMessage = Buffer.concat(this.#fragments) - websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage) - this.#fragments.length = 0 + websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments()) } this.#state = parserStates.INFO } else { - this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => { - if (error) { - failWebsocketConnection(this.ws, error.message) - return - } + this.#extensions.get('permessage-deflate').decompress( + body, + this.#info.fin, + (error, data) => { + if (error) { + const code = error instanceof MessageSizeExceededError ? 1009 : 1007 + failWebsocketConnectionWithCode(this.ws, code, error.message) + return + } - this.#fragments.push(data) + if (!this.writeFragments(data)) { + return + } + + if (this.#maxPayloadSize > 0 && this.#fragmentsBytes > this.#maxPayloadSize) { + failWebsocketConnectionWithCode(this.ws, 1009, new MessageSizeExceededError().message) + return + } + + if (!this.#info.fin) { + this.#state = parserStates.INFO + this.#loop = true + this.run(callback) + return + } + + websocketMessageReceived(this.ws, this.#info.binaryType, this.consumeFragments()) - if (!this.#info.fin) { - this.#state = parserStates.INFO this.#loop = true + this.#state = parserStates.INFO this.run(callback) - return } - - websocketMessageReceived(this.ws, this.#info.binaryType, Buffer.concat(this.#fragments)) - - this.#loop = true - this.#state = parserStates.INFO - this.#fragments.length = 0 - this.run(callback) - }) + ) this.#loop = false break @@ -76253,6 +76412,35 @@ class ByteParser extends Writable { return buffer } + writeFragments (fragment) { + if ( + this.#maxFragments > 0 && + this.#fragments.length === this.#maxFragments + ) { + failWebsocketConnectionWithCode(this.ws, 1008, 'Too many message fragments') + return false + } + + this.#fragmentsBytes += fragment.length + this.#fragments.push(fragment) + return true + } + + consumeFragments () { + const fragments = this.#fragments + + if (fragments.length === 1) { + this.#fragmentsBytes = 0 + return fragments.shift() + } + + const output = Buffer.concat(fragments, this.#fragmentsBytes) + this.#fragments = [] + this.#fragmentsBytes = 0 + + return output + } + parseCloseBody (data) { assert(data.length !== 1) @@ -77288,7 +77476,14 @@ class WebSocket extends EventTarget { // once this happens, the connection is open this[kResponse] = response - const parser = new ByteParser(this, parsedExtensions) + const webSocketOptions = this[kController]?.dispatcher?.webSocketOptions + const maxFragments = webSocketOptions?.maxFragments + const maxPayloadSize = webSocketOptions?.maxPayloadSize + + const parser = new ByteParser(this, parsedExtensions, { + maxFragments, + maxPayloadSize + }) parser.on('drain', onParserDrain) parser.on('error', onParserError.bind(this)) diff --git a/docs/advanced-usage.md b/docs/advanced-usage.md index 1b1e4fee..34ebaaf5 100644 --- a/docs/advanced-usage.md +++ b/docs/advanced-usage.md @@ -10,6 +10,7 @@ - [Alibaba Dragonwell](#Alibaba-Dragonwell) - [SapMachine](#SapMachine) - [GraalVM](#GraalVM) + - [GraalVM Community](#GraalVM-Community) - [JetBrains](#JetBrains) - [Installing custom Java package type](#Installing-custom-Java-package-type) - [Installing custom Java architecture](#Installing-custom-Java-architecture) @@ -172,6 +173,21 @@ steps: native-image -cp java HelloWorldApp ``` +### GraalVM Community +**NOTE:** GraalVM Community is available for stable JDK 17 and later releases. + +```yaml +steps: +- uses: actions/checkout@v6 +- uses: actions/setup-java@v5 + with: + distribution: 'graalvm-community' + java-version: '21' +- run: | + java -cp java HelloWorldApp + native-image -cp java HelloWorldApp +``` + ### JetBrains **NOTE:** JetBrains is only available for LTS versions on 11 or later (11, 17, 21, etc.). diff --git a/src/distributions/distribution-factory.ts b/src/distributions/distribution-factory.ts index 9cd459e6..0ff6597e 100644 --- a/src/distributions/distribution-factory.ts +++ b/src/distributions/distribution-factory.ts @@ -11,7 +11,10 @@ import {CorrettoDistribution} from './corretto/installer'; import {OracleDistribution} from './oracle/installer'; import {DragonwellDistribution} from './dragonwell/installer'; import {SapMachineDistribution} from './sapmachine/installer'; -import {GraalVMDistribution} from './graalvm/installer'; +import { + GraalVMCommunityDistribution, + GraalVMDistribution +} from './graalvm/installer'; import {JetBrainsDistribution} from './jetbrains/installer'; enum JavaDistribution { @@ -29,6 +32,7 @@ enum JavaDistribution { Dragonwell = 'dragonwell', SapMachine = 'sapmachine', GraalVM = 'graalvm', + GraalVMCommunity = 'graalvm-community', JetBrains = 'jetbrains' } @@ -74,6 +78,8 @@ export function getJavaDistribution( return new SapMachineDistribution(installerOptions); case JavaDistribution.GraalVM: return new GraalVMDistribution(installerOptions); + case JavaDistribution.GraalVMCommunity: + return new GraalVMCommunityDistribution(installerOptions); case JavaDistribution.JetBrains: return new JetBrainsDistribution(installerOptions); default: diff --git a/src/distributions/graalvm/installer.ts b/src/distributions/graalvm/installer.ts index fea3b8f1..e66f3242 100644 --- a/src/distributions/graalvm/installer.ts +++ b/src/distributions/graalvm/installer.ts @@ -2,6 +2,7 @@ import * as core from '@actions/core'; import * as tc from '@actions/tool-cache'; import fs from 'fs'; import path from 'path'; +import semver from 'semver'; import {JavaBase} from '../base-installer'; import {HttpCodes} from '@actions/http-client'; import {GraalVMEAVersion} from './models'; @@ -11,14 +12,24 @@ import { JavaInstallerResults } from '../base-models'; import { + convertVersionToSemver, extractJdkFile, getDownloadArchiveExtension, getGitHubHttpHeaders, - renameWinArchive + getNextPageUrlFromLinkHeader, + isVersionSatisfies, + MAX_PAGINATION_PAGES, + renameWinArchive, + validatePaginationUrl } from '../../util'; const GRAALVM_DL_BASE = 'https://download.oracle.com/graalvm'; const GRAALVM_DOWNLOAD_URL = 'https://www.graalvm.org/downloads/'; +const GRAALVM_COMMUNITY_RELEASES_URL = + 'https://api.github.com/repos/graalvm/graalvm-ce-builds/releases?per_page=100'; +const GRAALVM_COMMUNITY_RELEASES_PAGE_ORIGIN = 'https://api.github.com'; +const GRAALVM_COMMUNITY_DOWNLOAD_URL = + 'https://github.com/graalvm/graalvm-ce-builds/releases'; const IS_WINDOWS = process.platform === 'win32'; const GRAALVM_PLATFORM = IS_WINDOWS ? 'windows' : process.platform; const GRAALVM_MIN_VERSION = 17; @@ -26,9 +37,23 @@ const SUPPORTED_ARCHITECTURES = ['x64', 'aarch64'] as const; type SupportedArchitecture = (typeof SUPPORTED_ARCHITECTURES)[number]; type OsVersions = 'linux' | 'macos' | 'windows'; +interface GraalVMCommunityAsset { + name: string; + browser_download_url: string; +} + +interface GraalVMCommunityRelease { + draft: boolean; + prerelease: boolean; + assets: GraalVMCommunityAsset[]; +} + export class GraalVMDistribution extends JavaBase { - constructor(installerOptions: JavaInstallerOptions) { - super('GraalVM', installerOptions); + constructor( + installerOptions: JavaInstallerOptions, + distributionName = 'GraalVM' + ) { + super(distributionName, installerOptions); } protected async downloadTool( @@ -85,40 +110,14 @@ export class GraalVMDistribution extends JavaBase { protected async findPackageForDownload( range: string ): Promise { - // Add input validation - if (!range || typeof range !== 'string') { - throw new Error('Version range is required and must be a string'); - } - - const arch = this.distributionArchitecture(); - if (!SUPPORTED_ARCHITECTURES.includes(arch as SupportedArchitecture)) { - throw new Error( - `Unsupported architecture: ${this.architecture}. Supported architectures are: ${SUPPORTED_ARCHITECTURES.join(', ')}` - ); - } + this.validateVersionRange(range); + const arch = this.getSupportedArchitecture(); if (!this.stable) { return this.findEABuildDownloadUrl(`${range}-ea`); } - if (this.packageType !== 'jdk') { - throw new Error('GraalVM provides only the `jdk` package type'); - } - - const platform = this.getPlatform(); - const extension = getDownloadArchiveExtension(); - const major = range.includes('.') ? range.split('.')[0] : range; - const majorVersion = parseInt(major); - - if (isNaN(majorVersion)) { - throw new Error(`Invalid version format: ${range}`); - } - - if (majorVersion < GRAALVM_MIN_VERSION) { - throw new Error( - `GraalVM is only supported for JDK ${GRAALVM_MIN_VERSION} and later. Requested version: ${major}` - ); - } + const {platform, extension, major} = this.validateStableBuildRequest(range); const fileUrl = this.constructFileUrl( range, @@ -134,6 +133,60 @@ export class GraalVMDistribution extends JavaBase { return {url: fileUrl, version: range}; } + protected validateVersionRange(range: string): void { + if (!range || typeof range !== 'string') { + throw new Error('Version range is required and must be a string'); + } + } + + protected getSupportedArchitecture(): SupportedArchitecture { + const arch = this.distributionArchitecture(); + if (!SUPPORTED_ARCHITECTURES.includes(arch as SupportedArchitecture)) { + throw new Error( + `Unsupported architecture: ${this.architecture}. Supported architectures are: ${SUPPORTED_ARCHITECTURES.join(', ')}` + ); + } + + return arch as SupportedArchitecture; + } + + protected validateStableBuildRequest(range: string): { + arch: SupportedArchitecture; + platform: OsVersions; + extension: string; + major: string; + } { + const arch = this.getSupportedArchitecture(); + + if (this.packageType !== 'jdk') { + throw new Error( + `${this.distribution} provides only the \`jdk\` package type` + ); + } + + const platform = this.getPlatform(); + const extension = getDownloadArchiveExtension(); + const major = range.includes('.') ? range.split('.')[0] : range; + const majorVersion = parseInt(major); + + if (isNaN(majorVersion)) { + throw new Error(`Invalid version format: ${range}`); + } + + if (majorVersion < GRAALVM_MIN_VERSION) { + throw new Error( + `${this.distribution} is only supported for JDK ${GRAALVM_MIN_VERSION} and later. Requested version: ${major}` + ); + } + + return { + arch, + platform, + extension, + major + }; + } + private constructFileUrl( range: string, major: string, @@ -280,3 +333,140 @@ export class GraalVMDistribution extends JavaBase { return result; } } + +export class GraalVMCommunityDistribution extends GraalVMDistribution { + constructor(installerOptions: JavaInstallerOptions) { + super(installerOptions, 'GraalVM Community'); + } + + protected get toolcacheFolderName(): string { + return `Java_GraalVM_Community_${this.packageType}`; + } + + protected async findPackageForDownload( + range: string + ): Promise { + this.validateVersionRange(range); + + if (!this.stable) { + throw new Error('GraalVM Community does not provide early access builds'); + } + + const {arch, platform, extension} = this.validateStableBuildRequest(range); + const availableVersions = await this.getAvailableVersionsForPlatform( + platform, + arch, + extension + ); + + const satisfiedVersion = availableVersions + .filter(item => isVersionSatisfies(range, item.version)) + .sort((a, b) => -semver.compareBuild(a.version, b.version))[0]; + + if (!satisfiedVersion) { + const error = this.createVersionNotFoundError( + range, + availableVersions.map(item => item.version), + `Platform: ${platform}` + ); + error.message += `\nPlease check if this version is available at ${GRAALVM_COMMUNITY_DOWNLOAD_URL}.`; + throw error; + } + + return satisfiedVersion; + } + + private async getAvailableVersionsForPlatform( + platform: OsVersions, + arch: SupportedArchitecture, + extension: string + ): Promise { + const headers = getGitHubHttpHeaders(); + const availableVersions = new Map(); + let releasesUrl: string | null = GRAALVM_COMMUNITY_RELEASES_URL; + let pageCount = 0; + + while (releasesUrl) { + pageCount++; + const response = await this.http.getJson( + releasesUrl, + headers + ); + + for (const release of response.result ?? []) { + if (release.draft || release.prerelease) { + continue; + } + + for (const asset of release.assets ?? []) { + const releaseForPlatform = this.toCommunityReleaseForPlatform( + asset, + platform, + arch, + extension + ); + if (releaseForPlatform) { + availableVersions.set( + releaseForPlatform.version, + releaseForPlatform + ); + } + } + } + + const nextUrl = getNextPageUrlFromLinkHeader(response.headers); + if ( + nextUrl && + !validatePaginationUrl(nextUrl, GRAALVM_COMMUNITY_RELEASES_PAGE_ORIGIN) + ) { + core.warning( + `Ignoring pagination link with unexpected origin: ${nextUrl}` + ); + break; + } + + releasesUrl = nextUrl; + + if (!response.result || response.result.length === 0) { + break; + } + + if (pageCount >= MAX_PAGINATION_PAGES) { + core.warning( + `Reached pagination safeguard limit (${MAX_PAGINATION_PAGES} pages) while listing GraalVM Community releases.` + ); + break; + } + } + + return [...availableVersions.values()]; + } + + private toCommunityReleaseForPlatform( + asset: GraalVMCommunityAsset, + platform: OsVersions, + arch: SupportedArchitecture, + extension: string + ): JavaDownloadRelease | null { + const match = asset.name.match( + /^graalvm-community-jdk-(?\d+(?:\.\d+)+)_(?linux|macos|windows)-(?x64|aarch64)_bin\.(?tar\.gz|zip)$/ + ); + + if (!match?.groups) { + return null; + } + + if ( + match.groups.platform !== platform || + match.groups.arch !== arch || + match.groups.extension !== extension + ) { + return null; + } + + return { + version: convertVersionToSemver(match.groups.version), + url: asset.browser_download_url + }; + } +}