diff --git a/.github/workflows/pr-security-scan.yml b/.github/workflows/pr-security-scan.yml new file mode 100644 index 0000000..def5612 --- /dev/null +++ b/.github/workflows/pr-security-scan.yml @@ -0,0 +1,49 @@ +name: PR Security Scan + +on: + pull_request: + branches: [main] + +permissions: + contents: read + security-events: write + actions: read + pull-requests: write + +jobs: + get-changed-files: + name: Get Changed Files + runs-on: ubuntu-latest + outputs: + files: ${{ steps.changed-files.outputs.all_changed_files }} + any_changed: ${{ steps.changed-files.outputs.any_changed }} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v46 + with: + separator: ',' + # Exclude test files from security scan - they contain intentional + # path traversal test data that triggers false positives + files_ignore: | + **/*_test.go + **/testdata/** + + security-scan: + name: Security Scan + needs: get-changed-files + if: needs.get-changed-files.outputs.any_changed == 'true' + uses: ./.github/workflows/reusable-security-scan.yml + with: + scan-type: repo + fail-on: 'CRITICAL,HIGH' + include-files: ${{ needs.get-changed-files.outputs.files }} + build-from-source: false + secrets: + api-token: ${{ secrets.ARMIS_API_TOKEN }} + tenant-id: ${{ secrets.ARMIS_TENANT_ID }} diff --git a/.github/workflows/reusable-security-scan.yml b/.github/workflows/reusable-security-scan.yml index 3d8bde8..32fbd65 100644 --- a/.github/workflows/reusable-security-scan.yml +++ b/.github/workflows/reusable-security-scan.yml @@ -35,6 +35,14 @@ on: description: 'Scan timeout in minutes' type: number default: 60 + include-files: + description: 'Comma-separated list of file paths to scan (relative to repository root)' + type: string + default: '' + build-from-source: + description: 'Build CLI from source instead of downloading release (for testing scanner changes)' + type: boolean + default: false secrets: api-token: description: 'Armis API token for authentication' @@ -43,6 +51,14 @@ on: description: 'Tenant identifier for Armis Cloud' required: true +# Top-level permissions define the maximum permissions available to this workflow. +# Job-level permissions further restrict as needed. +permissions: + contents: read + security-events: write + actions: read + pull-requests: write + jobs: security-scan: name: Armis Security Scan @@ -62,7 +78,7 @@ jobs: - name: Run Armis Security Scan id: armis_scan - uses: ArmisSecurity/armis-cli@main + uses: ./ with: scan-type: ${{ inputs.scan-type }} scan-target: ${{ inputs.scan-target }} @@ -73,6 +89,8 @@ jobs: output-file: armis-results.sarif image-tarball: ${{ inputs.image-tarball }} scan-timeout: ${{ inputs.scan-timeout }} + include-files: ${{ inputs.include-files }} + build-from-source: ${{ inputs.build-from-source }} continue-on-error: true - name: Ensure SARIF exists @@ -121,7 +139,76 @@ jobs: if (counts.LOW > 0) body += `| 🔵 LOW | ${counts.LOW} |\n`; if (counts.INFO > 0) body += `| ⚪ INFO | ${counts.INFO} |\n`; body += `\n**Total: ${total}**\n`; - body += `\n
View full results\n\nSee the Security tab or download the \`armis-security-results\` artifact for the complete SARIF report.\n
`; + + // Build detailed findings section + if (total > 0) { + body += `\n
View all ${total} findings\n\n`; + + // Group results by severity + const severityOrder = ['CRITICAL', 'HIGH', 'MEDIUM', 'LOW', 'INFO']; + const severityEmoji = { CRITICAL: '🔴', HIGH: '🟠', MEDIUM: '🟡', LOW: '🔵', INFO: '⚪' }; + + for (const severity of severityOrder) { + const severityResults = results.filter(r => + (r.properties?.severity || 'INFO') === severity + ); + + if (severityResults.length > 0) { + body += `### ${severityEmoji[severity]} ${severity} (${severityResults.length})\n\n`; + + for (const r of severityResults) { + const file = r.locations?.[0]?.physicalLocation?.artifactLocation?.uri || ''; + const line = r.locations?.[0]?.physicalLocation?.region?.startLine || ''; + const location = file ? (line ? `${file}:${line}` : file) : 'Unknown location'; + + // Parse title and description from message + const msgParts = (r.message?.text || '').split(': '); + const title = msgParts[0] || r.ruleId; + const description = msgParts.slice(1).join(': ') || ''; + + body += `
${r.ruleId} - ${title}\n\n`; + body += `**Location:** \`${location}\`\n\n`; + + if (description) { + body += `${description}\n\n`; + } + + // Code snippet + const snippet = r.properties?.codeSnippet; + if (snippet) { + body += '```\n' + snippet + '\n```\n\n'; + } + + // CVEs and CWEs + const cves = r.properties?.cves || []; + const cwes = r.properties?.cwes || []; + if (cves.length > 0) { + body += `**CVEs:** ${cves.join(', ')}\n\n`; + } + if (cwes.length > 0) { + body += `**CWEs:** ${cwes.join(', ')}\n\n`; + } + + // Package info + const pkg = r.properties?.package; + const version = r.properties?.version; + const fixVersion = r.properties?.fixVersion; + if (pkg) { + let pkgInfo = `**Package:** ${pkg}`; + if (version) pkgInfo += ` (${version})`; + if (fixVersion) pkgInfo += ` → Fix: ${fixVersion}`; + body += pkgInfo + '\n\n'; + } + + body += `
\n\n`; + } + } + } + + body += `
`; + } else { + body += `\n
View full results\n\nNo security issues found.\n
`; + } // Find and update existing comment, or create new const { data: comments } = await github.rest.issues.listComments({ diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index 2f9013f..b48a79a 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -18,6 +18,7 @@ jobs: pr-comment: false # No PR context for scheduled runs upload-artifact: true scan-timeout: 120 # 2 hours + build-from-source: false secrets: api-token: ${{ secrets.ARMIS_API_TOKEN }} tenant-id: ${{ secrets.ARMIS_TENANT_ID }} diff --git a/action.yml b/action.yml index f9d7d53..2030718 100644 --- a/action.yml +++ b/action.yml @@ -46,6 +46,14 @@ inputs: description: 'Scan timeout in minutes' required: false default: '60' + include-files: + description: 'Comma-separated list of file paths to scan (relative to repository root)' + required: false + default: '' + build-from-source: + description: 'Build CLI from source instead of downloading release (for testing)' + required: false + default: 'false' outputs: results: @@ -58,11 +66,80 @@ outputs: runs: using: 'composite' steps: - - name: Install Armis CLI + - name: Setup Go + if: inputs.build-from-source == 'true' + uses: actions/setup-go@v5 + with: + go-version: '1.23' + + - name: Build Armis CLI from source + if: inputs.build-from-source == 'true' + shell: bash + run: | + echo "Building Armis CLI from source..." + go build -o armis-cli ./cmd/armis-cli + mkdir -p "$HOME/.local/bin" + mv armis-cli "$HOME/.local/bin/" + echo "$HOME/.local/bin" >> $GITHUB_PATH + + - name: Install Armis CLI from release + if: inputs.build-from-source != 'true' shell: bash run: | echo "Installing Armis CLI..." - curl -sSL https://raw.githubusercontent.com/ArmisSecurity/armis-cli/main/scripts/install.sh | bash + + # Detect OS and architecture + OS=$(uname -s | tr '[:upper:]' '[:lower:]') + case "$OS" in + linux*) OS="linux" ;; + darwin*) OS="darwin" ;; + *) echo "Error: Unsupported OS: $OS"; exit 1 ;; + esac + + ARCH=$(uname -m) + case "$ARCH" in + x86_64|amd64) ARCH="amd64" ;; + aarch64|arm64) ARCH="arm64" ;; + *) echo "Error: Unsupported architecture: $ARCH"; exit 1 ;; + esac + + # Download binary directly from GitHub releases + REPO="ArmisSecurity/armis-cli" + ARCHIVE_NAME="armis-cli-${OS}-${ARCH}.tar.gz" + CHECKSUMS_NAME="armis-cli-checksums.txt" + BASE_URL="https://github.com/${REPO}/releases/latest/download" + + TMP_DIR=$(mktemp -d) + trap 'rm -rf "$TMP_DIR"' EXIT + + echo "Downloading $ARCHIVE_NAME..." + curl -fsSL "${BASE_URL}/${ARCHIVE_NAME}" -o "${TMP_DIR}/${ARCHIVE_NAME}" + curl -fsSL "${BASE_URL}/${CHECKSUMS_NAME}" -o "${TMP_DIR}/${CHECKSUMS_NAME}" + + # Verify checksum + echo "Verifying checksum..." + cd "$TMP_DIR" + EXPECTED=$(grep "$ARCHIVE_NAME" "$CHECKSUMS_NAME" | awk '{print $1}') + if command -v sha256sum > /dev/null 2>&1; then + ACTUAL=$(sha256sum "$ARCHIVE_NAME" | awk '{print $1}') + else + ACTUAL=$(shasum -a 256 "$ARCHIVE_NAME" | awk '{print $1}') + fi + + if [ "$EXPECTED" != "$ACTUAL" ]; then + echo "Error: Checksum verification failed!" + echo "Expected: $EXPECTED" + echo "Actual: $ACTUAL" + exit 1 + fi + echo "Checksum verified successfully" + + # Extract and install + tar -xzf "$ARCHIVE_NAME" + mkdir -p "$HOME/.local/bin" + mv armis-cli "$HOME/.local/bin/" + chmod +x "$HOME/.local/bin/armis-cli" + echo "Armis CLI installed successfully" echo "$HOME/.local/bin" >> $GITHUB_PATH - name: Verify Installation @@ -85,6 +162,7 @@ runs: NO_PROGRESS: ${{ inputs.no-progress }} IMAGE_TARBALL: ${{ inputs.image-tarball }} OUTPUT_FILE: ${{ inputs.output-file }} + INCLUDE_FILES: ${{ inputs.include-files }} run: | set +e @@ -114,6 +192,10 @@ runs: SCAN_ARGS+=("--tarball" "$IMAGE_TARBALL") fi + if [ -n "$INCLUDE_FILES" ]; then + SCAN_ARGS+=("--include-files" "$INCLUDE_FILES") + fi + # Execute command safely without eval if [ -n "$OUTPUT_FILE" ]; then armis-cli "${SCAN_ARGS[@]}" > "$OUTPUT_FILE" diff --git a/internal/cmd/scan.go b/internal/cmd/scan.go index 7d37332..4f78c4e 100644 --- a/internal/cmd/scan.go +++ b/internal/cmd/scan.go @@ -10,6 +10,7 @@ var ( uploadTimeout int includeNonExploitable bool groupBy string + includeFiles []string ) var scanCmd = &cobra.Command{ @@ -24,6 +25,7 @@ func init() { scanCmd.PersistentFlags().IntVar(&uploadTimeout, "upload-timeout", 10, "Maximum time in minutes to wait for artifact upload to complete") scanCmd.PersistentFlags().BoolVar(&includeNonExploitable, "include-non-exploitable", false, "Include findings marked as non-exploitable (only exploitable findings shown by default)") scanCmd.PersistentFlags().StringVar(&groupBy, "group-by", "none", "Group findings by: none, cwe, severity, file") + scanCmd.PersistentFlags().StringSliceVar(&includeFiles, "include-files", nil, "Comma-separated list of file paths to include in scan (relative to repository root)") if rootCmd != nil { rootCmd.AddCommand(scanCmd) } diff --git a/internal/cmd/scan_repo.go b/internal/cmd/scan_repo.go index c283d84..8718f2c 100644 --- a/internal/cmd/scan_repo.go +++ b/internal/cmd/scan_repo.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "os" + "path/filepath" "time" "github.com/ArmisSecurity/armis-cli/internal/api" @@ -47,6 +48,22 @@ var scanRepoCmd = &cobra.Command{ scanTimeoutDuration := time.Duration(scanTimeout) * time.Minute scanner := repo.NewScanner(client, noProgress, tid, limit, includeTests, scanTimeoutDuration, includeNonExploitable) + // Handle --include-files flag for targeted file scanning + // Security: Path traversal protection is enforced by ParseFileList which + // validates all paths using SafeJoinPath to ensure they don't escape the + // repository root. Invalid or traversal paths are rejected with an error. + if len(includeFiles) > 0 { + absPath, err := filepath.Abs(repoPath) + if err != nil { + return fmt.Errorf("failed to resolve path: %w", err) + } + fileList, err := repo.ParseFileList(absPath, includeFiles) + if err != nil { + return fmt.Errorf("invalid --include-files: %w", err) + } + scanner = scanner.WithIncludeFiles(fileList) + } + ctx, cancel := NewSignalContext() defer cancel() diff --git a/internal/output/sarif.go b/internal/output/sarif.go index a21f71e..717408d 100644 --- a/internal/output/sarif.go +++ b/internal/output/sarif.go @@ -60,7 +60,14 @@ type sarifResult struct { } type sarifResultProperties struct { - Severity string `json:"severity"` + Severity string `json:"severity"` + Type string `json:"type,omitempty"` + CodeSnippet string `json:"codeSnippet,omitempty"` + CVEs []string `json:"cves,omitempty"` + CWEs []string `json:"cwes,omitempty"` + Package string `json:"package,omitempty"` + Version string `json:"version,omitempty"` + FixVersion string `json:"fixVersion,omitempty"` } type sarifMessage struct { @@ -138,8 +145,17 @@ func buildRules(findings []model.Finding) ([]sarifRule, map[string]int) { return rules, ruleIndexMap } +// maxSarifResultsCapacity is the maximum initial capacity for SARIF results slice +// to prevent resource exhaustion from extremely large finding lists (CWE-770). +const maxSarifResultsCapacity = 10000 + func convertToSarifResults(findings []model.Finding, ruleIndexMap map[string]int) []sarifResult { - results := make([]sarifResult, 0, len(findings)) + // Cap the initial capacity to prevent excessive memory allocation (CWE-770) + capacity := len(findings) + if capacity > maxSarifResultsCapacity { + capacity = maxSarifResultsCapacity + } + results := make([]sarifResult, 0, capacity) for _, finding := range findings { result := sarifResult{ @@ -150,7 +166,14 @@ func convertToSarifResults(findings []model.Finding, ruleIndexMap map[string]int Text: finding.Title + ": " + finding.Description, }, Properties: &sarifResultProperties{ - Severity: string(finding.Severity), + Severity: string(finding.Severity), + Type: string(finding.Type), + CodeSnippet: util.MaskSecretInLine(finding.CodeSnippet), // Defense-in-depth: always sanitize + CVEs: finding.CVEs, + CWEs: finding.CWEs, + Package: finding.Package, + Version: finding.Version, + FixVersion: finding.FixVersion, }, } diff --git a/internal/scan/repo/files.go b/internal/scan/repo/files.go new file mode 100644 index 0000000..3180377 --- /dev/null +++ b/internal/scan/repo/files.go @@ -0,0 +1,119 @@ +// Package repo provides repository scanning functionality. +package repo + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/ArmisSecurity/armis-cli/internal/util" +) + +// MaxFiles is the maximum number of files that can be specified via --include-files. +// This limit prevents resource exhaustion from extremely large file lists. +const MaxFiles = 1000 + +// FileList represents a list of files to be scanned. +type FileList struct { + files []string + repoRoot string +} + +// ParseFileList parses file paths from the --include-files flag. +// It accepts both relative paths (to repoRoot) and absolute paths, +// normalizing all to relative paths. +func ParseFileList(repoRoot string, files []string) (*FileList, error) { + absRoot, err := filepath.Abs(repoRoot) + if err != nil { + return nil, fmt.Errorf("failed to resolve repo root: %w", err) + } + + fl := &FileList{repoRoot: absRoot} + for _, f := range files { + if err := fl.addFile(f); err != nil { + return nil, err + } + } + return fl, nil +} + +func (fl *FileList) addFile(path string) error { + // Check file count limit to prevent resource exhaustion + if len(fl.files) >= MaxFiles { + return fmt.Errorf("too many files: maximum %d files allowed", MaxFiles) + } + + if path == "" { + return nil // Skip empty paths + } + + // Normalize path separators + path = filepath.FromSlash(path) + + // Convert absolute paths to relative + if filepath.IsAbs(path) { + // Security: Resolve symlinks to prevent path traversal attacks (CWE-22). + // Using filepath.EvalSymlinks ensures we compare actual filesystem paths, + // preventing symlink-based escapes from the repository root. + evalPath, err := filepath.EvalSymlinks(path) + if err != nil { + // Path doesn't exist yet - fall back to Clean for normalization + evalPath = filepath.Clean(path) + } + evalRoot, err := filepath.EvalSymlinks(fl.repoRoot) + if err != nil { + evalRoot = filepath.Clean(fl.repoRoot) + } + + // Use filepath.Rel to check containment - it returns an error or + // a path starting with ".." if the path is outside the root + relCheck, err := filepath.Rel(evalRoot, evalPath) + if err != nil || strings.HasPrefix(relCheck, "..") { + return fmt.Errorf("absolute path %q is outside repository root %q", path, fl.repoRoot) + } + + rel, err := filepath.Rel(fl.repoRoot, path) + if err != nil { + return fmt.Errorf("cannot make path relative to repo: %s", path) + } + path = rel + } + + // Validate path doesn't escape repo root using SafeJoinPath + if _, err := util.SafeJoinPath(fl.repoRoot, path); err != nil { + return fmt.Errorf("invalid path %q: %w", path, err) + } + + fl.files = append(fl.files, path) + return nil +} + +// Files returns the validated list of relative file paths. +func (fl *FileList) Files() []string { + return fl.files +} + +// RepoRoot returns the absolute path to the repository root. +func (fl *FileList) RepoRoot() string { + return fl.repoRoot +} + +// ValidateExistence checks which files exist and returns warnings for missing files. +func (fl *FileList) ValidateExistence() (existing []string, warnings []string) { + for _, f := range fl.files { + absPath := filepath.Join(fl.repoRoot, f) + info, err := os.Stat(absPath) + if err != nil { + warnings = append(warnings, fmt.Sprintf("file not found: %s", f)) + continue + } + // Skip directories - we only scan files + if info.IsDir() { + warnings = append(warnings, fmt.Sprintf("skipping directory: %s", f)) + continue + } + existing = append(existing, f) + } + return +} diff --git a/internal/scan/repo/files_test.go b/internal/scan/repo/files_test.go new file mode 100644 index 0000000..1e018ae --- /dev/null +++ b/internal/scan/repo/files_test.go @@ -0,0 +1,223 @@ +package repo + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestParseFileList(t *testing.T) { + tmpDir := t.TempDir() + + // Create test files + if err := os.WriteFile(filepath.Join(tmpDir, "main.go"), []byte("package main"), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + if err := os.MkdirAll(filepath.Join(tmpDir, "pkg"), 0750); err != nil { + t.Fatalf("Failed to create test dir: %v", err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "pkg", "helper.go"), []byte("package pkg"), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + tests := []struct { + name string + files []string + wantLen int + wantErr bool + }{ + { + name: "valid relative paths", + files: []string{"main.go", "pkg/helper.go"}, + wantLen: 2, + wantErr: false, + }, + { + name: "path traversal rejected", + files: []string{"../etc/passwd"}, + wantErr: true, + }, + { + name: "empty list", + files: []string{}, + wantLen: 0, + wantErr: false, + }, + { + name: "empty string in list is skipped", + files: []string{"main.go", "", "pkg/helper.go"}, + wantLen: 2, + wantErr: false, + }, + { + name: "absolute path converted to relative", + files: []string{filepath.Join(tmpDir, "main.go")}, + wantLen: 1, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fl, err := ParseFileList(tmpDir, tt.files) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(fl.Files()) != tt.wantLen { + t.Errorf("got %d files, want %d", len(fl.Files()), tt.wantLen) + } + }) + } +} + +func TestFileListValidateExistence(t *testing.T) { + tmpDir := t.TempDir() + + // Create one existing file + if err := os.WriteFile(filepath.Join(tmpDir, "exists.go"), []byte("package main"), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Create a directory to test directory skipping + if err := os.MkdirAll(filepath.Join(tmpDir, "subdir"), 0750); err != nil { + t.Fatalf("Failed to create test dir: %v", err) + } + + fl, err := ParseFileList(tmpDir, []string{"exists.go", "missing.go", "subdir"}) + if err != nil { + t.Fatalf("ParseFileList failed: %v", err) + } + + existing, warnings := fl.ValidateExistence() + + if len(existing) != 1 { + t.Errorf("expected 1 existing file, got %d", len(existing)) + } + if existing[0] != "exists.go" { + t.Errorf("expected exists.go, got %s", existing[0]) + } + if len(warnings) != 2 { + t.Errorf("expected 2 warnings (missing file + directory), got %d", len(warnings)) + } +} + +func TestParseFileListPathTraversal(t *testing.T) { + tmpDir := t.TempDir() + + traversalPaths := []string{ + "../etc/passwd", + "foo/../../etc/passwd", + "./foo/../../../etc/passwd", + } + + for _, path := range traversalPaths { + t.Run(path, func(t *testing.T) { + _, err := ParseFileList(tmpDir, []string{path}) + if err == nil { + t.Errorf("expected error for path traversal attempt: %s", path) + } + }) + } +} + +func TestParseFileListAbsolutePathOutsideRepo(t *testing.T) { + // Create two separate temp directories - one is the "repo root", the other is "outside" + repoDir := t.TempDir() + outsideDir := t.TempDir() + + // Create a file in the outside directory to get a real absolute path + outsideFile := filepath.Join(outsideDir, "outside.go") + if err := os.WriteFile(outsideFile, []byte("package outside"), 0600); err != nil { + t.Fatalf("Failed to create outside file: %v", err) + } + + // Test that an absolute path outside the repo root is rejected + _, err := ParseFileList(repoDir, []string{outsideFile}) + if err == nil { + t.Errorf("expected error for absolute path outside repo: %s", outsideFile) + } + // Verify the error message is clear about the issue + if err != nil && !strings.Contains(err.Error(), "outside repository root") { + t.Errorf("expected error message to mention 'outside repository root', got: %s", err.Error()) + } +} + +func TestFileListRepoRoot(t *testing.T) { + tmpDir := t.TempDir() + + fl, err := ParseFileList(tmpDir, []string{}) + if err != nil { + t.Fatalf("ParseFileList failed: %v", err) + } + + // RepoRoot should return an absolute path + root := fl.RepoRoot() + if !filepath.IsAbs(root) { + t.Errorf("RepoRoot should return absolute path, got: %s", root) + } +} + +func TestFileListFiles(t *testing.T) { + tmpDir := t.TempDir() + + // Create test file + if err := os.WriteFile(filepath.Join(tmpDir, "test.go"), []byte("package main"), 0600); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + fl, err := ParseFileList(tmpDir, []string{"test.go"}) + if err != nil { + t.Fatalf("ParseFileList failed: %v", err) + } + + files := fl.Files() + if len(files) != 1 { + t.Fatalf("expected 1 file, got %d", len(files)) + } + if files[0] != "test.go" { + t.Errorf("expected test.go, got %s", files[0]) + } +} + +func TestParseFileListMaxFilesLimit(t *testing.T) { + tmpDir := t.TempDir() + + // Generate more files than the limit + files := make([]string, MaxFiles+1) + for i := range files { + files[i] = "file.go" // Doesn't need to exist for this test + } + + _, err := ParseFileList(tmpDir, files) + if err == nil { + t.Errorf("expected error when exceeding MaxFiles limit (%d), got nil", MaxFiles) + } + if err != nil && !strings.Contains(err.Error(), "too many files") { + t.Errorf("expected error message to mention 'too many files', got: %s", err.Error()) + } +} + +func TestParseFileListAtMaxFilesLimit(t *testing.T) { + tmpDir := t.TempDir() + + // Generate exactly MaxFiles files (should succeed) + files := make([]string, MaxFiles) + for i := range files { + files[i] = "file.go" // Doesn't need to exist for this test + } + + fl, err := ParseFileList(tmpDir, files) + if err != nil { + t.Errorf("expected success at exactly MaxFiles limit (%d), got error: %v", MaxFiles, err) + } + if fl != nil && len(fl.Files()) != MaxFiles { + t.Errorf("expected %d files, got %d", MaxFiles, len(fl.Files())) + } +} diff --git a/internal/scan/repo/repo.go b/internal/scan/repo/repo.go index 5ecc17f..2f4d728 100644 --- a/internal/scan/repo/repo.go +++ b/internal/scan/repo/repo.go @@ -33,6 +33,7 @@ type Scanner struct { timeout time.Duration includeNonExploitable bool pollInterval time.Duration + includeFiles *FileList } // NewScanner creates a new repository scanner with the given configuration. @@ -55,6 +56,12 @@ func (s *Scanner) WithPollInterval(d time.Duration) *Scanner { return s } +// WithIncludeFiles sets a specific list of files to scan instead of the entire directory. +func (s *Scanner) WithIncludeFiles(fl *FileList) *Scanner { + s.includeFiles = fl + return s +} + // Scan scans a repository at the given path. func (s *Scanner) Scan(ctx context.Context, path string) (*model.ScanResult, error) { // Validate path to prevent path traversal @@ -76,30 +83,78 @@ func (s *Scanner) Scan(ctx context.Context, path string) (*model.ScanResult, err return nil, fmt.Errorf("path is not a directory: %s", absPath) } - ignoreMatcher, err := LoadIgnorePatterns(absPath) - if err != nil { - return nil, fmt.Errorf("failed to load ignore patterns: %w", err) - } + var size int64 + var tarFunc func() error + var ignoreMatcher *IgnoreMatcher - size, err := calculateDirSize(absPath, s.includeTests, ignoreMatcher) - if err != nil { - return nil, fmt.Errorf("failed to calculate directory size: %w", err) + pr, pw := io.Pipe() + // The pipe reader is deferred-closed to ensure cleanup on all code paths. + // Error is intentionally ignored (nolint:errcheck) because: + // 1. PipeReader.Close() rarely returns meaningful errors + // 2. The critical close is PipeWriter.Close() which signals EOF to the reader + // 3. Any actual read errors will surface through the main error flow + defer pr.Close() //nolint:errcheck + + if s.includeFiles != nil { + // Targeted file scanning mode - scan only specified files + existing, warnings := s.includeFiles.ValidateExistence() + for _, w := range warnings { + fmt.Fprintf(os.Stderr, "Warning: %s\n", w) + } + + if len(existing) == 0 { + return nil, fmt.Errorf("no files to scan: all specified files are missing or are directories") + } + + var err error + size, err = calculateFilesSize(absPath, existing) + if err != nil { + return nil, fmt.Errorf("failed to calculate files size: %w", err) + } + + tarFunc = func() error { + defer pw.Close() //nolint:errcheck // signals EOF to reader + return s.tarGzFiles(absPath, existing, pw) + } + } else { + // Full directory scanning mode (existing behavior) + var err error + ignoreMatcher, err = LoadIgnorePatterns(absPath) + if err != nil { + return nil, fmt.Errorf("failed to load ignore patterns: %w", err) + } + + size, err = calculateDirSize(absPath, s.includeTests, ignoreMatcher) + if err != nil { + return nil, fmt.Errorf("failed to calculate directory size: %w", err) + } + + tarFunc = func() error { + defer pw.Close() //nolint:errcheck // signals EOF to reader + return s.tarGzDirectory(absPath, pw, ignoreMatcher) + } } if size > MaxRepoSize { return nil, fmt.Errorf("directory size (%d bytes) exceeds maximum allowed size (%d bytes)", size, MaxRepoSize) } - pr, pw := io.Pipe() - spinner := progress.NewSpinnerWithContext(ctx, "Creating a compressed archive...", s.noProgress) spinner.Start() defer spinner.Stop() errChan := make(chan error, 1) go func() { - defer pw.Close() //nolint:errcheck // signals EOF to reader - errChan <- s.tarGzDirectory(absPath, pw, ignoreMatcher) + // Security: Check context before starting expensive tar operation + // to prevent resource leaks if context is already canceled. + select { + case <-ctx.Done(): + pw.Close() //nolint:errcheck,gosec // Close pipe to unblock StartIngest + errChan <- ctx.Err() + return + default: + } + errChan <- tarFunc() }() time.Sleep(500 * time.Millisecond) @@ -220,6 +275,125 @@ func (s *Scanner) tarGzDirectory(sourcePath string, writer io.Writer, ignoreMatc }) } +// isPathContained verifies that absPath is contained within baseDir. +// This is a defense-in-depth check to prevent path traversal attacks, +// complementing the SafeJoinPath validation performed at file list parsing. +func isPathContained(baseDir, absPath string) bool { + rel, err := filepath.Rel(baseDir, absPath) + if err != nil { + return false + } + // Path escapes if it starts with ".." or is absolute + return !strings.HasPrefix(rel, "..") && !filepath.IsAbs(rel) +} + +func (s *Scanner) tarGzFiles(repoRoot string, files []string, writer io.Writer) (err error) { + gzWriter := gzip.NewWriter(writer) + defer func() { + if closeErr := gzWriter.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + + tarWriter := tar.NewWriter(gzWriter) + defer func() { + if closeErr := tarWriter.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() + + // Note: The 'err' variables declared with := inside the loop shadow the named return value. + // This is intentional - direct returns like 'return copyErr' still set the named return, + // allowing the deferred close functions above to check if an error occurred. + // + // Security: TOCTOU (Time-of-Check-Time-of-Use) between ValidateExistence and file + // operations is an acceptable risk here. Files that disappear between validation and + // usage are gracefully handled with warnings. Symlinks are explicitly skipped to + // prevent escape attacks. Path traversal is prevented by SafeJoinPath validation + // and defense-in-depth isPathContained checks. + filesWritten := 0 + for _, relPath := range files { + absPath := filepath.Join(repoRoot, relPath) + + // Defense-in-depth: verify path is within repo root + if !isPathContained(repoRoot, absPath) { + fmt.Fprintf(os.Stderr, "Warning: skipping path outside repository: %s\n", relPath) + continue + } + + info, err := os.Stat(absPath) + if err != nil { + // Skip files that don't exist (may have been deleted) + fmt.Fprintf(os.Stderr, "Warning: skipping %s: %v\n", relPath, err) + continue + } + + // Skip directories - we only handle files + if info.IsDir() { + continue + } + + // Skip symlinks for security + if info.Mode()&os.ModeSymlink != 0 { + fmt.Fprintf(os.Stderr, "Warning: skipping symlink %s\n", relPath) + continue + } + + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return err + } + + // Use forward slashes for tar paths + header.Name = filepath.ToSlash(relPath) + + if err := tarWriter.WriteHeader(header); err != nil { + return err + } + + file, err := os.Open(absPath) // #nosec G304 - path is validated via SafeJoinPath + if err != nil { + return err + } + _, copyErr := io.Copy(tarWriter, file) + closeErr := file.Close() + if copyErr != nil { + return copyErr + } + if closeErr != nil { + return closeErr + } + filesWritten++ + } + + if filesWritten == 0 { + return fmt.Errorf("no files were added to archive") + } + + return nil +} + +func calculateFilesSize(repoRoot string, files []string) (int64, error) { + var size int64 + for _, relPath := range files { + absPath := filepath.Join(repoRoot, relPath) + + // Defense-in-depth: verify path is within repo root + if !isPathContained(repoRoot, absPath) { + continue // Skip paths outside repository + } + + info, err := os.Stat(absPath) + if err != nil { + continue // Skip non-existent files + } + if !info.IsDir() && info.Mode()&os.ModeSymlink == 0 { + size += info.Size() + } + } + return size, nil +} + func calculateDirSize(path string, includeTests bool, ignoreMatcher *IgnoreMatcher) (int64, error) { var size int64 err := filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {