package cmd import ( "context" "encoding/hex" "fmt" "io" "os" "path/filepath" "runtime" "strings" "time" "github.com/salvacybersec/keyhunter/pkg/config" "github.com/salvacybersec/keyhunter/pkg/engine" "github.com/salvacybersec/keyhunter/pkg/engine/sources" "github.com/salvacybersec/keyhunter/pkg/output" "github.com/salvacybersec/keyhunter/pkg/providers" "github.com/salvacybersec/keyhunter/pkg/storage" "github.com/salvacybersec/keyhunter/pkg/verify" "github.com/spf13/cobra" "github.com/spf13/viper" ) var ( flagWorkers int flagVerify bool flagUnmask bool flagOutput string flagExclude []string flagGit bool flagURL string flagClipboard bool flagSince string flagMaxFileSize int64 flagInsecure bool flagVerifyTimeout time.Duration flagVerifyWorkers int ) var scanCmd = &cobra.Command{ Use: "scan [path|stdin|-]", Short: "Scan files, directories, git history, stdin, URLs, or clipboard for leaked API keys", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { // Load config cfg := config.Load() if viper.GetInt("scan.workers") > 0 { cfg.Workers = viper.GetInt("scan.workers") } // Workers flag overrides config workers := flagWorkers if workers <= 0 { workers = cfg.Workers } if workers <= 0 { workers = runtime.NumCPU() * 8 } // Initialize registry reg, err := providers.NewRegistry() if err != nil { return fmt.Errorf("loading providers: %w", err) } // Initialize engine and select source based on flags/args. eng := engine.NewEngine(reg) src, err := selectSource(args, sourceFlags{ Git: flagGit, URL: flagURL, Clipboard: flagClipboard, Since: flagSince, Excludes: flagExclude, }) if err != nil { return err } scanCfg := engine.ScanConfig{ Workers: workers, Verify: flagVerify, Unmask: flagUnmask, } // Open database (ensure directory exists) dbPath := viper.GetString("database.path") if dbPath == "" { dbPath = cfg.DBPath } if err := os.MkdirAll(filepath.Dir(dbPath), 0700); err != nil { return fmt.Errorf("creating database directory: %w", err) } db, err := storage.Open(dbPath) if err != nil { return fmt.Errorf("opening database: %w", err) } defer db.Close() // Derive encryption key using a per-installation salt stored in settings table. // On first run, NewSalt() generates a random salt and stores it. // On subsequent runs, the same salt is loaded -- ensuring consistent encryption. encKey, err := loadOrCreateEncKey(db, cfg.Passphrase) if err != nil { return fmt.Errorf("preparing encryption key: %w", err) } // Run scan ch, err := eng.Scan(context.Background(), src, scanCfg) if err != nil { return fmt.Errorf("starting scan: %w", err) } // Collect findings first (no immediate save) so verification can populate // verify_* fields before persistence. var findings []engine.Finding for f := range ch { findings = append(findings, f) } // Phase 5: if --verify is set, gate on consent and route findings // through the HTTPVerifier before persisting. Declined consent skips // verification but still prints + persists the unverified findings. if flagVerify && len(findings) > 0 { granted, consentErr := verify.EnsureConsent(db, os.Stdin, os.Stderr) if consentErr != nil { return fmt.Errorf("consent check: %w", consentErr) } if !granted { fmt.Fprintln(os.Stderr, "Verification skipped (consent not granted). Run `keyhunter legal` for details.") } else { verifier := verify.NewHTTPVerifier(flagVerifyTimeout) resultsCh := verifier.VerifyAll(context.Background(), findings, reg, flagVerifyWorkers) // Build an index keyed by provider+maskedKey so we can back-assign // results onto the findings slice (VerifyAll preserves neither // order nor identity — only the provider/masked tuple). idx := make(map[string]int, len(findings)) for i, f := range findings { idx[f.ProviderName+"|"+f.KeyMasked] = i } for r := range resultsCh { if i, ok := idx[r.ProviderName+"|"+r.KeyMasked]; ok { findings[i].Verified = true findings[i].VerifyStatus = r.Status findings[i].VerifyHTTPCode = r.HTTPCode findings[i].VerifyMetadata = r.Metadata if r.Error != "" { findings[i].VerifyError = r.Error } } } } } // Persist all findings with verify_* fields populated (if verification ran). for _, f := range findings { storeFinding := storage.Finding{ ProviderName: f.ProviderName, KeyValue: f.KeyValue, KeyMasked: f.KeyMasked, Confidence: f.Confidence, SourcePath: f.Source, SourceType: f.SourceType, LineNumber: f.LineNumber, Verified: f.Verified, VerifyStatus: f.VerifyStatus, VerifyHTTPCode: f.VerifyHTTPCode, VerifyMetadata: f.VerifyMetadata, } if _, err := db.SaveFinding(storeFinding, encKey); err != nil { fmt.Fprintf(os.Stderr, "warning: failed to save finding: %v\n", err) } } // Output via the formatter registry (OUT-01..04). if err := renderScanOutput(findings, flagOutput, flagUnmask, os.Stdout); err != nil { return err } // OUT-06 exit codes: 0=clean, 1=findings, 2=error (errors returned via // RunE -> root.Execute -> os.Exit(2)). if len(findings) > 0 { os.Exit(1) } return nil }, } // renderScanOutput dispatches findings through the formatter registry. It is // the single entry point used by scan RunE (and tests) to format output. // // Returns an error wrapping output.ErrUnknownFormat when name does not match a // registered formatter; the error message includes the valid format list. func renderScanOutput(findings []engine.Finding, name string, unmask bool, w io.Writer) error { formatter, err := output.Get(name) if err != nil { return fmt.Errorf("%w (valid: %s)", err, strings.Join(output.Names(), ", ")) } if err := formatter.Format(findings, w, output.Options{ Unmask: unmask, ToolName: "keyhunter", ToolVersion: versionString(), }); err != nil { return fmt.Errorf("rendering %s output: %w", name, err) } return nil } // version is the compiled tool version. Override with: // // go build -ldflags "-X github.com/salvacybersec/keyhunter/cmd.version=1.2.3" var version = "dev" // versionString returns the compiled tool version (used by SARIF output). func versionString() string { return version } // sourceFlags captures the CLI inputs that control source selection. // Extracted into a struct so selectSource is straightforward to unit test. type sourceFlags struct { Git bool URL string Clipboard bool Since string Excludes []string } // selectSource inspects positional args and source flags, validates that // exactly one source is specified, and returns the appropriate Source. // // Dispatch rules: // - --url / --clipboard: no positional arg, mutually exclusive with --git and each other // - --git : uses GitSource (optionally filtered by --since=YYYY-MM-DD) // - target == "stdin" or "-": uses StdinSource // - target is a directory: uses DirSource (forwards --exclude patterns) // - target is a file: uses FileSource func selectSource(args []string, f sourceFlags) (sources.Source, error) { // Count explicit source selectors that are mutually exclusive. explicitCount := 0 if f.URL != "" { explicitCount++ } if f.Clipboard { explicitCount++ } if f.Git { explicitCount++ } if explicitCount > 1 { return nil, fmt.Errorf("scan: --git, --url, and --clipboard are mutually exclusive") } // Clipboard and URL take no positional argument. if f.Clipboard { if len(args) > 0 { return nil, fmt.Errorf("scan: --clipboard does not accept a positional argument") } return sources.NewClipboardSource(), nil } if f.URL != "" { if len(args) > 0 { return nil, fmt.Errorf("scan: --url does not accept a positional argument") } return sources.NewURLSource(f.URL), nil } if len(args) == 0 { return nil, fmt.Errorf("scan: missing target (path, stdin, -, or a source flag)") } target := args[0] if target == "stdin" || target == "-" { if f.Git { return nil, fmt.Errorf("scan: --git cannot be combined with stdin") } return sources.NewStdinSource(), nil } if f.Git { gs := sources.NewGitSource(target) if f.Since != "" { t, err := time.Parse("2006-01-02", f.Since) if err != nil { return nil, fmt.Errorf("scan: --since must be YYYY-MM-DD: %w", err) } gs.Since = t } return gs, nil } info, err := os.Stat(target) if err != nil { return nil, fmt.Errorf("scan: stat %q: %w", target, err) } if info.IsDir() { return sources.NewDirSource(target, f.Excludes...), nil } return sources.NewFileSource(target), nil } // loadOrCreateEncKey loads the per-installation salt from the settings table. // On first run it generates a new random salt with storage.NewSalt() and persists it. // The salt is hex-encoded in the settings table under key "encryption.salt". func loadOrCreateEncKey(db *storage.DB, passphrase string) ([]byte, error) { const saltKey = "encryption.salt" saltHex, found, err := db.GetSetting(saltKey) if err != nil { return nil, fmt.Errorf("reading salt from settings: %w", err) } var salt []byte if !found { // First run: generate and persist a new random salt. salt, err = storage.NewSalt() if err != nil { return nil, fmt.Errorf("generating salt: %w", err) } if err := db.SetSetting(saltKey, hex.EncodeToString(salt)); err != nil { return nil, fmt.Errorf("storing salt: %w", err) } } else { salt, err = hex.DecodeString(saltHex) if err != nil { return nil, fmt.Errorf("decoding stored salt: %w", err) } } return storage.DeriveKey([]byte(passphrase), salt), nil } func init() { scanCmd.Flags().IntVar(&flagWorkers, "workers", 0, "number of worker goroutines (default: CPU*8)") scanCmd.Flags().BoolVar(&flagVerify, "verify", false, "actively verify found keys (opt-in, Phase 5)") scanCmd.Flags().BoolVar(&flagUnmask, "unmask", false, "show full key values (default: masked)") scanCmd.Flags().StringVar(&flagOutput, "output", "table", "output format: table, json, sarif, csv") scanCmd.Flags().StringSliceVar(&flagExclude, "exclude", nil, "extra glob patterns to exclude (e.g. *.min.js)") // Phase 4 source-selection flags. scanCmd.Flags().BoolVar(&flagGit, "git", false, "treat target as a git repo and scan full history") scanCmd.Flags().StringVar(&flagURL, "url", "", "fetch and scan a remote http(s) URL (no positional arg)") scanCmd.Flags().BoolVar(&flagClipboard, "clipboard", false, "scan current clipboard contents") scanCmd.Flags().StringVar(&flagSince, "since", "", "for --git: only scan commits after YYYY-MM-DD") scanCmd.Flags().Int64Var(&flagMaxFileSize, "max-file-size", 0, "max file size in bytes to scan (0 = unlimited)") scanCmd.Flags().BoolVar(&flagInsecure, "insecure", false, "for --url: skip TLS certificate verification") // Phase 5 verification flags. scanCmd.Flags().DurationVar(&flagVerifyTimeout, "verify-timeout", 10*time.Second, "per-key verification HTTP timeout (default 10s)") scanCmd.Flags().IntVar(&flagVerifyWorkers, "verify-workers", 10, "parallel workers for key verification (default 10)") _ = viper.BindPFlag("scan.workers", scanCmd.Flags().Lookup("workers")) }