package sources import ( "context" "encoding/base64" "errors" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" "github.com/salvacybersec/keyhunter/pkg/providers" "github.com/salvacybersec/keyhunter/pkg/recon" ) func kaggleTestRegistry() *providers.Registry { return providers.NewRegistryFromProviders([]providers.Provider{ {Name: "openai", Keywords: []string{"sk-proj-"}}, }) } func newKaggleSource(t *testing.T, user, key, baseURL string) *KaggleSource { t.Helper() s := NewKaggleSource(user, key, kaggleTestRegistry(), recon.NewLimiterRegistry()) s.BaseURL = baseURL s.WebBaseURL = "https://www.kaggle.com" return s } func TestKaggle_Enabled(t *testing.T) { reg := kaggleTestRegistry() lim := recon.NewLimiterRegistry() cases := []struct { user, key string want bool }{ {"", "", false}, {"user", "", false}, {"", "key", false}, {"user", "key", true}, } for _, c := range cases { s := NewKaggleSource(c.user, c.key, reg, lim) if got := s.Enabled(recon.Config{}); got != c.want { t.Errorf("Enabled(user=%q,key=%q) = %v, want %v", c.user, c.key, got, c.want) } } } func TestKaggle_Sweep_BasicAuthAndFindings(t *testing.T) { var gotAuth string var gotQuery string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAuth = r.Header.Get("Authorization") gotQuery = r.URL.Query().Get("search") if r.URL.Query().Get("pageSize") != "50" { t.Errorf("expected pageSize=50, got %q", r.URL.Query().Get("pageSize")) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) _, _ = w.Write([]byte(`[{"ref":"alice/notebook-one","title":"one"},{"ref":"bob/notebook-two","title":"two"}]`)) })) defer srv.Close() s := newKaggleSource(t, "testuser", "testkey", srv.URL) out := make(chan recon.Finding, 8) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.Sweep(ctx, "", out); err != nil { t.Fatalf("Sweep returned error: %v", err) } close(out) if !strings.HasPrefix(gotAuth, "Basic ") { t.Fatalf("expected Basic auth header, got %q", gotAuth) } decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(gotAuth, "Basic ")) if err != nil { t.Fatalf("failed to decode Basic auth: %v", err) } if string(decoded) != "testuser:testkey" { t.Fatalf("expected credentials testuser:testkey, got %q", string(decoded)) } if gotQuery != "sk-proj-" { t.Errorf("expected search=sk-proj-, got %q", gotQuery) } var findings []recon.Finding for f := range out { findings = append(findings, f) } if len(findings) != 2 { t.Fatalf("expected 2 findings, got %d", len(findings)) } wantSources := map[string]bool{ "https://www.kaggle.com/code/alice/notebook-one": false, "https://www.kaggle.com/code/bob/notebook-two": false, } for _, f := range findings { if f.SourceType != "recon:kaggle" { t.Errorf("expected SourceType recon:kaggle, got %q", f.SourceType) } if _, ok := wantSources[f.Source]; !ok { t.Errorf("unexpected Source: %q", f.Source) } wantSources[f.Source] = true } for src, seen := range wantSources { if !seen { t.Errorf("missing expected finding source: %s", src) } } } func TestKaggle_Sweep_MissingCredentials_NoHTTP(t *testing.T) { var calls int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&calls, 1) w.WriteHeader(200) _, _ = w.Write([]byte("[]")) })) defer srv.Close() s := newKaggleSource(t, "testuser", "", srv.URL) out := make(chan recon.Finding, 1) if err := s.Sweep(context.Background(), "", out); err != nil { t.Fatalf("expected nil error for missing key, got %v", err) } close(out) s2 := newKaggleSource(t, "", "testkey", srv.URL) out2 := make(chan recon.Finding, 1) if err := s2.Sweep(context.Background(), "", out2); err != nil { t.Fatalf("expected nil error for missing user, got %v", err) } close(out2) if n := atomic.LoadInt32(&calls); n != 0 { t.Fatalf("expected 0 HTTP calls when credentials missing, got %d", n) } } func TestKaggle_Sweep_Unauthorized(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) _, _ = w.Write([]byte("bad creds")) })) defer srv.Close() s := newKaggleSource(t, "testuser", "testkey", srv.URL) out := make(chan recon.Finding, 1) err := s.Sweep(context.Background(), "", out) if err == nil { t.Fatal("expected error on 401") } if !errors.Is(err, ErrUnauthorized) { t.Fatalf("expected ErrUnauthorized, got %v", err) } } func TestKaggle_Sweep_CtxCancellation(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) w.WriteHeader(200) _, _ = w.Write([]byte("[]")) })) defer srv.Close() s := newKaggleSource(t, "testuser", "testkey", srv.URL) ctx, cancel := context.WithCancel(context.Background()) cancel() out := make(chan recon.Finding, 1) err := s.Sweep(ctx, "", out) if err == nil { t.Fatal("expected error from cancelled context") } if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } } func TestKaggle_ReconSourceInterface(t *testing.T) { var _ recon.ReconSource = (*KaggleSource)(nil) s := NewKaggleSource("u", "k", nil, nil) if s.Name() != "kaggle" { t.Errorf("Name = %q, want kaggle", s.Name()) } if s.Burst() != 1 { t.Errorf("Burst = %d, want 1", s.Burst()) } if s.RespectsRobots() { t.Error("RespectsRobots should be false") } if s.RateLimit() <= 0 { t.Error("RateLimit should be > 0") } }