package sources import ( "context" "errors" "io" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" ) func newTestClient() *Client { c := NewClient() c.MaxRetries = 2 return c } func TestClient_Do_OKPassThrough(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) _, _ = w.Write([]byte("hello")) })) defer srv.Close() req, _ := http.NewRequest("GET", srv.URL, nil) resp, err := newTestClient().Do(context.Background(), req) if err != nil { t.Fatalf("expected nil err, got %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if string(body) != "hello" { t.Fatalf("unexpected body: %q", body) } } func TestClient_Do_RetryOn429(t *testing.T) { var calls int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&calls, 1) if n == 1 { w.Header().Set("Retry-After", "1") w.WriteHeader(429) _, _ = w.Write([]byte("slow down")) return } w.WriteHeader(200) _, _ = w.Write([]byte("ok")) })) defer srv.Close() req, _ := http.NewRequest("GET", srv.URL, nil) resp, err := newTestClient().Do(context.Background(), req) if err != nil { t.Fatalf("expected success after retry, got %v", err) } defer resp.Body.Close() if atomic.LoadInt32(&calls) != 2 { t.Fatalf("expected 2 calls, got %d", calls) } } func TestClient_Do_RetryOn403(t *testing.T) { var calls int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&calls, 1) if n == 1 { w.Header().Set("Retry-After", "1") w.WriteHeader(403) return } w.WriteHeader(200) })) defer srv.Close() req, _ := http.NewRequest("GET", srv.URL, nil) resp, err := newTestClient().Do(context.Background(), req) if err != nil { t.Fatalf("expected success after 403 retry, got %v", err) } defer resp.Body.Close() if atomic.LoadInt32(&calls) != 2 { t.Fatalf("expected 2 calls, got %d", calls) } } func TestClient_Do_UnauthorizedNoRetry(t *testing.T) { var calls int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&calls, 1) w.WriteHeader(401) _, _ = w.Write([]byte("bad token")) })) defer srv.Close() req, _ := http.NewRequest("GET", srv.URL, nil) _, err := newTestClient().Do(context.Background(), req) if err == nil { t.Fatal("expected error on 401") } if !errors.Is(err, ErrUnauthorized) { t.Fatalf("expected ErrUnauthorized, got %v", err) } if atomic.LoadInt32(&calls) != 1 { t.Fatalf("expected 1 call (no retry), got %d", calls) } if !strings.Contains(err.Error(), "bad token") { t.Fatalf("expected body in error, got %v", err) } } func TestClient_Do_CtxCancelDuringRetrySleep(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Retry-After", "10") w.WriteHeader(429) })) defer srv.Close() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(100 * time.Millisecond) cancel() }() req, _ := http.NewRequest("GET", srv.URL, nil) start := time.Now() _, err := newTestClient().Do(ctx, req) elapsed := time.Since(start) if err == nil { t.Fatal("expected ctx error") } if !errors.Is(err, context.Canceled) { t.Fatalf("expected ctx.Canceled, got %v", err) } if elapsed > 500*time.Millisecond { t.Fatalf("ctx cancellation too slow: %v", elapsed) } } func TestClient_Do_RetriesExhausted(t *testing.T) { var calls int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&calls, 1) w.Header().Set("Retry-After", "0") w.WriteHeader(500) _, _ = w.Write([]byte("boom")) })) defer srv.Close() c := NewClient() c.MaxRetries = 1 req, _ := http.NewRequest("GET", srv.URL, nil) _, err := c.Do(context.Background(), req) if err == nil { t.Fatal("expected error after exhausted retries") } if atomic.LoadInt32(&calls) != 2 { t.Fatalf("expected 2 calls (1 + 1 retry), got %d", calls) } if !strings.Contains(err.Error(), "500") { t.Fatalf("expected 500 in error, got %v", err) } } func TestClient_Do_DefaultUserAgent(t *testing.T) { var gotUA string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotUA = r.Header.Get("User-Agent") w.WriteHeader(200) })) defer srv.Close() req, _ := http.NewRequest("GET", srv.URL, nil) resp, err := newTestClient().Do(context.Background(), req) if err != nil { t.Fatal(err) } resp.Body.Close() if gotUA != "keyhunter-recon/1.0" { t.Fatalf("expected default UA, got %q", gotUA) } } func TestParseRetryAfter(t *testing.T) { cases := map[string]time.Duration{ "": 1 * time.Second, "0": 1 * time.Second, "2": 2 * time.Second, "60": 60 * time.Second, "abc": 1 * time.Second, } for in, want := range cases { if got := ParseRetryAfter(in); got != want { t.Errorf("ParseRetryAfter(%q) = %v, want %v", in, got, want) } } }