Files
keyhunter/pkg/recon/sources/huggingface_test.go
salvacybersec 45f8782464 test(10-06): add failing tests for HuggingFaceSource
- httptest server routes /api/spaces and /api/models
- assertions: enabled, both endpoints hit, URL prefixes, auth header, ctx cancel, rate-limit token mode
2026-04-06 01:15:43 +03:00

205 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package sources
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/salvacybersec/keyhunter/pkg/providers"
"github.com/salvacybersec/keyhunter/pkg/recon"
)
// hfTestRegistry builds a minimal registry with two keywords so tests assert
// an exact Finding count (2 endpoints × 2 keywords × 1 result = 4).
func hfTestRegistry(t *testing.T) *providers.Registry {
t.Helper()
return providers.NewRegistryFromProviders([]providers.Provider{
{Name: "OpenAI", Keywords: []string{"sk-proj"}},
{Name: "Anthropic", Keywords: []string{"sk-ant"}},
})
}
func hfTestServer(t *testing.T, spacesHits, modelsHits *int32, authSeen *string) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("/api/spaces", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(spacesHits, 1)
if authSeen != nil {
*authSeen = r.Header.Get("Authorization")
}
q := r.URL.Query().Get("search")
payload := []map[string]string{
{"id": fmt.Sprintf("acme/space-%s", q)},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(payload)
})
mux.HandleFunc("/api/models", func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(modelsHits, 1)
q := r.URL.Query().Get("search")
payload := []map[string]string{
{"id": fmt.Sprintf("acme/model-%s", q)},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(payload)
})
return httptest.NewServer(mux)
}
func TestHuggingFaceEnabledAlwaysTrue(t *testing.T) {
if !(&HuggingFaceSource{}).Enabled(recon.Config{}) {
t.Fatal("HuggingFace should be enabled even without token")
}
if !(&HuggingFaceSource{Token: "hf_xxx"}).Enabled(recon.Config{}) {
t.Fatal("HuggingFace should be enabled with token")
}
}
func TestHuggingFaceSweepHitsBothEndpoints(t *testing.T) {
var spacesHits, modelsHits int32
ts := hfTestServer(t, &spacesHits, &modelsHits, nil)
defer ts.Close()
reg := hfTestRegistry(t)
src := NewHuggingFaceSource(HuggingFaceConfig{
Token: "hf_test",
BaseURL: ts.URL,
Registry: reg,
Limiters: nil, // bypass rate limiter for tests
})
out := make(chan recon.Finding, 16)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := src.Sweep(ctx, "", out); err != nil {
t.Fatalf("Sweep: %v", err)
}
close(out)
findings := make([]recon.Finding, 0)
for f := range out {
findings = append(findings, f)
}
if len(findings) != 4 {
t.Fatalf("expected 4 findings, got %d", len(findings))
}
if atomic.LoadInt32(&spacesHits) != 2 {
t.Errorf("expected 2 /api/spaces hits, got %d", spacesHits)
}
if atomic.LoadInt32(&modelsHits) != 2 {
t.Errorf("expected 2 /api/models hits, got %d", modelsHits)
}
var sawSpace, sawModel bool
for _, f := range findings {
if f.SourceType != "recon:huggingface" {
t.Errorf("wrong SourceType: %q", f.SourceType)
}
switch {
case strings.HasPrefix(f.Source, "https://huggingface.co/spaces/"):
sawSpace = true
case strings.HasPrefix(f.Source, "https://huggingface.co/"):
sawModel = true
default:
t.Errorf("unexpected Source URL: %q", f.Source)
}
}
if !sawSpace || !sawModel {
t.Errorf("expected both space and model URLs; space=%v model=%v", sawSpace, sawModel)
}
}
func TestHuggingFaceAuthorizationHeader(t *testing.T) {
var authSeen string
var s, m int32
ts := hfTestServer(t, &s, &m, &authSeen)
defer ts.Close()
reg := hfTestRegistry(t)
src := NewHuggingFaceSource(HuggingFaceConfig{
Token: "hf_secret",
BaseURL: ts.URL,
Registry: reg,
Limiters: nil,
})
out := make(chan recon.Finding, 16)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := src.Sweep(ctx, "", out); err != nil {
t.Fatalf("Sweep: %v", err)
}
close(out)
for range out {
}
if authSeen != "Bearer hf_secret" {
t.Errorf("expected 'Bearer hf_secret', got %q", authSeen)
}
// Without token
authSeen = ""
var s2, m2 int32
ts2 := hfTestServer(t, &s2, &m2, &authSeen)
defer ts2.Close()
src2 := NewHuggingFaceSource(HuggingFaceConfig{
BaseURL: ts2.URL,
Registry: reg,
Limiters: nil,
})
out2 := make(chan recon.Finding, 16)
if err := src2.Sweep(ctx, "", out2); err != nil {
t.Fatalf("Sweep unauth: %v", err)
}
close(out2)
for range out2 {
}
if authSeen != "" {
t.Errorf("expected no Authorization header when token empty, got %q", authSeen)
}
}
func TestHuggingFaceContextCancellation(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
return
case <-time.After(2 * time.Second):
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("[]"))
}
}))
defer ts.Close()
reg := hfTestRegistry(t)
src := NewHuggingFaceSource(HuggingFaceConfig{
BaseURL: ts.URL,
Registry: reg,
Limiters: nil,
})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
out := make(chan recon.Finding, 16)
if err := src.Sweep(ctx, "", out); err == nil {
t.Fatal("expected error on cancelled context")
}
}
func TestHuggingFaceRateLimitTokenMode(t *testing.T) {
withTok := &HuggingFaceSource{Token: "hf_xxx"}
noTok := &HuggingFaceSource{}
if withTok.RateLimit() == noTok.RateLimit() {
t.Fatal("rate limit should differ based on token presence")
}
if withTok.RateLimit() < noTok.RateLimit() {
t.Fatalf("authenticated rate (%v) should be faster (larger) than unauth (%v)",
withTok.RateLimit(), noTok.RateLimit())
}
}