diff --git a/.github/workflows/pr-security-scan.yml b/.github/workflows/pr-security-scan.yml
new file mode 100644
index 0000000..def5612
--- /dev/null
+++ b/.github/workflows/pr-security-scan.yml
@@ -0,0 +1,49 @@
+name: PR Security Scan
+
+on:
+ pull_request:
+ branches: [main]
+
+permissions:
+ contents: read
+ security-events: write
+ actions: read
+ pull-requests: write
+
+jobs:
+ get-changed-files:
+ name: Get Changed Files
+ runs-on: ubuntu-latest
+ outputs:
+ files: ${{ steps.changed-files.outputs.all_changed_files }}
+ any_changed: ${{ steps.changed-files.outputs.any_changed }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Get changed files
+ id: changed-files
+ uses: tj-actions/changed-files@v46
+ with:
+ separator: ','
+ # Exclude test files from security scan - they contain intentional
+ # path traversal test data that triggers false positives
+ files_ignore: |
+ **/*_test.go
+ **/testdata/**
+
+ security-scan:
+ name: Security Scan
+ needs: get-changed-files
+ if: needs.get-changed-files.outputs.any_changed == 'true'
+ uses: ./.github/workflows/reusable-security-scan.yml
+ with:
+ scan-type: repo
+ fail-on: 'CRITICAL,HIGH'
+ include-files: ${{ needs.get-changed-files.outputs.files }}
+ build-from-source: false
+ secrets:
+ api-token: ${{ secrets.ARMIS_API_TOKEN }}
+ tenant-id: ${{ secrets.ARMIS_TENANT_ID }}
diff --git a/.github/workflows/reusable-security-scan.yml b/.github/workflows/reusable-security-scan.yml
index 3d8bde8..32fbd65 100644
--- a/.github/workflows/reusable-security-scan.yml
+++ b/.github/workflows/reusable-security-scan.yml
@@ -35,6 +35,14 @@ on:
description: 'Scan timeout in minutes'
type: number
default: 60
+ include-files:
+ description: 'Comma-separated list of file paths to scan (relative to repository root)'
+ type: string
+ default: ''
+ build-from-source:
+ description: 'Build CLI from source instead of downloading release (for testing scanner changes)'
+ type: boolean
+ default: false
secrets:
api-token:
description: 'Armis API token for authentication'
@@ -43,6 +51,14 @@ on:
description: 'Tenant identifier for Armis Cloud'
required: true
+# Top-level permissions define the maximum permissions available to this workflow.
+# Job-level permissions further restrict as needed.
+permissions:
+ contents: read
+ security-events: write
+ actions: read
+ pull-requests: write
+
jobs:
security-scan:
name: Armis Security Scan
@@ -62,7 +78,7 @@ jobs:
- name: Run Armis Security Scan
id: armis_scan
- uses: ArmisSecurity/armis-cli@main
+ uses: ./
with:
scan-type: ${{ inputs.scan-type }}
scan-target: ${{ inputs.scan-target }}
@@ -73,6 +89,8 @@ jobs:
output-file: armis-results.sarif
image-tarball: ${{ inputs.image-tarball }}
scan-timeout: ${{ inputs.scan-timeout }}
+ include-files: ${{ inputs.include-files }}
+ build-from-source: ${{ inputs.build-from-source }}
continue-on-error: true
- name: Ensure SARIF exists
@@ -121,7 +139,76 @@ jobs:
if (counts.LOW > 0) body += `| 🔵 LOW | ${counts.LOW} |\n`;
if (counts.INFO > 0) body += `| ⚪ INFO | ${counts.INFO} |\n`;
body += `\n**Total: ${total}**\n`;
- body += `\nView full results
\n\nSee the Security tab or download the \`armis-security-results\` artifact for the complete SARIF report.\n `;
+
+ // Build detailed findings section
+ if (total > 0) {
+ body += `\nView all ${total} findings
\n\n`;
+
+ // Group results by severity
+ const severityOrder = ['CRITICAL', 'HIGH', 'MEDIUM', 'LOW', 'INFO'];
+ const severityEmoji = { CRITICAL: '🔴', HIGH: '🟠', MEDIUM: '🟡', LOW: '🔵', INFO: '⚪' };
+
+ for (const severity of severityOrder) {
+ const severityResults = results.filter(r =>
+ (r.properties?.severity || 'INFO') === severity
+ );
+
+ if (severityResults.length > 0) {
+ body += `### ${severityEmoji[severity]} ${severity} (${severityResults.length})\n\n`;
+
+ for (const r of severityResults) {
+ const file = r.locations?.[0]?.physicalLocation?.artifactLocation?.uri || '';
+ const line = r.locations?.[0]?.physicalLocation?.region?.startLine || '';
+ const location = file ? (line ? `${file}:${line}` : file) : 'Unknown location';
+
+ // Parse title and description from message
+ const msgParts = (r.message?.text || '').split(': ');
+ const title = msgParts[0] || r.ruleId;
+ const description = msgParts.slice(1).join(': ') || '';
+
+ body += `${r.ruleId} - ${title}
\n\n`;
+ body += `**Location:** \`${location}\`\n\n`;
+
+ if (description) {
+ body += `${description}\n\n`;
+ }
+
+ // Code snippet
+ const snippet = r.properties?.codeSnippet;
+ if (snippet) {
+ body += '```\n' + snippet + '\n```\n\n';
+ }
+
+ // CVEs and CWEs
+ const cves = r.properties?.cves || [];
+ const cwes = r.properties?.cwes || [];
+ if (cves.length > 0) {
+ body += `**CVEs:** ${cves.join(', ')}\n\n`;
+ }
+ if (cwes.length > 0) {
+ body += `**CWEs:** ${cwes.join(', ')}\n\n`;
+ }
+
+ // Package info
+ const pkg = r.properties?.package;
+ const version = r.properties?.version;
+ const fixVersion = r.properties?.fixVersion;
+ if (pkg) {
+ let pkgInfo = `**Package:** ${pkg}`;
+ if (version) pkgInfo += ` (${version})`;
+ if (fixVersion) pkgInfo += ` → Fix: ${fixVersion}`;
+ body += pkgInfo + '\n\n';
+ }
+
+ body += ` \n\n`;
+ }
+ }
+ }
+
+ body += ` `;
+ } else {
+ body += `\nView full results
\n\nNo security issues found.\n `;
+ }
// Find and update existing comment, or create new
const { data: comments } = await github.rest.issues.listComments({
diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml
index 2f9013f..b48a79a 100644
--- a/.github/workflows/security-scan.yml
+++ b/.github/workflows/security-scan.yml
@@ -18,6 +18,7 @@ jobs:
pr-comment: false # No PR context for scheduled runs
upload-artifact: true
scan-timeout: 120 # 2 hours
+ build-from-source: false
secrets:
api-token: ${{ secrets.ARMIS_API_TOKEN }}
tenant-id: ${{ secrets.ARMIS_TENANT_ID }}
diff --git a/action.yml b/action.yml
index f9d7d53..2030718 100644
--- a/action.yml
+++ b/action.yml
@@ -46,6 +46,14 @@ inputs:
description: 'Scan timeout in minutes'
required: false
default: '60'
+ include-files:
+ description: 'Comma-separated list of file paths to scan (relative to repository root)'
+ required: false
+ default: ''
+ build-from-source:
+ description: 'Build CLI from source instead of downloading release (for testing)'
+ required: false
+ default: 'false'
outputs:
results:
@@ -58,11 +66,80 @@ outputs:
runs:
using: 'composite'
steps:
- - name: Install Armis CLI
+ - name: Setup Go
+ if: inputs.build-from-source == 'true'
+ uses: actions/setup-go@v5
+ with:
+ go-version: '1.23'
+
+ - name: Build Armis CLI from source
+ if: inputs.build-from-source == 'true'
+ shell: bash
+ run: |
+ echo "Building Armis CLI from source..."
+ go build -o armis-cli ./cmd/armis-cli
+ mkdir -p "$HOME/.local/bin"
+ mv armis-cli "$HOME/.local/bin/"
+ echo "$HOME/.local/bin" >> $GITHUB_PATH
+
+ - name: Install Armis CLI from release
+ if: inputs.build-from-source != 'true'
shell: bash
run: |
echo "Installing Armis CLI..."
- curl -sSL https://raw.githubusercontent.com/ArmisSecurity/armis-cli/main/scripts/install.sh | bash
+
+ # Detect OS and architecture
+ OS=$(uname -s | tr '[:upper:]' '[:lower:]')
+ case "$OS" in
+ linux*) OS="linux" ;;
+ darwin*) OS="darwin" ;;
+ *) echo "Error: Unsupported OS: $OS"; exit 1 ;;
+ esac
+
+ ARCH=$(uname -m)
+ case "$ARCH" in
+ x86_64|amd64) ARCH="amd64" ;;
+ aarch64|arm64) ARCH="arm64" ;;
+ *) echo "Error: Unsupported architecture: $ARCH"; exit 1 ;;
+ esac
+
+ # Download binary directly from GitHub releases
+ REPO="ArmisSecurity/armis-cli"
+ ARCHIVE_NAME="armis-cli-${OS}-${ARCH}.tar.gz"
+ CHECKSUMS_NAME="armis-cli-checksums.txt"
+ BASE_URL="https://github.com/${REPO}/releases/latest/download"
+
+ TMP_DIR=$(mktemp -d)
+ trap 'rm -rf "$TMP_DIR"' EXIT
+
+ echo "Downloading $ARCHIVE_NAME..."
+ curl -fsSL "${BASE_URL}/${ARCHIVE_NAME}" -o "${TMP_DIR}/${ARCHIVE_NAME}"
+ curl -fsSL "${BASE_URL}/${CHECKSUMS_NAME}" -o "${TMP_DIR}/${CHECKSUMS_NAME}"
+
+ # Verify checksum
+ echo "Verifying checksum..."
+ cd "$TMP_DIR"
+ EXPECTED=$(grep "$ARCHIVE_NAME" "$CHECKSUMS_NAME" | awk '{print $1}')
+ if command -v sha256sum > /dev/null 2>&1; then
+ ACTUAL=$(sha256sum "$ARCHIVE_NAME" | awk '{print $1}')
+ else
+ ACTUAL=$(shasum -a 256 "$ARCHIVE_NAME" | awk '{print $1}')
+ fi
+
+ if [ "$EXPECTED" != "$ACTUAL" ]; then
+ echo "Error: Checksum verification failed!"
+ echo "Expected: $EXPECTED"
+ echo "Actual: $ACTUAL"
+ exit 1
+ fi
+ echo "Checksum verified successfully"
+
+ # Extract and install
+ tar -xzf "$ARCHIVE_NAME"
+ mkdir -p "$HOME/.local/bin"
+ mv armis-cli "$HOME/.local/bin/"
+ chmod +x "$HOME/.local/bin/armis-cli"
+ echo "Armis CLI installed successfully"
echo "$HOME/.local/bin" >> $GITHUB_PATH
- name: Verify Installation
@@ -85,6 +162,7 @@ runs:
NO_PROGRESS: ${{ inputs.no-progress }}
IMAGE_TARBALL: ${{ inputs.image-tarball }}
OUTPUT_FILE: ${{ inputs.output-file }}
+ INCLUDE_FILES: ${{ inputs.include-files }}
run: |
set +e
@@ -114,6 +192,10 @@ runs:
SCAN_ARGS+=("--tarball" "$IMAGE_TARBALL")
fi
+ if [ -n "$INCLUDE_FILES" ]; then
+ SCAN_ARGS+=("--include-files" "$INCLUDE_FILES")
+ fi
+
# Execute command safely without eval
if [ -n "$OUTPUT_FILE" ]; then
armis-cli "${SCAN_ARGS[@]}" > "$OUTPUT_FILE"
diff --git a/internal/cmd/scan.go b/internal/cmd/scan.go
index 7d37332..4f78c4e 100644
--- a/internal/cmd/scan.go
+++ b/internal/cmd/scan.go
@@ -10,6 +10,7 @@ var (
uploadTimeout int
includeNonExploitable bool
groupBy string
+ includeFiles []string
)
var scanCmd = &cobra.Command{
@@ -24,6 +25,7 @@ func init() {
scanCmd.PersistentFlags().IntVar(&uploadTimeout, "upload-timeout", 10, "Maximum time in minutes to wait for artifact upload to complete")
scanCmd.PersistentFlags().BoolVar(&includeNonExploitable, "include-non-exploitable", false, "Include findings marked as non-exploitable (only exploitable findings shown by default)")
scanCmd.PersistentFlags().StringVar(&groupBy, "group-by", "none", "Group findings by: none, cwe, severity, file")
+ scanCmd.PersistentFlags().StringSliceVar(&includeFiles, "include-files", nil, "Comma-separated list of file paths to include in scan (relative to repository root)")
if rootCmd != nil {
rootCmd.AddCommand(scanCmd)
}
diff --git a/internal/cmd/scan_repo.go b/internal/cmd/scan_repo.go
index c283d84..8718f2c 100644
--- a/internal/cmd/scan_repo.go
+++ b/internal/cmd/scan_repo.go
@@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"os"
+ "path/filepath"
"time"
"github.com/ArmisSecurity/armis-cli/internal/api"
@@ -47,6 +48,22 @@ var scanRepoCmd = &cobra.Command{
scanTimeoutDuration := time.Duration(scanTimeout) * time.Minute
scanner := repo.NewScanner(client, noProgress, tid, limit, includeTests, scanTimeoutDuration, includeNonExploitable)
+ // Handle --include-files flag for targeted file scanning
+ // Security: Path traversal protection is enforced by ParseFileList which
+ // validates all paths using SafeJoinPath to ensure they don't escape the
+ // repository root. Invalid or traversal paths are rejected with an error.
+ if len(includeFiles) > 0 {
+ absPath, err := filepath.Abs(repoPath)
+ if err != nil {
+ return fmt.Errorf("failed to resolve path: %w", err)
+ }
+ fileList, err := repo.ParseFileList(absPath, includeFiles)
+ if err != nil {
+ return fmt.Errorf("invalid --include-files: %w", err)
+ }
+ scanner = scanner.WithIncludeFiles(fileList)
+ }
+
ctx, cancel := NewSignalContext()
defer cancel()
diff --git a/internal/output/sarif.go b/internal/output/sarif.go
index a21f71e..717408d 100644
--- a/internal/output/sarif.go
+++ b/internal/output/sarif.go
@@ -60,7 +60,14 @@ type sarifResult struct {
}
type sarifResultProperties struct {
- Severity string `json:"severity"`
+ Severity string `json:"severity"`
+ Type string `json:"type,omitempty"`
+ CodeSnippet string `json:"codeSnippet,omitempty"`
+ CVEs []string `json:"cves,omitempty"`
+ CWEs []string `json:"cwes,omitempty"`
+ Package string `json:"package,omitempty"`
+ Version string `json:"version,omitempty"`
+ FixVersion string `json:"fixVersion,omitempty"`
}
type sarifMessage struct {
@@ -138,8 +145,17 @@ func buildRules(findings []model.Finding) ([]sarifRule, map[string]int) {
return rules, ruleIndexMap
}
+// maxSarifResultsCapacity is the maximum initial capacity for SARIF results slice
+// to prevent resource exhaustion from extremely large finding lists (CWE-770).
+const maxSarifResultsCapacity = 10000
+
func convertToSarifResults(findings []model.Finding, ruleIndexMap map[string]int) []sarifResult {
- results := make([]sarifResult, 0, len(findings))
+ // Cap the initial capacity to prevent excessive memory allocation (CWE-770)
+ capacity := len(findings)
+ if capacity > maxSarifResultsCapacity {
+ capacity = maxSarifResultsCapacity
+ }
+ results := make([]sarifResult, 0, capacity)
for _, finding := range findings {
result := sarifResult{
@@ -150,7 +166,14 @@ func convertToSarifResults(findings []model.Finding, ruleIndexMap map[string]int
Text: finding.Title + ": " + finding.Description,
},
Properties: &sarifResultProperties{
- Severity: string(finding.Severity),
+ Severity: string(finding.Severity),
+ Type: string(finding.Type),
+ CodeSnippet: util.MaskSecretInLine(finding.CodeSnippet), // Defense-in-depth: always sanitize
+ CVEs: finding.CVEs,
+ CWEs: finding.CWEs,
+ Package: finding.Package,
+ Version: finding.Version,
+ FixVersion: finding.FixVersion,
},
}
diff --git a/internal/scan/repo/files.go b/internal/scan/repo/files.go
new file mode 100644
index 0000000..3180377
--- /dev/null
+++ b/internal/scan/repo/files.go
@@ -0,0 +1,119 @@
+// Package repo provides repository scanning functionality.
+package repo
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/ArmisSecurity/armis-cli/internal/util"
+)
+
+// MaxFiles is the maximum number of files that can be specified via --include-files.
+// This limit prevents resource exhaustion from extremely large file lists.
+const MaxFiles = 1000
+
+// FileList represents a list of files to be scanned.
+type FileList struct {
+ files []string
+ repoRoot string
+}
+
+// ParseFileList parses file paths from the --include-files flag.
+// It accepts both relative paths (to repoRoot) and absolute paths,
+// normalizing all to relative paths.
+func ParseFileList(repoRoot string, files []string) (*FileList, error) {
+ absRoot, err := filepath.Abs(repoRoot)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve repo root: %w", err)
+ }
+
+ fl := &FileList{repoRoot: absRoot}
+ for _, f := range files {
+ if err := fl.addFile(f); err != nil {
+ return nil, err
+ }
+ }
+ return fl, nil
+}
+
+func (fl *FileList) addFile(path string) error {
+ // Check file count limit to prevent resource exhaustion
+ if len(fl.files) >= MaxFiles {
+ return fmt.Errorf("too many files: maximum %d files allowed", MaxFiles)
+ }
+
+ if path == "" {
+ return nil // Skip empty paths
+ }
+
+ // Normalize path separators
+ path = filepath.FromSlash(path)
+
+ // Convert absolute paths to relative
+ if filepath.IsAbs(path) {
+ // Security: Resolve symlinks to prevent path traversal attacks (CWE-22).
+ // Using filepath.EvalSymlinks ensures we compare actual filesystem paths,
+ // preventing symlink-based escapes from the repository root.
+ evalPath, err := filepath.EvalSymlinks(path)
+ if err != nil {
+ // Path doesn't exist yet - fall back to Clean for normalization
+ evalPath = filepath.Clean(path)
+ }
+ evalRoot, err := filepath.EvalSymlinks(fl.repoRoot)
+ if err != nil {
+ evalRoot = filepath.Clean(fl.repoRoot)
+ }
+
+ // Use filepath.Rel to check containment - it returns an error or
+ // a path starting with ".." if the path is outside the root
+ relCheck, err := filepath.Rel(evalRoot, evalPath)
+ if err != nil || strings.HasPrefix(relCheck, "..") {
+ return fmt.Errorf("absolute path %q is outside repository root %q", path, fl.repoRoot)
+ }
+
+ rel, err := filepath.Rel(fl.repoRoot, path)
+ if err != nil {
+ return fmt.Errorf("cannot make path relative to repo: %s", path)
+ }
+ path = rel
+ }
+
+ // Validate path doesn't escape repo root using SafeJoinPath
+ if _, err := util.SafeJoinPath(fl.repoRoot, path); err != nil {
+ return fmt.Errorf("invalid path %q: %w", path, err)
+ }
+
+ fl.files = append(fl.files, path)
+ return nil
+}
+
+// Files returns the validated list of relative file paths.
+func (fl *FileList) Files() []string {
+ return fl.files
+}
+
+// RepoRoot returns the absolute path to the repository root.
+func (fl *FileList) RepoRoot() string {
+ return fl.repoRoot
+}
+
+// ValidateExistence checks which files exist and returns warnings for missing files.
+func (fl *FileList) ValidateExistence() (existing []string, warnings []string) {
+ for _, f := range fl.files {
+ absPath := filepath.Join(fl.repoRoot, f)
+ info, err := os.Stat(absPath)
+ if err != nil {
+ warnings = append(warnings, fmt.Sprintf("file not found: %s", f))
+ continue
+ }
+ // Skip directories - we only scan files
+ if info.IsDir() {
+ warnings = append(warnings, fmt.Sprintf("skipping directory: %s", f))
+ continue
+ }
+ existing = append(existing, f)
+ }
+ return
+}
diff --git a/internal/scan/repo/files_test.go b/internal/scan/repo/files_test.go
new file mode 100644
index 0000000..1e018ae
--- /dev/null
+++ b/internal/scan/repo/files_test.go
@@ -0,0 +1,223 @@
+package repo
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func TestParseFileList(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create test files
+ if err := os.WriteFile(filepath.Join(tmpDir, "main.go"), []byte("package main"), 0600); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+ if err := os.MkdirAll(filepath.Join(tmpDir, "pkg"), 0750); err != nil {
+ t.Fatalf("Failed to create test dir: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(tmpDir, "pkg", "helper.go"), []byte("package pkg"), 0600); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ tests := []struct {
+ name string
+ files []string
+ wantLen int
+ wantErr bool
+ }{
+ {
+ name: "valid relative paths",
+ files: []string{"main.go", "pkg/helper.go"},
+ wantLen: 2,
+ wantErr: false,
+ },
+ {
+ name: "path traversal rejected",
+ files: []string{"../etc/passwd"},
+ wantErr: true,
+ },
+ {
+ name: "empty list",
+ files: []string{},
+ wantLen: 0,
+ wantErr: false,
+ },
+ {
+ name: "empty string in list is skipped",
+ files: []string{"main.go", "", "pkg/helper.go"},
+ wantLen: 2,
+ wantErr: false,
+ },
+ {
+ name: "absolute path converted to relative",
+ files: []string{filepath.Join(tmpDir, "main.go")},
+ wantLen: 1,
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ fl, err := ParseFileList(tmpDir, tt.files)
+ if tt.wantErr {
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(fl.Files()) != tt.wantLen {
+ t.Errorf("got %d files, want %d", len(fl.Files()), tt.wantLen)
+ }
+ })
+ }
+}
+
+func TestFileListValidateExistence(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create one existing file
+ if err := os.WriteFile(filepath.Join(tmpDir, "exists.go"), []byte("package main"), 0600); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ // Create a directory to test directory skipping
+ if err := os.MkdirAll(filepath.Join(tmpDir, "subdir"), 0750); err != nil {
+ t.Fatalf("Failed to create test dir: %v", err)
+ }
+
+ fl, err := ParseFileList(tmpDir, []string{"exists.go", "missing.go", "subdir"})
+ if err != nil {
+ t.Fatalf("ParseFileList failed: %v", err)
+ }
+
+ existing, warnings := fl.ValidateExistence()
+
+ if len(existing) != 1 {
+ t.Errorf("expected 1 existing file, got %d", len(existing))
+ }
+ if existing[0] != "exists.go" {
+ t.Errorf("expected exists.go, got %s", existing[0])
+ }
+ if len(warnings) != 2 {
+ t.Errorf("expected 2 warnings (missing file + directory), got %d", len(warnings))
+ }
+}
+
+func TestParseFileListPathTraversal(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ traversalPaths := []string{
+ "../etc/passwd",
+ "foo/../../etc/passwd",
+ "./foo/../../../etc/passwd",
+ }
+
+ for _, path := range traversalPaths {
+ t.Run(path, func(t *testing.T) {
+ _, err := ParseFileList(tmpDir, []string{path})
+ if err == nil {
+ t.Errorf("expected error for path traversal attempt: %s", path)
+ }
+ })
+ }
+}
+
+func TestParseFileListAbsolutePathOutsideRepo(t *testing.T) {
+ // Create two separate temp directories - one is the "repo root", the other is "outside"
+ repoDir := t.TempDir()
+ outsideDir := t.TempDir()
+
+ // Create a file in the outside directory to get a real absolute path
+ outsideFile := filepath.Join(outsideDir, "outside.go")
+ if err := os.WriteFile(outsideFile, []byte("package outside"), 0600); err != nil {
+ t.Fatalf("Failed to create outside file: %v", err)
+ }
+
+ // Test that an absolute path outside the repo root is rejected
+ _, err := ParseFileList(repoDir, []string{outsideFile})
+ if err == nil {
+ t.Errorf("expected error for absolute path outside repo: %s", outsideFile)
+ }
+ // Verify the error message is clear about the issue
+ if err != nil && !strings.Contains(err.Error(), "outside repository root") {
+ t.Errorf("expected error message to mention 'outside repository root', got: %s", err.Error())
+ }
+}
+
+func TestFileListRepoRoot(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ fl, err := ParseFileList(tmpDir, []string{})
+ if err != nil {
+ t.Fatalf("ParseFileList failed: %v", err)
+ }
+
+ // RepoRoot should return an absolute path
+ root := fl.RepoRoot()
+ if !filepath.IsAbs(root) {
+ t.Errorf("RepoRoot should return absolute path, got: %s", root)
+ }
+}
+
+func TestFileListFiles(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Create test file
+ if err := os.WriteFile(filepath.Join(tmpDir, "test.go"), []byte("package main"), 0600); err != nil {
+ t.Fatalf("Failed to create test file: %v", err)
+ }
+
+ fl, err := ParseFileList(tmpDir, []string{"test.go"})
+ if err != nil {
+ t.Fatalf("ParseFileList failed: %v", err)
+ }
+
+ files := fl.Files()
+ if len(files) != 1 {
+ t.Fatalf("expected 1 file, got %d", len(files))
+ }
+ if files[0] != "test.go" {
+ t.Errorf("expected test.go, got %s", files[0])
+ }
+}
+
+func TestParseFileListMaxFilesLimit(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Generate more files than the limit
+ files := make([]string, MaxFiles+1)
+ for i := range files {
+ files[i] = "file.go" // Doesn't need to exist for this test
+ }
+
+ _, err := ParseFileList(tmpDir, files)
+ if err == nil {
+ t.Errorf("expected error when exceeding MaxFiles limit (%d), got nil", MaxFiles)
+ }
+ if err != nil && !strings.Contains(err.Error(), "too many files") {
+ t.Errorf("expected error message to mention 'too many files', got: %s", err.Error())
+ }
+}
+
+func TestParseFileListAtMaxFilesLimit(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // Generate exactly MaxFiles files (should succeed)
+ files := make([]string, MaxFiles)
+ for i := range files {
+ files[i] = "file.go" // Doesn't need to exist for this test
+ }
+
+ fl, err := ParseFileList(tmpDir, files)
+ if err != nil {
+ t.Errorf("expected success at exactly MaxFiles limit (%d), got error: %v", MaxFiles, err)
+ }
+ if fl != nil && len(fl.Files()) != MaxFiles {
+ t.Errorf("expected %d files, got %d", MaxFiles, len(fl.Files()))
+ }
+}
diff --git a/internal/scan/repo/repo.go b/internal/scan/repo/repo.go
index 5ecc17f..2f4d728 100644
--- a/internal/scan/repo/repo.go
+++ b/internal/scan/repo/repo.go
@@ -33,6 +33,7 @@ type Scanner struct {
timeout time.Duration
includeNonExploitable bool
pollInterval time.Duration
+ includeFiles *FileList
}
// NewScanner creates a new repository scanner with the given configuration.
@@ -55,6 +56,12 @@ func (s *Scanner) WithPollInterval(d time.Duration) *Scanner {
return s
}
+// WithIncludeFiles sets a specific list of files to scan instead of the entire directory.
+func (s *Scanner) WithIncludeFiles(fl *FileList) *Scanner {
+ s.includeFiles = fl
+ return s
+}
+
// Scan scans a repository at the given path.
func (s *Scanner) Scan(ctx context.Context, path string) (*model.ScanResult, error) {
// Validate path to prevent path traversal
@@ -76,30 +83,78 @@ func (s *Scanner) Scan(ctx context.Context, path string) (*model.ScanResult, err
return nil, fmt.Errorf("path is not a directory: %s", absPath)
}
- ignoreMatcher, err := LoadIgnorePatterns(absPath)
- if err != nil {
- return nil, fmt.Errorf("failed to load ignore patterns: %w", err)
- }
+ var size int64
+ var tarFunc func() error
+ var ignoreMatcher *IgnoreMatcher
- size, err := calculateDirSize(absPath, s.includeTests, ignoreMatcher)
- if err != nil {
- return nil, fmt.Errorf("failed to calculate directory size: %w", err)
+ pr, pw := io.Pipe()
+ // The pipe reader is deferred-closed to ensure cleanup on all code paths.
+ // Error is intentionally ignored (nolint:errcheck) because:
+ // 1. PipeReader.Close() rarely returns meaningful errors
+ // 2. The critical close is PipeWriter.Close() which signals EOF to the reader
+ // 3. Any actual read errors will surface through the main error flow
+ defer pr.Close() //nolint:errcheck
+
+ if s.includeFiles != nil {
+ // Targeted file scanning mode - scan only specified files
+ existing, warnings := s.includeFiles.ValidateExistence()
+ for _, w := range warnings {
+ fmt.Fprintf(os.Stderr, "Warning: %s\n", w)
+ }
+
+ if len(existing) == 0 {
+ return nil, fmt.Errorf("no files to scan: all specified files are missing or are directories")
+ }
+
+ var err error
+ size, err = calculateFilesSize(absPath, existing)
+ if err != nil {
+ return nil, fmt.Errorf("failed to calculate files size: %w", err)
+ }
+
+ tarFunc = func() error {
+ defer pw.Close() //nolint:errcheck // signals EOF to reader
+ return s.tarGzFiles(absPath, existing, pw)
+ }
+ } else {
+ // Full directory scanning mode (existing behavior)
+ var err error
+ ignoreMatcher, err = LoadIgnorePatterns(absPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load ignore patterns: %w", err)
+ }
+
+ size, err = calculateDirSize(absPath, s.includeTests, ignoreMatcher)
+ if err != nil {
+ return nil, fmt.Errorf("failed to calculate directory size: %w", err)
+ }
+
+ tarFunc = func() error {
+ defer pw.Close() //nolint:errcheck // signals EOF to reader
+ return s.tarGzDirectory(absPath, pw, ignoreMatcher)
+ }
}
if size > MaxRepoSize {
return nil, fmt.Errorf("directory size (%d bytes) exceeds maximum allowed size (%d bytes)", size, MaxRepoSize)
}
- pr, pw := io.Pipe()
-
spinner := progress.NewSpinnerWithContext(ctx, "Creating a compressed archive...", s.noProgress)
spinner.Start()
defer spinner.Stop()
errChan := make(chan error, 1)
go func() {
- defer pw.Close() //nolint:errcheck // signals EOF to reader
- errChan <- s.tarGzDirectory(absPath, pw, ignoreMatcher)
+ // Security: Check context before starting expensive tar operation
+ // to prevent resource leaks if context is already canceled.
+ select {
+ case <-ctx.Done():
+ pw.Close() //nolint:errcheck,gosec // Close pipe to unblock StartIngest
+ errChan <- ctx.Err()
+ return
+ default:
+ }
+ errChan <- tarFunc()
}()
time.Sleep(500 * time.Millisecond)
@@ -220,6 +275,125 @@ func (s *Scanner) tarGzDirectory(sourcePath string, writer io.Writer, ignoreMatc
})
}
+// isPathContained verifies that absPath is contained within baseDir.
+// This is a defense-in-depth check to prevent path traversal attacks,
+// complementing the SafeJoinPath validation performed at file list parsing.
+func isPathContained(baseDir, absPath string) bool {
+ rel, err := filepath.Rel(baseDir, absPath)
+ if err != nil {
+ return false
+ }
+ // Path escapes if it starts with ".." or is absolute
+ return !strings.HasPrefix(rel, "..") && !filepath.IsAbs(rel)
+}
+
+func (s *Scanner) tarGzFiles(repoRoot string, files []string, writer io.Writer) (err error) {
+ gzWriter := gzip.NewWriter(writer)
+ defer func() {
+ if closeErr := gzWriter.Close(); closeErr != nil && err == nil {
+ err = closeErr
+ }
+ }()
+
+ tarWriter := tar.NewWriter(gzWriter)
+ defer func() {
+ if closeErr := tarWriter.Close(); closeErr != nil && err == nil {
+ err = closeErr
+ }
+ }()
+
+ // Note: The 'err' variables declared with := inside the loop shadow the named return value.
+ // This is intentional - direct returns like 'return copyErr' still set the named return,
+ // allowing the deferred close functions above to check if an error occurred.
+ //
+ // Security: TOCTOU (Time-of-Check-Time-of-Use) between ValidateExistence and file
+ // operations is an acceptable risk here. Files that disappear between validation and
+ // usage are gracefully handled with warnings. Symlinks are explicitly skipped to
+ // prevent escape attacks. Path traversal is prevented by SafeJoinPath validation
+ // and defense-in-depth isPathContained checks.
+ filesWritten := 0
+ for _, relPath := range files {
+ absPath := filepath.Join(repoRoot, relPath)
+
+ // Defense-in-depth: verify path is within repo root
+ if !isPathContained(repoRoot, absPath) {
+ fmt.Fprintf(os.Stderr, "Warning: skipping path outside repository: %s\n", relPath)
+ continue
+ }
+
+ info, err := os.Stat(absPath)
+ if err != nil {
+ // Skip files that don't exist (may have been deleted)
+ fmt.Fprintf(os.Stderr, "Warning: skipping %s: %v\n", relPath, err)
+ continue
+ }
+
+ // Skip directories - we only handle files
+ if info.IsDir() {
+ continue
+ }
+
+ // Skip symlinks for security
+ if info.Mode()&os.ModeSymlink != 0 {
+ fmt.Fprintf(os.Stderr, "Warning: skipping symlink %s\n", relPath)
+ continue
+ }
+
+ header, err := tar.FileInfoHeader(info, "")
+ if err != nil {
+ return err
+ }
+
+ // Use forward slashes for tar paths
+ header.Name = filepath.ToSlash(relPath)
+
+ if err := tarWriter.WriteHeader(header); err != nil {
+ return err
+ }
+
+ file, err := os.Open(absPath) // #nosec G304 - path is validated via SafeJoinPath
+ if err != nil {
+ return err
+ }
+ _, copyErr := io.Copy(tarWriter, file)
+ closeErr := file.Close()
+ if copyErr != nil {
+ return copyErr
+ }
+ if closeErr != nil {
+ return closeErr
+ }
+ filesWritten++
+ }
+
+ if filesWritten == 0 {
+ return fmt.Errorf("no files were added to archive")
+ }
+
+ return nil
+}
+
+func calculateFilesSize(repoRoot string, files []string) (int64, error) {
+ var size int64
+ for _, relPath := range files {
+ absPath := filepath.Join(repoRoot, relPath)
+
+ // Defense-in-depth: verify path is within repo root
+ if !isPathContained(repoRoot, absPath) {
+ continue // Skip paths outside repository
+ }
+
+ info, err := os.Stat(absPath)
+ if err != nil {
+ continue // Skip non-existent files
+ }
+ if !info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
+ size += info.Size()
+ }
+ }
+ return size, nil
+}
+
func calculateDirSize(path string, includeTests bool, ignoreMatcher *IgnoreMatcher) (int64, error) {
var size int64
err := filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {