diff --git a/pkg/providers/registry.go b/pkg/providers/registry.go index e7696c9..a71c098 100644 --- a/pkg/providers/registry.go +++ b/pkg/providers/registry.go @@ -39,6 +39,21 @@ func NewRegistry() (*Registry, error) { }, nil } +// NewRegistryFromProviders builds a Registry from an explicit slice of providers +// without touching the embedded YAML files. Intended for tests that need a +// minimal registry with synthetic providers. +func NewRegistryFromProviders(ps []Provider) *Registry { + index := make(map[string]int, len(ps)) + var keywords []string + for i, p := range ps { + index[p.Name] = i + keywords = append(keywords, p.Keywords...) + } + builder := ahocorasick.NewAhoCorasickBuilder(ahocorasick.Opts{DFA: true}) + ac := builder.Build(keywords) + return &Registry{providers: ps, index: index, ac: ac} +} + // List returns all loaded providers. func (r *Registry) List() []Provider { return r.providers diff --git a/pkg/verify/verifier_test.go b/pkg/verify/verifier_test.go index ba81245..9ae379e 100644 --- a/pkg/verify/verifier_test.go +++ b/pkg/verify/verifier_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "regexp" "strings" + "sync/atomic" "testing" "time" @@ -172,6 +173,89 @@ func TestVerify_HTTPRejected(t *testing.T) { assert.True(t, strings.Contains(strings.ToLower(res.Error), "https"), "error should mention HTTPS: %q", res.Error) } +func TestVerifyAll_MultipleFindings(t *testing.T) { + var hits int32 + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + v := newTestVerifier(t, srv, 5*time.Second) + prov := testProvider(providers.VerifySpec{URL: srv.URL}) + reg := providers.NewRegistryFromProviders([]providers.Provider{prov}) + + var findings []engine.Finding + for i := 0; i < 5; i++ { + findings = append(findings, testFinding("sk-test-keyvalue")) + } + + out := v.VerifyAll(context.Background(), findings, reg, 3) + var got int + for res := range out { + assert.Equal(t, StatusLive, res.Status) + got++ + } + assert.Equal(t, 5, got) + assert.Equal(t, int32(5), atomic.LoadInt32(&hits)) +} + +func TestVerifyAll_MissingProvider(t *testing.T) { + v := NewHTTPVerifier(5 * time.Second) + reg := providers.NewRegistryFromProviders(nil) + + f := engine.Finding{ProviderName: "nonexistent", KeyValue: "x", KeyMasked: "****"} + out := v.VerifyAll(context.Background(), []engine.Finding{f}, reg, 2) + + res, ok := <-out + require.True(t, ok, "expected a result before close") + assert.Equal(t, StatusUnknown, res.Status) + assert.Contains(t, strings.ToLower(res.Error), "not found") + + _, more := <-out + assert.False(t, more, "channel should close after single result") +} + +func TestVerifyAll_ContextCancellation(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-time.After(100 * time.Millisecond): + case <-r.Context().Done(): + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + v := newTestVerifier(t, srv, 5*time.Second) + prov := testProvider(providers.VerifySpec{URL: srv.URL}) + reg := providers.NewRegistryFromProviders([]providers.Provider{prov}) + + findings := make([]engine.Finding, 100) + for i := range findings { + findings[i] = testFinding("sk-test-keyvalue") + } + + ctx, cancel := context.WithCancel(context.Background()) + out := v.VerifyAll(ctx, findings, reg, 4) + time.AfterFunc(50*time.Millisecond, cancel) + + done := make(chan int, 1) + go func() { + n := 0 + for range out { + n++ + } + done <- n + }() + + select { + case n := <-done: + assert.Less(t, n, 100, "expected partial results due to cancellation, got %d", n) + case <-time.After(3 * time.Second): + t.Fatal("channel did not close within 3s after cancellation") + } +} + func TestVerify_Timeout(t *testing.T) { srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(300 * time.Millisecond)