- TestVerifyAll_MultipleFindings: 5 findings via 3-worker pool - TestVerifyAll_MissingProvider: unknown provider yields StatusUnknown - TestVerifyAll_ContextCancellation: cancellation closes channel early - Add providers.NewRegistryFromProviders test helper
274 lines
8.4 KiB
Go
274 lines
8.4 KiB
Go
package verify
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"regexp"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/salvacybersec/keyhunter/pkg/engine"
|
|
"github.com/salvacybersec/keyhunter/pkg/providers"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// newTestVerifier builds an HTTPVerifier whose transport trusts the given
|
|
// httptest TLS server's self-signed cert.
|
|
func newTestVerifier(t *testing.T, srv *httptest.Server, timeout time.Duration) *HTTPVerifier {
|
|
t.Helper()
|
|
v := NewHTTPVerifier(timeout)
|
|
v.Client = srv.Client()
|
|
v.Client.Timeout = timeout
|
|
return v
|
|
}
|
|
|
|
func testFinding(key string) engine.Finding {
|
|
return engine.Finding{
|
|
ProviderName: "testprov",
|
|
KeyValue: key,
|
|
KeyMasked: engine.MaskKey(key + "padding1234"),
|
|
}
|
|
}
|
|
|
|
func testProvider(spec providers.VerifySpec) providers.Provider {
|
|
return providers.Provider{Name: "testprov", Verify: spec}
|
|
}
|
|
|
|
func TestVerify_Live_200(t *testing.T) {
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
p := testProvider(providers.VerifySpec{URL: srv.URL, Method: "GET"})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusLive, res.Status)
|
|
assert.Equal(t, 200, res.HTTPCode)
|
|
}
|
|
|
|
func TestVerify_Dead_401(t *testing.T) {
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
p := testProvider(providers.VerifySpec{URL: srv.URL})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusDead, res.Status)
|
|
assert.Equal(t, 401, res.HTTPCode)
|
|
}
|
|
|
|
func TestVerify_RateLimited_429_WithRetryAfter(t *testing.T) {
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Retry-After", "30")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
p := testProvider(providers.VerifySpec{URL: srv.URL})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusRateLimited, res.Status)
|
|
assert.Equal(t, 30*time.Second, res.RetryAfter)
|
|
}
|
|
|
|
func TestVerify_MetadataExtraction(t *testing.T) {
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"organization":{"name":"Acme"},"tier":"plus"}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
p := testProvider(providers.VerifySpec{
|
|
URL: srv.URL,
|
|
MetadataPaths: map[string]string{"org": "organization.name", "tier": "tier"},
|
|
})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
require.Equal(t, StatusLive, res.Status)
|
|
assert.Equal(t, "Acme", res.Metadata["org"])
|
|
assert.Equal(t, "plus", res.Metadata["tier"])
|
|
}
|
|
|
|
func TestVerify_KeySubstitution_InHeader(t *testing.T) {
|
|
var gotAuth string
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotAuth = r.Header.Get("Authorization")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
p := testProvider(providers.VerifySpec{
|
|
URL: srv.URL,
|
|
Headers: map[string]string{"Authorization": "Bearer {{KEY}}"},
|
|
})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusLive, res.Status)
|
|
assert.Equal(t, "Bearer sk-test-keyvalue", gotAuth)
|
|
}
|
|
|
|
func TestVerify_KeySubstitution_InBody(t *testing.T) {
|
|
var gotBody string
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
b, _ := io.ReadAll(r.Body)
|
|
gotBody = string(b)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
p := testProvider(providers.VerifySpec{
|
|
URL: srv.URL,
|
|
Method: "POST",
|
|
Body: `{"api_key":"{{KEY}}"}`,
|
|
})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusLive, res.Status)
|
|
assert.Equal(t, `{"api_key":"sk-test-keyvalue"}`, gotBody)
|
|
}
|
|
|
|
func TestVerify_KeySubstitution_InURL(t *testing.T) {
|
|
var gotKey string
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotKey = r.URL.Query().Get("key")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
p := testProvider(providers.VerifySpec{URL: srv.URL + "/v1/models?key={{KEY}}"})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusLive, res.Status)
|
|
assert.Equal(t, "sk-test-keyvalue", gotKey)
|
|
}
|
|
|
|
func TestVerify_MissingURL_Unknown(t *testing.T) {
|
|
v := NewHTTPVerifier(5 * time.Second)
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), testProvider(providers.VerifySpec{}))
|
|
assert.Equal(t, StatusUnknown, res.Status)
|
|
}
|
|
|
|
func TestVerify_HTTPRejected(t *testing.T) {
|
|
v := NewHTTPVerifier(5 * time.Second)
|
|
p := testProvider(providers.VerifySpec{URL: "http://example.com/verify"})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusError, res.Status)
|
|
assert.True(t, strings.Contains(strings.ToLower(res.Error), "https"), "error should mention HTTPS: %q", res.Error)
|
|
}
|
|
|
|
func TestVerifyAll_MultipleFindings(t *testing.T) {
|
|
var hits int32
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddInt32(&hits, 1)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
prov := testProvider(providers.VerifySpec{URL: srv.URL})
|
|
reg := providers.NewRegistryFromProviders([]providers.Provider{prov})
|
|
|
|
var findings []engine.Finding
|
|
for i := 0; i < 5; i++ {
|
|
findings = append(findings, testFinding("sk-test-keyvalue"))
|
|
}
|
|
|
|
out := v.VerifyAll(context.Background(), findings, reg, 3)
|
|
var got int
|
|
for res := range out {
|
|
assert.Equal(t, StatusLive, res.Status)
|
|
got++
|
|
}
|
|
assert.Equal(t, 5, got)
|
|
assert.Equal(t, int32(5), atomic.LoadInt32(&hits))
|
|
}
|
|
|
|
func TestVerifyAll_MissingProvider(t *testing.T) {
|
|
v := NewHTTPVerifier(5 * time.Second)
|
|
reg := providers.NewRegistryFromProviders(nil)
|
|
|
|
f := engine.Finding{ProviderName: "nonexistent", KeyValue: "x", KeyMasked: "****"}
|
|
out := v.VerifyAll(context.Background(), []engine.Finding{f}, reg, 2)
|
|
|
|
res, ok := <-out
|
|
require.True(t, ok, "expected a result before close")
|
|
assert.Equal(t, StatusUnknown, res.Status)
|
|
assert.Contains(t, strings.ToLower(res.Error), "not found")
|
|
|
|
_, more := <-out
|
|
assert.False(t, more, "channel should close after single result")
|
|
}
|
|
|
|
func TestVerifyAll_ContextCancellation(t *testing.T) {
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
select {
|
|
case <-time.After(100 * time.Millisecond):
|
|
case <-r.Context().Done():
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 5*time.Second)
|
|
prov := testProvider(providers.VerifySpec{URL: srv.URL})
|
|
reg := providers.NewRegistryFromProviders([]providers.Provider{prov})
|
|
|
|
findings := make([]engine.Finding, 100)
|
|
for i := range findings {
|
|
findings[i] = testFinding("sk-test-keyvalue")
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
out := v.VerifyAll(ctx, findings, reg, 4)
|
|
time.AfterFunc(50*time.Millisecond, cancel)
|
|
|
|
done := make(chan int, 1)
|
|
go func() {
|
|
n := 0
|
|
for range out {
|
|
n++
|
|
}
|
|
done <- n
|
|
}()
|
|
|
|
select {
|
|
case n := <-done:
|
|
assert.Less(t, n, 100, "expected partial results due to cancellation, got %d", n)
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("channel did not close within 3s after cancellation")
|
|
}
|
|
}
|
|
|
|
func TestVerify_Timeout(t *testing.T) {
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(300 * time.Millisecond)
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
v := newTestVerifier(t, srv, 50*time.Millisecond)
|
|
p := testProvider(providers.VerifySpec{URL: srv.URL})
|
|
res := v.Verify(context.Background(), testFinding("sk-test-keyvalue"), p)
|
|
|
|
assert.Equal(t, StatusError, res.Status)
|
|
assert.True(t, regexp.MustCompile(`(?i)timeout|deadline|canceled`).MatchString(res.Error),
|
|
"expected timeout-like error, got %q", res.Error)
|
|
}
|