diff --git a/pkg/storage/queries_test.go b/pkg/storage/queries_test.go new file mode 100644 index 0000000..d8ae8ad --- /dev/null +++ b/pkg/storage/queries_test.go @@ -0,0 +1,149 @@ +package storage_test + +import ( + "database/sql" + "errors" + "testing" + + "github.com/salvacybersec/keyhunter/pkg/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// seedQueryFindings inserts three findings across two providers with mixed +// verified flags. Returns the DB, encryption key, and the saved IDs in +// insertion order: [openai-live, openai-dead, anthropic-live]. +func seedQueryFindings(t *testing.T) (*storage.DB, []byte, []int64) { + t.Helper() + db, err := storage.Open(":memory:") + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + encKey := makeTestKey() + + seeds := []storage.Finding{ + { + ProviderName: "openai", + KeyValue: "sk-proj-openai-live-1234567890", + Confidence: "high", + SourcePath: "/tmp/a.env", + SourceType: "file", + LineNumber: 1, + Verified: true, + VerifyStatus: "live", + VerifyHTTPCode: 200, + VerifyMetadata: map[string]string{"org": "acme"}, + }, + { + ProviderName: "openai", + KeyValue: "sk-proj-openai-dead-abcdefghij", + Confidence: "medium", + SourcePath: "/tmp/b.env", + SourceType: "file", + LineNumber: 2, + Verified: false, + }, + { + ProviderName: "anthropic", + KeyValue: "sk-ant-api03-livekey-abcdefghij", + Confidence: "high", + SourcePath: "/tmp/c.yaml", + SourceType: "file", + LineNumber: 3, + Verified: true, + VerifyStatus: "live", + VerifyHTTPCode: 200, + }, + } + + ids := make([]int64, 0, len(seeds)) + for _, f := range seeds { + id, err := db.SaveFinding(f, encKey) + require.NoError(t, err) + ids = append(ids, id) + } + return db, encKey, ids +} + +func TestListFindingsFiltered_ByProvider(t *testing.T) { + db, encKey, _ := seedQueryFindings(t) + + out, err := db.ListFindingsFiltered(encKey, storage.Filters{Provider: "openai"}) + require.NoError(t, err) + require.Len(t, out, 2) + for _, f := range out { + assert.Equal(t, "openai", f.ProviderName) + } +} + +func TestListFindingsFiltered_Verified(t *testing.T) { + db, encKey, _ := seedQueryFindings(t) + + verifiedTrue := true + out, err := db.ListFindingsFiltered(encKey, storage.Filters{Verified: &verifiedTrue}) + require.NoError(t, err) + require.Len(t, out, 2) + for _, f := range out { + assert.True(t, f.Verified, "expected only verified findings") + } + + verifiedFalse := false + out, err = db.ListFindingsFiltered(encKey, storage.Filters{Verified: &verifiedFalse}) + require.NoError(t, err) + require.Len(t, out, 1) + assert.False(t, out[0].Verified) +} + +func TestListFindingsFiltered_Pagination(t *testing.T) { + db, encKey, _ := seedQueryFindings(t) + + // Unpaginated baseline — should return all 3. + all, err := db.ListFindingsFiltered(encKey, storage.Filters{}) + require.NoError(t, err) + require.Len(t, all, 3) + + // Limit=1 Offset=1 returns the second row from the baseline order. + page, err := db.ListFindingsFiltered(encKey, storage.Filters{Limit: 1, Offset: 1}) + require.NoError(t, err) + require.Len(t, page, 1) + assert.Equal(t, all[1].ID, page[0].ID) +} + +func TestGetFinding_Hit(t *testing.T) { + db, encKey, ids := seedQueryFindings(t) + + f, err := db.GetFinding(ids[0], encKey) + require.NoError(t, err) + require.NotNil(t, f) + assert.Equal(t, "openai", f.ProviderName) + assert.Equal(t, "sk-proj-openai-live-1234567890", f.KeyValue) + assert.True(t, f.Verified) +} + +func TestGetFinding_Miss(t *testing.T) { + db, encKey, _ := seedQueryFindings(t) + + f, err := db.GetFinding(9999, encKey) + assert.Nil(t, f) + assert.True(t, errors.Is(err, sql.ErrNoRows), "expected sql.ErrNoRows, got %v", err) +} + +func TestDeleteFinding_Hit(t *testing.T) { + db, encKey, ids := seedQueryFindings(t) + + n, err := db.DeleteFinding(ids[1]) + require.NoError(t, err) + assert.Equal(t, int64(1), n) + + f, err := db.GetFinding(ids[1], encKey) + assert.Nil(t, f) + assert.True(t, errors.Is(err, sql.ErrNoRows)) +} + +func TestDeleteFinding_Miss(t *testing.T) { + db, _, _ := seedQueryFindings(t) + + n, err := db.DeleteFinding(9999) + require.NoError(t, err) + assert.Equal(t, int64(0), n) +}