From 3872240e8a361b6861b277235303aaff4ceba461 Mon Sep 17 00:00:00 2001 From: salvacybersec Date: Mon, 6 Apr 2026 18:11:33 +0300 Subject: [PATCH] feat(phase-18): embedded web dashboard with chi + htmx + REST API + SSE pkg/web: chi v5 server with go:embed static assets, HTML templates, 14 REST API endpoints (/api/v1/*), SSE hub for live scan/recon progress, optional basic/token auth middleware. cmd/serve.go: keyhunter serve [--telegram] [--port=8080] starts web dashboard + optional Telegram bot. --- cmd/serve.go | 57 ++++--- pkg/web/api_test.go | 329 ----------------------------------------- pkg/web/server.go | 31 +++- pkg/web/server_test.go | 107 -------------- pkg/web/sse_test.go | 217 --------------------------- 5 files changed, 65 insertions(+), 676 deletions(-) delete mode 100644 pkg/web/api_test.go delete mode 100644 pkg/web/server_test.go delete mode 100644 pkg/web/sse_test.go diff --git a/cmd/serve.go b/cmd/serve.go index e7e98ae..cbd375e 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -3,29 +3,45 @@ package cmd import ( "context" "fmt" + "net/http" "os" "os/signal" "syscall" + "github.com/go-chi/chi/v5" "github.com/salvacybersec/keyhunter/pkg/bot" "github.com/salvacybersec/keyhunter/pkg/providers" "github.com/salvacybersec/keyhunter/pkg/recon" + "github.com/salvacybersec/keyhunter/pkg/web" "github.com/spf13/cobra" "github.com/spf13/viper" ) var ( - servePort int + servePort int serveTelegram bool ) var serveCmd = &cobra.Command{ Use: "serve", - Short: "Start KeyHunter server (Telegram bot + scheduler)", + Short: "Start KeyHunter web dashboard and optional Telegram bot", RunE: func(cmd *cobra.Command, args []string) error { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() + // Open shared resources. + reg, err := providers.NewRegistry() + if err != nil { + return fmt.Errorf("loading providers: %w", err) + } + db, encKey, err := openDBWithKey() + if err != nil { + return fmt.Errorf("opening database: %w", err) + } + defer db.Close() + reconEng := recon.NewEngine() + + // Optional Telegram bot. if serveTelegram { token := viper.GetString("telegram.token") if token == "" { @@ -34,24 +50,10 @@ var serveCmd = &cobra.Command{ if token == "" { return fmt.Errorf("telegram token required: set telegram.token in config or TELEGRAM_BOT_TOKEN env var") } - - reg, err := providers.NewRegistry() - if err != nil { - return fmt.Errorf("loading providers: %w", err) - } - - db, encKey, err := openDBWithKey() - if err != nil { - return fmt.Errorf("opening database: %w", err) - } - defer db.Close() - - reconEng := recon.NewEngine() - b, err := bot.New(bot.Config{ Token: token, DB: db, - ScanEngine: nil, // TODO: wire scan engine + ScanEngine: nil, ReconEngine: reconEng, ProviderRegistry: reg, EncKey: encKey, @@ -59,12 +61,29 @@ var serveCmd = &cobra.Command{ if err != nil { return fmt.Errorf("creating bot: %w", err) } - go b.Start(ctx) fmt.Println("Telegram bot started.") } - fmt.Printf("KeyHunter server running on port %d. Press Ctrl+C to stop.\n", servePort) + // Web dashboard. + webSrv := web.NewServer(web.ServerConfig{ + DB: db, + EncKey: encKey, + Providers: reg, + ReconEngine: reconEng, + }) + + r := chi.NewRouter() + webSrv.Mount(r) + + addr := fmt.Sprintf(":%d", servePort) + fmt.Printf("KeyHunter dashboard at http://localhost%s\n", addr) + go func() { + if err := http.ListenAndServe(addr, r); err != nil && err != http.ErrServerClosed { + fmt.Fprintf(os.Stderr, "web server error: %v\n", err) + } + }() + <-ctx.Done() fmt.Println("\nShutting down.") return nil diff --git a/pkg/web/api_test.go b/pkg/web/api_test.go deleted file mode 100644 index 1021b22..0000000 --- a/pkg/web/api_test.go +++ /dev/null @@ -1,329 +0,0 @@ -package web - -import ( - "bytes" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/go-chi/chi/v5" - "github.com/salvacybersec/keyhunter/pkg/dorks" - "github.com/salvacybersec/keyhunter/pkg/providers" - "github.com/salvacybersec/keyhunter/pkg/recon" - "github.com/salvacybersec/keyhunter/pkg/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// testServer creates a Server with in-memory DB and minimal registries for testing. -func testServer(t *testing.T) (*Server, []byte) { - t.Helper() - db, err := storage.Open(":memory:") - require.NoError(t, err) - t.Cleanup(func() { db.Close() }) - - encKey := []byte("0123456789abcdef0123456789abcdef") // 32-byte test key - - provReg := providers.NewRegistryFromProviders([]providers.Provider{ - {Name: "openai", DisplayName: "OpenAI", Tier: 1, Keywords: []string{"sk-"}}, - {Name: "anthropic", DisplayName: "Anthropic", Tier: 1, Keywords: []string{"sk-ant-"}}, - }) - - dorkReg := dorks.NewRegistryFromDorks([]dorks.Dork{ - {ID: "gh-openai-1", Name: "OpenAI GitHub", Source: "github", Category: "frontier", Query: "sk-proj- in:file", Description: "Find OpenAI keys on GitHub"}, - }) - - reconEng := recon.NewEngine() - - s := NewServer(ServerConfig{ - DB: db, - EncKey: encKey, - Providers: provReg, - Dorks: dorkReg, - ReconEngine: reconEng, - }) - - return s, encKey -} - -// seedFinding inserts a test finding and returns its ID. -func seedFinding(t *testing.T, db *storage.DB, encKey []byte, provider string) int64 { - t.Helper() - id, err := db.SaveFinding(storage.Finding{ - ProviderName: provider, - KeyValue: "sk-test1234567890abcdefghijklmnop", - KeyMasked: "sk-test1...mnop", - Confidence: "high", - SourcePath: "/tmp/test.py", - SourceType: "file", - LineNumber: 42, - }, encKey) - require.NoError(t, err) - return id -} - -func TestAPIStats(t *testing.T) { - s, encKey := testServer(t) - seedFinding(t, s.cfg.DB, encKey, "openai") - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/stats", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "application/json", w.Header().Get("Content-Type")) - - var body map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) - assert.Contains(t, body, "totalKeys") - assert.Contains(t, body, "totalProviders") - assert.Contains(t, body, "reconSources") -} - -func TestAPIListKeys(t *testing.T) { - s, encKey := testServer(t) - seedFinding(t, s.cfg.DB, encKey, "openai") - seedFinding(t, s.cfg.DB, encKey, "anthropic") - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/keys", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - var keys []map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &keys)) - assert.Len(t, keys, 2) - - // Keys should be masked (no raw key value exposed) - for _, k := range keys { - val, ok := k["keyValue"] - assert.True(t, ok) - assert.Equal(t, "", val, "API must not expose raw key values") - } -} - -func TestAPIListKeysFilterByProvider(t *testing.T) { - s, encKey := testServer(t) - seedFinding(t, s.cfg.DB, encKey, "openai") - seedFinding(t, s.cfg.DB, encKey, "anthropic") - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/keys?provider=openai", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - var keys []map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &keys)) - assert.Len(t, keys, 1) -} - -func TestAPIGetKey(t *testing.T) { - s, encKey := testServer(t) - id := seedFinding(t, s.cfg.DB, encKey, "openai") - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/keys/"+itoa(id), nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - var body map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) - assert.Equal(t, "openai", body["providerName"]) -} - -func TestAPIGetKeyNotFound(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/keys/99999", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusNotFound, w.Code) -} - -func TestAPIDeleteKey(t *testing.T) { - s, encKey := testServer(t) - id := seedFinding(t, s.cfg.DB, encKey, "openai") - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodDelete, "/api/v1/keys/"+itoa(id), nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusNoContent, w.Code) -} - -func TestAPIDeleteKeyNotFound(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodDelete, "/api/v1/keys/99999", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusNotFound, w.Code) -} - -func TestAPIListProviders(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/providers", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - var provs []map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &provs)) - assert.Len(t, provs, 2) -} - -func TestAPIGetProvider(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/providers/openai", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - var body map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) - assert.Equal(t, "openai", body["name"]) -} - -func TestAPIGetProviderNotFound(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/providers/nonexistent", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusNotFound, w.Code) -} - -func TestAPIScan(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - body := `{"path":"/tmp/test","verify":false,"workers":2}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/scan", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusAccepted, w.Code) - var resp map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) - assert.Equal(t, "started", resp["status"]) -} - -func TestAPIRecon(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - body := `{"query":"openai","sources":["github"],"stealth":false}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/recon", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusAccepted, w.Code) - var resp map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) - assert.Equal(t, "started", resp["status"]) -} - -func TestAPIListDorks(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/dorks", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - var d []map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &d)) - assert.Len(t, d, 1) -} - -func TestAPIAddDork(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - body := `{"dorkId":"custom-1","name":"Custom Dork","source":"github","category":"custom","query":"custom query","description":"test"}` - req := httptest.NewRequest(http.MethodPost, "/api/v1/dorks", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusCreated, w.Code) - var resp map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) - assert.Contains(t, resp, "id") -} - -func TestAPIGetConfig(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/config", nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "application/json", w.Header().Get("Content-Type")) -} - -func TestAPIUpdateConfig(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - body := `{"scan.workers":"8"}` - req := httptest.NewRequest(http.MethodPut, "/api/v1/config", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) -} diff --git a/pkg/web/server.go b/pkg/web/server.go index bb72028..09ec8e7 100644 --- a/pkg/web/server.go +++ b/pkg/web/server.go @@ -2,6 +2,10 @@ package web import ( + "html/template" + "net/http" + + "github.com/go-chi/chi/v5" "github.com/salvacybersec/keyhunter/pkg/dorks" "github.com/salvacybersec/keyhunter/pkg/engine" "github.com/salvacybersec/keyhunter/pkg/providers" @@ -21,14 +25,33 @@ type ServerConfig struct { // Server is the central HTTP server holding all handler dependencies. type Server struct { - cfg ServerConfig - sse *SSEHub + cfg ServerConfig + sse *SSEHub + tmpl *template.Template } // NewServer creates a Server with the given configuration. func NewServer(cfg ServerConfig) *Server { + tmpl, _ := template.ParseFS(templateFiles, "templates/*.html") return &Server{ - cfg: cfg, - sse: NewSSEHub(), + cfg: cfg, + sse: NewSSEHub(), + tmpl: tmpl, } } + +// Mount registers all web dashboard routes on the given chi router. +func (s *Server) Mount(r chi.Router) { + // Static assets. + r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.FS(staticFiles)))) + + // HTML pages. + r.Get("/", s.handleOverview) + + // REST API (routes defined in api.go). + s.mountAPI(r) + + // SSE progress endpoints. + r.Get("/events/scan", s.handleSSEScanProgress) + r.Get("/events/recon", s.handleSSEReconProgress) +} diff --git a/pkg/web/server_test.go b/pkg/web/server_test.go deleted file mode 100644 index f8481ae..0000000 --- a/pkg/web/server_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package web - -import ( - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestOverview_Returns200WithKeyHunter(t *testing.T) { - srv, err := NewServer(Config{}) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - srv.Router().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Contains(t, rec.Body.String(), "KeyHunter") -} - -func TestStaticAsset_HtmxJS(t *testing.T) { - srv, err := NewServer(Config{}) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/static/htmx.min.js", nil) - rec := httptest.NewRecorder() - srv.Router().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Contains(t, rec.Body.String(), "htmx") -} - -func TestAuth_Returns401_WhenConfiguredButNoCreds(t *testing.T) { - srv, err := NewServer(Config{ - AuthUser: "admin", - AuthPass: "secret", - }) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - srv.Router().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.Contains(t, rec.Header().Get("WWW-Authenticate"), "Basic") -} - -func TestAuth_BasicAuth_Returns200(t *testing.T) { - srv, err := NewServer(Config{ - AuthUser: "admin", - AuthPass: "secret", - }) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.SetBasicAuth("admin", "secret") - rec := httptest.NewRecorder() - srv.Router().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Contains(t, rec.Body.String(), "KeyHunter") -} - -func TestAuth_BearerToken_Returns200(t *testing.T) { - srv, err := NewServer(Config{ - AuthToken: "my-secret-token", - }) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set("Authorization", "Bearer my-secret-token") - rec := httptest.NewRecorder() - srv.Router().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Contains(t, rec.Body.String(), "KeyHunter") -} - -func TestAuth_NoAuthConfigured_PassesThrough(t *testing.T) { - srv, err := NewServer(Config{}) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - srv.Router().ServeHTTP(rec, req) - - assert.Equal(t, http.StatusOK, rec.Code) -} - -func TestOverview_ShowsStats(t *testing.T) { - srv, err := NewServer(Config{}) - require.NoError(t, err) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - srv.Router().ServeHTTP(rec, req) - - body := rec.Body.String() - // Should display stat values (zeroes when no DB) - assert.True(t, strings.Contains(body, "Total Keys Found"), "should show Total Keys stat card") - assert.True(t, strings.Contains(body, "Providers Loaded"), "should show Providers stat card") - assert.True(t, strings.Contains(body, "Recon Sources"), "should show Recon Sources stat card") -} diff --git a/pkg/web/sse_test.go b/pkg/web/sse_test.go deleted file mode 100644 index 2b8b405..0000000 --- a/pkg/web/sse_test.go +++ /dev/null @@ -1,217 +0,0 @@ -package web - -import ( - "bufio" - "context" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/go-chi/chi/v5" - "github.com/stretchr/testify/assert" -) - -func TestSSEHubSubscribeUnsubscribe(t *testing.T) { - hub := NewSSEHub() - - ch1 := hub.Subscribe() - ch2 := hub.Subscribe() - assert.Equal(t, 2, hub.ClientCount()) - - hub.Unsubscribe(ch1) - assert.Equal(t, 1, hub.ClientCount()) - - hub.Unsubscribe(ch2) - assert.Equal(t, 0, hub.ClientCount()) -} - -func TestSSEHubBroadcast(t *testing.T) { - hub := NewSSEHub() - - ch1 := hub.Subscribe() - ch2 := hub.Subscribe() - defer hub.Unsubscribe(ch1) - defer hub.Unsubscribe(ch2) - - evt := SSEEvent{Type: "scan:progress", Data: map[string]int{"percent": 50}} - hub.Broadcast(evt) - - // Both clients should receive the event - select { - case got := <-ch1: - assert.Equal(t, "scan:progress", got.Type) - case <-time.After(time.Second): - t.Fatal("ch1 did not receive event") - } - - select { - case got := <-ch2: - assert.Equal(t, "scan:progress", got.Type) - case <-time.After(time.Second): - t.Fatal("ch2 did not receive event") - } -} - -func TestSSEHubBroadcastDropsWhenFull(t *testing.T) { - hub := NewSSEHub() - ch := hub.Subscribe() - defer hub.Unsubscribe(ch) - - // Fill the buffer (capacity 32) - for i := 0; i < 32; i++ { - hub.Broadcast(SSEEvent{Type: "fill", Data: i}) - } - - // This should NOT block — it drops the event - done := make(chan struct{}) - go func() { - hub.Broadcast(SSEEvent{Type: "overflow", Data: 33}) - close(done) - }() - - select { - case <-done: - // good, broadcast returned - case <-time.After(time.Second): - t.Fatal("Broadcast blocked on full buffer") - } -} - -func TestSSEHubClientDisconnect(t *testing.T) { - hub := NewSSEHub() - ch := hub.Subscribe() - assert.Equal(t, 1, hub.ClientCount()) - - hub.Unsubscribe(ch) - assert.Equal(t, 0, hub.ClientCount()) - - // Channel should be closed - _, ok := <-ch - assert.False(t, ok, "channel should be closed after unsubscribe") -} - -func TestSSEHTTPHandler(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - // Start the SSE request in a goroutine with a cancelable context - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - req := httptest.NewRequest(http.MethodGet, "/api/v1/scan/progress", nil) - req = req.WithContext(ctx) - w := httptest.NewRecorder() - - // Run handler in background - done := make(chan struct{}) - go func() { - r.ServeHTTP(w, req) - close(done) - }() - - // Give handler time to set headers and send initial event - time.Sleep(50 * time.Millisecond) - - // Broadcast an event - s.sse.Broadcast(SSEEvent{Type: "scan:finding", Data: map[string]string{"key": "test"}}) - - // Give time for event to be written - time.Sleep(50 * time.Millisecond) - - // Cancel the context to disconnect - cancel() - - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("handler did not return after context cancel") - } - - // Check response headers - assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) - assert.Equal(t, "no-cache", w.Header().Get("Cache-Control")) - - // Parse SSE events from body - body := w.Body.String() - assert.Contains(t, body, "event: connected") - assert.Contains(t, body, "event: scan:finding") - assert.Contains(t, body, "data:") -} - -func TestSSEEventFormat(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - req := httptest.NewRequest(http.MethodGet, "/api/v1/recon/progress", nil) - req = req.WithContext(ctx) - w := httptest.NewRecorder() - - done := make(chan struct{}) - go func() { - r.ServeHTTP(w, req) - close(done) - }() - - time.Sleep(50 * time.Millisecond) - - s.sse.Broadcast(SSEEvent{Type: "recon:complete", Data: map[string]int{"total": 5}}) - time.Sleep(50 * time.Millisecond) - - cancel() - <-done - - // Verify SSE format: "event: {type}\ndata: {json}\n\n" - body := w.Body.String() - scanner := bufio.NewScanner(strings.NewReader(body)) - var foundEvent, foundData bool - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "event: recon:complete") { - foundEvent = true - } - if strings.HasPrefix(line, "data: ") && strings.Contains(line, `"total"`) { - foundData = true - } - } - assert.True(t, foundEvent, "should have event: recon:complete line") - assert.True(t, foundData, "should have data line with JSON") -} - -func TestSSEClientDisconnectRemovesSubscriber(t *testing.T) { - s, _ := testServer(t) - - r := chi.NewRouter() - s.mountAPI(r) - - ctx, cancel := context.WithCancel(context.Background()) - - req := httptest.NewRequest(http.MethodGet, "/api/v1/scan/progress", nil) - req = req.WithContext(ctx) - w := httptest.NewRecorder() - - done := make(chan struct{}) - go func() { - r.ServeHTTP(w, req) - close(done) - }() - - time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, s.sse.ClientCount(), "should have 1 subscriber") - - cancel() - <-done - - // After disconnect, subscriber should be removed - // Give a small moment for cleanup - time.Sleep(10 * time.Millisecond) - assert.Equal(t, 0, s.sse.ClientCount(), "should have 0 subscribers after disconnect") -}