Files
keyhunter/cmd/scan.go
salvacybersec b151e88a29 feat(04-05): wire all Phase 4 sources through scan command
- Add --git, --url, --clipboard, --since, --max-file-size, --insecure flags
- Introduce selectSource dispatcher with sourceFlags struct
- Dispatch to Dir/File/Git/Stdin/URL/Clipboard sources based on args+flags
- Reject mutually exclusive source selectors with clear error
- Forward --exclude patterns into DirSource
- Args changed to MaximumNArgs(1) to allow --url/--clipboard without positional
2026-04-05 15:23:12 +03:00

293 lines
8.6 KiB
Go

package cmd
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"time"
"github.com/salvacybersec/keyhunter/pkg/config"
"github.com/salvacybersec/keyhunter/pkg/engine"
"github.com/salvacybersec/keyhunter/pkg/engine/sources"
"github.com/salvacybersec/keyhunter/pkg/output"
"github.com/salvacybersec/keyhunter/pkg/providers"
"github.com/salvacybersec/keyhunter/pkg/storage"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var (
flagWorkers int
flagVerify bool
flagUnmask bool
flagOutput string
flagExclude []string
flagGit bool
flagURL string
flagClipboard bool
flagSince string
flagMaxFileSize int64
flagInsecure bool
)
var scanCmd = &cobra.Command{
Use: "scan [path|stdin|-]",
Short: "Scan files, directories, git history, stdin, URLs, or clipboard for leaked API keys",
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// Load config
cfg := config.Load()
if viper.GetInt("scan.workers") > 0 {
cfg.Workers = viper.GetInt("scan.workers")
}
// Workers flag overrides config
workers := flagWorkers
if workers <= 0 {
workers = cfg.Workers
}
if workers <= 0 {
workers = runtime.NumCPU() * 8
}
// Initialize registry
reg, err := providers.NewRegistry()
if err != nil {
return fmt.Errorf("loading providers: %w", err)
}
// Initialize engine and select source based on flags/args.
eng := engine.NewEngine(reg)
src, err := selectSource(args, sourceFlags{
Git: flagGit,
URL: flagURL,
Clipboard: flagClipboard,
Since: flagSince,
Excludes: flagExclude,
})
if err != nil {
return err
}
scanCfg := engine.ScanConfig{
Workers: workers,
Verify: flagVerify,
Unmask: flagUnmask,
}
// Open database (ensure directory exists)
dbPath := viper.GetString("database.path")
if dbPath == "" {
dbPath = cfg.DBPath
}
if err := os.MkdirAll(filepath.Dir(dbPath), 0700); err != nil {
return fmt.Errorf("creating database directory: %w", err)
}
db, err := storage.Open(dbPath)
if err != nil {
return fmt.Errorf("opening database: %w", err)
}
defer db.Close()
// Derive encryption key using a per-installation salt stored in settings table.
// On first run, NewSalt() generates a random salt and stores it.
// On subsequent runs, the same salt is loaded -- ensuring consistent encryption.
encKey, err := loadOrCreateEncKey(db, cfg.Passphrase)
if err != nil {
return fmt.Errorf("preparing encryption key: %w", err)
}
// Run scan
ch, err := eng.Scan(context.Background(), src, scanCfg)
if err != nil {
return fmt.Errorf("starting scan: %w", err)
}
var findings []engine.Finding
for f := range ch {
findings = append(findings, f)
// Persist to storage
storeFinding := storage.Finding{
ProviderName: f.ProviderName,
KeyValue: f.KeyValue,
KeyMasked: f.KeyMasked,
Confidence: f.Confidence,
SourcePath: f.Source,
SourceType: f.SourceType,
LineNumber: f.LineNumber,
}
if _, err := db.SaveFinding(storeFinding, encKey); err != nil {
fmt.Fprintf(os.Stderr, "warning: failed to save finding: %v\n", err)
}
}
// Output
switch flagOutput {
case "json":
enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
type jsonFinding struct {
Provider string `json:"provider"`
KeyMasked string `json:"key_masked"`
Confidence string `json:"confidence"`
Source string `json:"source"`
Line int `json:"line"`
}
out := make([]jsonFinding, 0, len(findings))
for _, f := range findings {
out = append(out, jsonFinding{
Provider: f.ProviderName,
KeyMasked: f.KeyMasked,
Confidence: f.Confidence,
Source: f.Source,
Line: f.LineNumber,
})
}
if err := enc.Encode(out); err != nil {
return fmt.Errorf("encoding JSON output: %w", err)
}
default:
output.PrintFindings(findings, flagUnmask)
}
// Exit code semantics (CLI-05 / OUT-06): 0=clean, 1=found, 2=error
if len(findings) > 0 {
os.Exit(1)
}
return nil
},
}
// sourceFlags captures the CLI inputs that control source selection.
// Extracted into a struct so selectSource is straightforward to unit test.
type sourceFlags struct {
Git bool
URL string
Clipboard bool
Since string
Excludes []string
}
// selectSource inspects positional args and source flags, validates that
// exactly one source is specified, and returns the appropriate Source.
//
// Dispatch rules:
// - --url / --clipboard: no positional arg, mutually exclusive with --git and each other
// - --git <path>: uses GitSource (optionally filtered by --since=YYYY-MM-DD)
// - target == "stdin" or "-": uses StdinSource
// - target is a directory: uses DirSource (forwards --exclude patterns)
// - target is a file: uses FileSource
func selectSource(args []string, f sourceFlags) (sources.Source, error) {
// Count explicit source selectors that are mutually exclusive.
explicitCount := 0
if f.URL != "" {
explicitCount++
}
if f.Clipboard {
explicitCount++
}
if f.Git {
explicitCount++
}
if explicitCount > 1 {
return nil, fmt.Errorf("scan: --git, --url, and --clipboard are mutually exclusive")
}
// Clipboard and URL take no positional argument.
if f.Clipboard {
if len(args) > 0 {
return nil, fmt.Errorf("scan: --clipboard does not accept a positional argument")
}
return sources.NewClipboardSource(), nil
}
if f.URL != "" {
if len(args) > 0 {
return nil, fmt.Errorf("scan: --url does not accept a positional argument")
}
return sources.NewURLSource(f.URL), nil
}
if len(args) == 0 {
return nil, fmt.Errorf("scan: missing target (path, stdin, -, or a source flag)")
}
target := args[0]
if target == "stdin" || target == "-" {
if f.Git {
return nil, fmt.Errorf("scan: --git cannot be combined with stdin")
}
return sources.NewStdinSource(), nil
}
if f.Git {
gs := sources.NewGitSource(target)
if f.Since != "" {
t, err := time.Parse("2006-01-02", f.Since)
if err != nil {
return nil, fmt.Errorf("scan: --since must be YYYY-MM-DD: %w", err)
}
gs.Since = t
}
return gs, nil
}
info, err := os.Stat(target)
if err != nil {
return nil, fmt.Errorf("scan: stat %q: %w", target, err)
}
if info.IsDir() {
return sources.NewDirSource(target, f.Excludes...), nil
}
return sources.NewFileSource(target), nil
}
// loadOrCreateEncKey loads the per-installation salt from the settings table.
// On first run it generates a new random salt with storage.NewSalt() and persists it.
// The salt is hex-encoded in the settings table under key "encryption.salt".
func loadOrCreateEncKey(db *storage.DB, passphrase string) ([]byte, error) {
const saltKey = "encryption.salt"
saltHex, found, err := db.GetSetting(saltKey)
if err != nil {
return nil, fmt.Errorf("reading salt from settings: %w", err)
}
var salt []byte
if !found {
// First run: generate and persist a new random salt.
salt, err = storage.NewSalt()
if err != nil {
return nil, fmt.Errorf("generating salt: %w", err)
}
if err := db.SetSetting(saltKey, hex.EncodeToString(salt)); err != nil {
return nil, fmt.Errorf("storing salt: %w", err)
}
} else {
salt, err = hex.DecodeString(saltHex)
if err != nil {
return nil, fmt.Errorf("decoding stored salt: %w", err)
}
}
return storage.DeriveKey([]byte(passphrase), salt), nil
}
func init() {
scanCmd.Flags().IntVar(&flagWorkers, "workers", 0, "number of worker goroutines (default: CPU*8)")
scanCmd.Flags().BoolVar(&flagVerify, "verify", false, "actively verify found keys (opt-in, Phase 5)")
scanCmd.Flags().BoolVar(&flagUnmask, "unmask", false, "show full key values (default: masked)")
scanCmd.Flags().StringVar(&flagOutput, "output", "table", "output format: table, json (full JSON output in Phase 6)")
scanCmd.Flags().StringSliceVar(&flagExclude, "exclude", nil, "extra glob patterns to exclude (e.g. *.min.js)")
// Phase 4 source-selection flags.
scanCmd.Flags().BoolVar(&flagGit, "git", false, "treat target as a git repo and scan full history")
scanCmd.Flags().StringVar(&flagURL, "url", "", "fetch and scan a remote http(s) URL (no positional arg)")
scanCmd.Flags().BoolVar(&flagClipboard, "clipboard", false, "scan current clipboard contents")
scanCmd.Flags().StringVar(&flagSince, "since", "", "for --git: only scan commits after YYYY-MM-DD")
scanCmd.Flags().Int64Var(&flagMaxFileSize, "max-file-size", 0, "max file size in bytes to scan (0 = unlimited)")
scanCmd.Flags().BoolVar(&flagInsecure, "insecure", false, "for --url: skip TLS certificate verification")
_ = viper.BindPFlag("scan.workers", scanCmd.Flags().Lookup("workers"))
}