package sources import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" "github.com/salvacybersec/keyhunter/pkg/providers" "github.com/salvacybersec/keyhunter/pkg/recon" ) // hfTestRegistry builds a minimal registry with two keywords so tests assert // an exact Finding count (2 endpoints × 2 keywords × 1 result = 4). func hfTestRegistry(t *testing.T) *providers.Registry { t.Helper() return providers.NewRegistryFromProviders([]providers.Provider{ {Name: "OpenAI", Keywords: []string{"sk-proj"}}, {Name: "Anthropic", Keywords: []string{"sk-ant"}}, }) } func hfTestServer(t *testing.T, spacesHits, modelsHits *int32, authSeen *string) *httptest.Server { t.Helper() mux := http.NewServeMux() mux.HandleFunc("/api/spaces", func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(spacesHits, 1) if authSeen != nil { *authSeen = r.Header.Get("Authorization") } q := r.URL.Query().Get("search") payload := []map[string]string{ {"id": fmt.Sprintf("acme/space-%s", q)}, } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(payload) }) mux.HandleFunc("/api/models", func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(modelsHits, 1) q := r.URL.Query().Get("search") payload := []map[string]string{ {"id": fmt.Sprintf("acme/model-%s", q)}, } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(payload) }) return httptest.NewServer(mux) } func TestHuggingFaceEnabledAlwaysTrue(t *testing.T) { if !(&HuggingFaceSource{}).Enabled(recon.Config{}) { t.Fatal("HuggingFace should be enabled even without token") } if !(&HuggingFaceSource{Token: "hf_xxx"}).Enabled(recon.Config{}) { t.Fatal("HuggingFace should be enabled with token") } } func TestHuggingFaceSweepHitsBothEndpoints(t *testing.T) { var spacesHits, modelsHits int32 ts := hfTestServer(t, &spacesHits, &modelsHits, nil) defer ts.Close() reg := hfTestRegistry(t) src := NewHuggingFaceSource(HuggingFaceConfig{ Token: "hf_test", BaseURL: ts.URL, Registry: reg, Limiters: nil, // bypass rate limiter for tests }) out := make(chan recon.Finding, 16) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := src.Sweep(ctx, "", out); err != nil { t.Fatalf("Sweep: %v", err) } close(out) findings := make([]recon.Finding, 0) for f := range out { findings = append(findings, f) } if len(findings) != 4 { t.Fatalf("expected 4 findings, got %d", len(findings)) } if atomic.LoadInt32(&spacesHits) != 2 { t.Errorf("expected 2 /api/spaces hits, got %d", spacesHits) } if atomic.LoadInt32(&modelsHits) != 2 { t.Errorf("expected 2 /api/models hits, got %d", modelsHits) } var sawSpace, sawModel bool for _, f := range findings { if f.SourceType != "recon:huggingface" { t.Errorf("wrong SourceType: %q", f.SourceType) } switch { case strings.HasPrefix(f.Source, "https://huggingface.co/spaces/"): sawSpace = true case strings.HasPrefix(f.Source, "https://huggingface.co/"): sawModel = true default: t.Errorf("unexpected Source URL: %q", f.Source) } } if !sawSpace || !sawModel { t.Errorf("expected both space and model URLs; space=%v model=%v", sawSpace, sawModel) } } func TestHuggingFaceAuthorizationHeader(t *testing.T) { var authSeen string var s, m int32 ts := hfTestServer(t, &s, &m, &authSeen) defer ts.Close() reg := hfTestRegistry(t) src := NewHuggingFaceSource(HuggingFaceConfig{ Token: "hf_secret", BaseURL: ts.URL, Registry: reg, Limiters: nil, }) out := make(chan recon.Finding, 16) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := src.Sweep(ctx, "", out); err != nil { t.Fatalf("Sweep: %v", err) } close(out) for range out { } if authSeen != "Bearer hf_secret" { t.Errorf("expected 'Bearer hf_secret', got %q", authSeen) } // Without token authSeen = "" var s2, m2 int32 ts2 := hfTestServer(t, &s2, &m2, &authSeen) defer ts2.Close() src2 := NewHuggingFaceSource(HuggingFaceConfig{ BaseURL: ts2.URL, Registry: reg, Limiters: nil, }) out2 := make(chan recon.Finding, 16) if err := src2.Sweep(ctx, "", out2); err != nil { t.Fatalf("Sweep unauth: %v", err) } close(out2) for range out2 { } if authSeen != "" { t.Errorf("expected no Authorization header when token empty, got %q", authSeen) } } func TestHuggingFaceContextCancellation(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-r.Context().Done(): return case <-time.After(2 * time.Second): w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("[]")) } })) defer ts.Close() reg := hfTestRegistry(t) src := NewHuggingFaceSource(HuggingFaceConfig{ BaseURL: ts.URL, Registry: reg, Limiters: nil, }) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() out := make(chan recon.Finding, 16) if err := src.Sweep(ctx, "", out); err == nil { t.Fatal("expected error on cancelled context") } } func TestHuggingFaceRateLimitTokenMode(t *testing.T) { withTok := &HuggingFaceSource{Token: "hf_xxx"} noTok := &HuggingFaceSource{} if withTok.RateLimit() == noTok.RateLimit() { t.Fatal("rate limit should differ based on token presence") } if withTok.RateLimit() < noTok.RateLimit() { t.Fatalf("authenticated rate (%v) should be faster (larger) than unauth (%v)", withTok.RateLimit(), noTok.RateLimit()) } }