diff --git a/__test__/unsafe-pr-checkout-helper.test.ts b/__test__/unsafe-pr-checkout-helper.test.ts index 9634618..9efa246 100644 --- a/__test__/unsafe-pr-checkout-helper.test.ts +++ b/__test__/unsafe-pr-checkout-helper.test.ts @@ -13,6 +13,7 @@ const PR_MERGE_SHA = '2222222222222222222222222222222222222222' const SAFE_BASE_SHA = '3333333333333333333333333333333333333333' const WORKFLOW_RUN_HEAD_COMMIT_SHA = '4444444444444444444444444444444444444444' const BASE_QUALIFIED_REPO = 'some-owner/some-repo' +const FORK_QUALIFIED_REPO = 'another-repo/fork' function setContext(eventName: string, payload: object): void { ;(github.context as {eventName: string}).eventName = eventName @@ -25,7 +26,7 @@ function forkPullRequestTargetPayload(): object { pull_request: { head: { sha: PR_HEAD_SHA, - repo: {id: FORK_REPO_ID} + repo: {id: FORK_REPO_ID, full_name: FORK_QUALIFIED_REPO} }, merge_commit_sha: PR_MERGE_SHA } @@ -38,7 +39,7 @@ function sameRepoPullRequestTargetPayload(): object { pull_request: { head: { sha: PR_HEAD_SHA, - repo: {id: BASE_REPO_ID} + repo: {id: BASE_REPO_ID, full_name: BASE_QUALIFIED_REPO} }, merge_commit_sha: PR_MERGE_SHA } @@ -51,7 +52,7 @@ function forkWorkflowRunPayload(): object { workflow_run: { event: 'pull_request', head_commit: {id: WORKFLOW_RUN_HEAD_COMMIT_SHA}, - head_repository: {id: FORK_REPO_ID} + head_repository: {id: FORK_REPO_ID, full_name: FORK_QUALIFIED_REPO} } } } @@ -164,7 +165,7 @@ describe('unsafe-pr-checkout-helper', () => { setContext('pull_request_target', forkPullRequestTargetPayload()) expect(() => assertSafePrCheckout({ - qualifiedRepository: 'attacker/fork', + qualifiedRepository: FORK_QUALIFIED_REPO, ref: 'refs/heads/main', commit: '', allowUnsafePrCheckout: false @@ -172,13 +173,25 @@ describe('unsafe-pr-checkout-helper', () => { ).toThrow() }) + it('allows pull_request_target checkout of an unrelated third-party repo', () => { + setContext('pull_request_target', forkPullRequestTargetPayload()) + expect(() => + assertSafePrCheckout({ + qualifiedRepository: 'some-other/unrelated', + ref: 'refs/heads/main', + commit: '', + allowUnsafePrCheckout: false + }) + ).not.toThrow() + }) + it('refuses pull_request_target ignoring repository case differences', () => { setContext('pull_request_target', forkPullRequestTargetPayload()) expect(() => assertSafePrCheckout({ - qualifiedRepository: 'SOME-OWNER/SOME-REPO', + qualifiedRepository: FORK_QUALIFIED_REPO.toUpperCase(), ref: '', - commit: PR_HEAD_SHA, + commit: '', allowUnsafePrCheckout: false }) ).toThrow() diff --git a/dist/index.js b/dist/index.js index cf9067d..0acfdf1 100644 --- a/dist/index.js +++ b/dist/index.js @@ -2793,9 +2793,11 @@ function assertSafePrCheckout(input) { return; } let prHeadRepoId; + let prHeadRepoFullName; const prShas = []; if (eventName === 'pull_request_target') { prHeadRepoId = (0, ref_helper_1.fromPayload)('pull_request.head.repo.id'); + prHeadRepoFullName = (0, ref_helper_1.fromPayload)('pull_request.head.repo.full_name'); pushIfSha(prShas, (0, ref_helper_1.fromPayload)('pull_request.head.sha')); pushIfSha(prShas, (0, ref_helper_1.fromPayload)('pull_request.merge_commit_sha')); } @@ -2805,7 +2807,13 @@ function assertSafePrCheckout(input) { return; } prHeadRepoId = (0, ref_helper_1.fromPayload)('workflow_run.head_repository.id'); + prHeadRepoFullName = (0, ref_helper_1.fromPayload)('workflow_run.head_repository.full_name'); pushIfSha(prShas, (0, ref_helper_1.fromPayload)('workflow_run.head_commit.id')); + // For `pull_request_target`-triggered workflow_run, `head_sha` is the base + // default branch SHA (not the PR head) + if (wrEvent !== 'pull_request_target') { + pushIfSha(prShas, (0, ref_helper_1.fromPayload)('workflow_run.head_sha')); + } } // (A) Fork PR? if (typeof prHeadRepoId !== 'number' || prHeadRepoId === baseRepoId) { @@ -2813,12 +2821,12 @@ function assertSafePrCheckout(input) { } // (B) We cannot check for all fork PR refs so check to see // if the resolved input points to the fork PR sha we have in the payload - const baseQualifiedRepository = `${github.context.repo.owner}/${github.context.repo.repo}`; - const repositoryDiffersFromBase = input.qualifiedRepository.toLowerCase() !== - baseQualifiedRepository.toLowerCase(); + const repositoryMatchesPrHead = typeof prHeadRepoFullName === 'string' && + input.qualifiedRepository.toLowerCase() === + prHeadRepoFullName.toLowerCase(); const refMatchesPullPattern = PR_REF_PATTERN.test(input.ref); const commitMatchesPrHeadSha = !!input.commit && prShas.includes(input.commit.toLowerCase()); - if (!repositoryDiffersFromBase && + if (!repositoryMatchesPrHead && !refMatchesPullPattern && !commitMatchesPrHeadSha) { return; diff --git a/src/unsafe-pr-checkout-helper.ts b/src/unsafe-pr-checkout-helper.ts index 860d5ed..3e956f3 100644 --- a/src/unsafe-pr-checkout-helper.ts +++ b/src/unsafe-pr-checkout-helper.ts @@ -6,7 +6,7 @@ const PR_REF_PATTERN = /^refs\/pull\/[0-9]+\/(?:head|merge)$/ export interface IUnsafePrCheckoutInput { qualifiedRepository: string ref: string - commit: string + commit: string | undefined allowUnsafePrCheckout: boolean } @@ -26,10 +26,12 @@ export function assertSafePrCheckout(input: IUnsafePrCheckoutInput): void { } let prHeadRepoId: unknown + let prHeadRepoFullName: unknown const prShas: string[] = [] if (eventName === 'pull_request_target') { prHeadRepoId = fromPayload('pull_request.head.repo.id') + prHeadRepoFullName = fromPayload('pull_request.head.repo.full_name') pushIfSha(prShas, fromPayload('pull_request.head.sha')) pushIfSha(prShas, fromPayload('pull_request.merge_commit_sha')) } else { @@ -38,7 +40,13 @@ export function assertSafePrCheckout(input: IUnsafePrCheckoutInput): void { return } prHeadRepoId = fromPayload('workflow_run.head_repository.id') + prHeadRepoFullName = fromPayload('workflow_run.head_repository.full_name') pushIfSha(prShas, fromPayload('workflow_run.head_commit.id')) + // For `pull_request_target`-triggered workflow_run, `head_sha` is the base + // default branch SHA (not the PR head) + if (wrEvent !== 'pull_request_target') { + pushIfSha(prShas, fromPayload('workflow_run.head_sha')) + } } // (A) Fork PR? @@ -48,16 +56,16 @@ export function assertSafePrCheckout(input: IUnsafePrCheckoutInput): void { // (B) We cannot check for all fork PR refs so check to see // if the resolved input points to the fork PR sha we have in the payload - const baseQualifiedRepository = `${github.context.repo.owner}/${github.context.repo.repo}` - const repositoryDiffersFromBase = - input.qualifiedRepository.toLowerCase() !== - baseQualifiedRepository.toLowerCase() + const repositoryMatchesPrHead = + typeof prHeadRepoFullName === 'string' && + input.qualifiedRepository.toLowerCase() === + prHeadRepoFullName.toLowerCase() const refMatchesPullPattern = PR_REF_PATTERN.test(input.ref) const commitMatchesPrHeadSha = !!input.commit && prShas.includes(input.commit.toLowerCase()) if ( - !repositoryDiffersFromBase && + !repositoryMatchesPrHead && !refMatchesPullPattern && !commitMatchesPrHeadSha ) {