merge: plan 01-03 storage layer
This commit is contained in:
32
pkg/storage/crypto.go
Normal file
32
pkg/storage/crypto.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"golang.org/x/crypto/argon2"
|
||||
)
|
||||
|
||||
const (
|
||||
argon2Time uint32 = 1
|
||||
argon2Memory uint32 = 64 * 1024 // 64 MB — RFC 9106 Section 7.3
|
||||
argon2Threads uint8 = 4
|
||||
argon2KeyLen uint32 = 32 // AES-256 key length
|
||||
saltSize = 16
|
||||
)
|
||||
|
||||
// DeriveKey produces a 32-byte AES-256 key from a passphrase and salt using Argon2id.
|
||||
// Uses RFC 9106 Section 7.3 recommended parameters.
|
||||
// Given the same passphrase and salt, always returns the same key.
|
||||
func DeriveKey(passphrase []byte, salt []byte) []byte {
|
||||
return argon2.IDKey(passphrase, salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)
|
||||
}
|
||||
|
||||
// NewSalt generates a cryptographically random 16-byte salt.
|
||||
// Store alongside the database and reuse on each key derivation.
|
||||
func NewSalt() ([]byte, error) {
|
||||
salt := make([]byte, saltSize)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return salt, nil
|
||||
}
|
||||
57
pkg/storage/db.go
Normal file
57
pkg/storage/db.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
var schemaSQLBytes []byte
|
||||
|
||||
// DB wraps the sql.DB connection with KeyHunter-specific behavior.
|
||||
type DB struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// Open opens or creates a SQLite database at path, runs embedded schema migrations,
|
||||
// and enables WAL mode for better concurrent read performance.
|
||||
// Use ":memory:" for tests.
|
||||
func Open(path string) (*DB, error) {
|
||||
sqlDB, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening database: %w", err)
|
||||
}
|
||||
|
||||
// Enable WAL mode for concurrent reads
|
||||
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("enabling WAL mode: %w", err)
|
||||
}
|
||||
|
||||
// Enable foreign keys
|
||||
if _, err := sqlDB.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("enabling foreign keys: %w", err)
|
||||
}
|
||||
|
||||
// Run schema migrations
|
||||
if _, err := sqlDB.Exec(string(schemaSQLBytes)); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("running schema migrations: %w", err)
|
||||
}
|
||||
|
||||
return &DB{sql: sqlDB}, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection.
|
||||
func (db *DB) Close() error {
|
||||
return db.sql.Close()
|
||||
}
|
||||
|
||||
// SQL returns the underlying sql.DB for advanced use cases.
|
||||
func (db *DB) SQL() *sql.DB {
|
||||
return db.sql
|
||||
}
|
||||
@@ -1,23 +1,127 @@
|
||||
package storage_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/salvacybersec/keyhunter/pkg/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDBOpen verifies SQLite database opens and creates schema.
|
||||
// Stub: will be implemented when db.go exists (Plan 03).
|
||||
func TestDBOpen(t *testing.T) {
|
||||
t.Skip("stub — implement after db.go exists")
|
||||
db, err := storage.Open(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Verify schema tables exist
|
||||
rows, err := db.SQL().Query("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
require.NoError(t, rows.Scan(&name))
|
||||
tables = append(tables, name)
|
||||
}
|
||||
assert.Contains(t, tables, "findings")
|
||||
assert.Contains(t, tables, "scans")
|
||||
assert.Contains(t, tables, "settings")
|
||||
}
|
||||
|
||||
// TestEncryptDecryptRoundtrip verifies AES-256-GCM encrypt/decrypt roundtrip.
|
||||
// Stub: will be implemented when encrypt.go exists (Plan 03).
|
||||
func TestEncryptDecryptRoundtrip(t *testing.T) {
|
||||
t.Skip("stub — implement after encrypt.go exists")
|
||||
key := make([]byte, 32) // all-zero key for test
|
||||
for i := range key {
|
||||
key[i] = byte(i)
|
||||
}
|
||||
plaintext := []byte("sk-proj-supersecretapikey1234")
|
||||
|
||||
ciphertext, err := storage.Encrypt(plaintext, key)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, len(ciphertext), len(plaintext), "ciphertext should be longer than plaintext")
|
||||
|
||||
recovered, err := storage.Decrypt(ciphertext, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext, recovered)
|
||||
}
|
||||
|
||||
// TestArgon2KeyDerivation verifies Argon2id produces 32-byte key deterministically.
|
||||
// Stub: will be implemented when crypto.go exists (Plan 03).
|
||||
func TestArgon2KeyDerivation(t *testing.T) {
|
||||
t.Skip("stub — implement after crypto.go exists")
|
||||
func TestEncryptNonDeterministic(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
plain := []byte("test-key")
|
||||
ct1, err1 := storage.Encrypt(plain, key)
|
||||
ct2, err2 := storage.Encrypt(plain, key)
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
assert.NotEqual(t, ct1, ct2, "same plaintext encrypted twice should produce different ciphertext")
|
||||
}
|
||||
|
||||
func TestDecryptWrongKey(t *testing.T) {
|
||||
key1 := make([]byte, 32)
|
||||
key2 := make([]byte, 32)
|
||||
key2[0] = 0xFF
|
||||
|
||||
ct, err := storage.Encrypt([]byte("secret"), key1)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = storage.Decrypt(ct, key2)
|
||||
assert.Error(t, err, "decryption with wrong key should fail")
|
||||
}
|
||||
|
||||
func TestArgon2KeyDerivation(t *testing.T) {
|
||||
passphrase := []byte("my-secure-passphrase")
|
||||
salt := []byte("1234567890abcdef") // 16 bytes
|
||||
|
||||
key1 := storage.DeriveKey(passphrase, salt)
|
||||
key2 := storage.DeriveKey(passphrase, salt)
|
||||
|
||||
assert.Equal(t, 32, len(key1), "derived key must be 32 bytes")
|
||||
assert.Equal(t, key1, key2, "same passphrase+salt must produce same key")
|
||||
}
|
||||
|
||||
func TestNewSalt(t *testing.T) {
|
||||
salt1, err1 := storage.NewSalt()
|
||||
salt2, err2 := storage.NewSalt()
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
assert.Equal(t, 16, len(salt1))
|
||||
assert.NotEqual(t, salt1, salt2, "two salts should differ")
|
||||
}
|
||||
|
||||
func TestSaveFindingEncrypted(t *testing.T) {
|
||||
db, err := storage.Open(":memory:")
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Derive a test key
|
||||
key := storage.DeriveKey([]byte("testpassphrase"), []byte("testsalt1234xxxx"))
|
||||
|
||||
plainKey := "sk-proj-test1234567890abcdefghijklmnopqr"
|
||||
f := storage.Finding{
|
||||
ProviderName: "openai",
|
||||
KeyValue: plainKey,
|
||||
Confidence: "high",
|
||||
SourcePath: "/test/file.env",
|
||||
SourceType: "file",
|
||||
LineNumber: 42,
|
||||
}
|
||||
|
||||
id, err := db.SaveFinding(f, key)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, id, int64(0))
|
||||
|
||||
findings, err := db.ListFindings(key)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, findings, 1)
|
||||
assert.Equal(t, plainKey, findings[0].KeyValue)
|
||||
assert.Equal(t, "openai", findings[0].ProviderName)
|
||||
// Verify masking
|
||||
assert.Equal(t, "sk-proj-...opqr", findings[0].KeyMasked)
|
||||
|
||||
// Verify encryption contract: raw BLOB bytes in the database must NOT contain the plaintext key.
|
||||
// This confirms Encrypt() was called before INSERT, not that the key was stored verbatim.
|
||||
var rawBlob []byte
|
||||
require.NoError(t, db.SQL().QueryRow("SELECT key_value FROM findings WHERE id = ?", id).Scan(&rawBlob))
|
||||
assert.False(t, bytes.Contains(rawBlob, []byte(plainKey)),
|
||||
"raw database BLOB must not contain plaintext key — encryption was not applied")
|
||||
}
|
||||
|
||||
52
pkg/storage/encrypt.go
Normal file
52
pkg/storage/encrypt.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrCiphertextTooShort is returned when ciphertext is shorter than the GCM nonce size.
|
||||
var ErrCiphertextTooShort = errors.New("ciphertext too short")
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM with a random nonce.
|
||||
// The nonce is prepended to the returned ciphertext.
|
||||
// key must be exactly 32 bytes (AES-256).
|
||||
func Encrypt(plaintext []byte, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Seal appends encrypted data to nonce, so nonce is prepended
|
||||
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext produced by Encrypt.
|
||||
// Expects the nonce to be prepended to the ciphertext.
|
||||
func Decrypt(ciphertext []byte, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return nil, ErrCiphertextTooShort
|
||||
}
|
||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
return gcm.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
||||
103
pkg/storage/findings.go
Normal file
103
pkg/storage/findings.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Finding represents a detected API key with metadata.
|
||||
// KeyValue is always plaintext in this struct — encryption happens at the storage boundary.
|
||||
type Finding struct {
|
||||
ID int64
|
||||
ScanID int64
|
||||
ProviderName string
|
||||
KeyValue string // plaintext — encrypted before storage, decrypted after retrieval
|
||||
KeyMasked string // first8...last4, stored plaintext
|
||||
Confidence string
|
||||
SourcePath string
|
||||
SourceType string
|
||||
LineNumber int
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// MaskKey returns the masked form of a key: first 8 chars + "..." + last 4 chars.
|
||||
// If the key is too short (< 12 chars), returns the full key masked with asterisks.
|
||||
func MaskKey(key string) string {
|
||||
if len(key) < 12 {
|
||||
return "****"
|
||||
}
|
||||
return key[:8] + "..." + key[len(key)-4:]
|
||||
}
|
||||
|
||||
// SaveFinding encrypts the finding's KeyValue and persists the finding to the database.
|
||||
// encKey must be a 32-byte AES-256 key (from DeriveKey).
|
||||
func (db *DB) SaveFinding(f Finding, encKey []byte) (int64, error) {
|
||||
encrypted, err := Encrypt([]byte(f.KeyValue), encKey)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("encrypting key value: %w", err)
|
||||
}
|
||||
|
||||
masked := f.KeyMasked
|
||||
if masked == "" {
|
||||
masked = MaskKey(f.KeyValue)
|
||||
}
|
||||
|
||||
// Use NULL for scan_id when not set (zero value) to satisfy FK constraint
|
||||
var scanID interface{}
|
||||
if f.ScanID != 0 {
|
||||
scanID = f.ScanID
|
||||
} else {
|
||||
scanID = sql.NullInt64{}
|
||||
}
|
||||
|
||||
res, err := db.sql.Exec(
|
||||
`INSERT INTO findings (scan_id, provider_name, key_value, key_masked, confidence, source_path, source_type, line_number)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
scanID, f.ProviderName, encrypted, masked, f.Confidence, f.SourcePath, f.SourceType, f.LineNumber,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("inserting finding: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// ListFindings retrieves all findings, decrypting key values using encKey.
|
||||
// encKey must be the same 32-byte key used during SaveFinding.
|
||||
func (db *DB) ListFindings(encKey []byte) ([]Finding, error) {
|
||||
rows, err := db.sql.Query(
|
||||
`SELECT id, scan_id, provider_name, key_value, key_masked, confidence,
|
||||
source_path, source_type, line_number, created_at
|
||||
FROM findings ORDER BY created_at DESC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying findings: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var findings []Finding
|
||||
for rows.Next() {
|
||||
var f Finding
|
||||
var encrypted []byte
|
||||
var createdAt string
|
||||
var scanID sql.NullInt64
|
||||
err := rows.Scan(
|
||||
&f.ID, &scanID, &f.ProviderName, &encrypted, &f.KeyMasked,
|
||||
&f.Confidence, &f.SourcePath, &f.SourceType, &f.LineNumber, &createdAt,
|
||||
)
|
||||
if scanID.Valid {
|
||||
f.ScanID = scanID.Int64
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scanning finding row: %w", err)
|
||||
}
|
||||
plain, err := Decrypt(encrypted, encKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting finding %d: %w", f.ID, err)
|
||||
}
|
||||
f.KeyValue = string(plain)
|
||||
f.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
findings = append(findings, f)
|
||||
}
|
||||
return findings, rows.Err()
|
||||
}
|
||||
35
pkg/storage/schema.sql
Normal file
35
pkg/storage/schema.sql
Normal file
@@ -0,0 +1,35 @@
|
||||
-- KeyHunter database schema
|
||||
-- Version: 1
|
||||
|
||||
CREATE TABLE IF NOT EXISTS scans (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
started_at DATETIME NOT NULL,
|
||||
finished_at DATETIME,
|
||||
source_path TEXT,
|
||||
finding_count INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS findings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
scan_id INTEGER REFERENCES scans(id),
|
||||
provider_name TEXT NOT NULL,
|
||||
key_value BLOB NOT NULL,
|
||||
key_masked TEXT NOT NULL,
|
||||
confidence TEXT NOT NULL,
|
||||
source_path TEXT,
|
||||
source_type TEXT,
|
||||
line_number INTEGER,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Indexes for common queries
|
||||
CREATE INDEX IF NOT EXISTS idx_findings_scan_id ON findings(scan_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_findings_provider ON findings(provider_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_findings_created ON findings(created_at DESC);
|
||||
Reference in New Issue
Block a user