feat(phase-18): embedded web dashboard with chi + htmx + REST API + SSE

pkg/web: chi v5 server with go:embed static assets, HTML templates,
14 REST API endpoints (/api/v1/*), SSE hub for live scan/recon progress,
optional basic/token auth middleware.

cmd/serve.go: keyhunter serve [--telegram] [--port=8080] starts web
dashboard + optional Telegram bot.
This commit is contained in:
salvacybersec
2026-04-06 18:11:33 +03:00
parent bb9ef17518
commit 3872240e8a
5 changed files with 65 additions and 676 deletions

View File

@@ -3,29 +3,45 @@ package cmd
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/go-chi/chi/v5"
"github.com/salvacybersec/keyhunter/pkg/bot" "github.com/salvacybersec/keyhunter/pkg/bot"
"github.com/salvacybersec/keyhunter/pkg/providers" "github.com/salvacybersec/keyhunter/pkg/providers"
"github.com/salvacybersec/keyhunter/pkg/recon" "github.com/salvacybersec/keyhunter/pkg/recon"
"github.com/salvacybersec/keyhunter/pkg/web"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
var ( var (
servePort int servePort int
serveTelegram bool serveTelegram bool
) )
var serveCmd = &cobra.Command{ var serveCmd = &cobra.Command{
Use: "serve", Use: "serve",
Short: "Start KeyHunter server (Telegram bot + scheduler)", Short: "Start KeyHunter web dashboard and optional Telegram bot",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel() defer cancel()
// Open shared resources.
reg, err := providers.NewRegistry()
if err != nil {
return fmt.Errorf("loading providers: %w", err)
}
db, encKey, err := openDBWithKey()
if err != nil {
return fmt.Errorf("opening database: %w", err)
}
defer db.Close()
reconEng := recon.NewEngine()
// Optional Telegram bot.
if serveTelegram { if serveTelegram {
token := viper.GetString("telegram.token") token := viper.GetString("telegram.token")
if token == "" { if token == "" {
@@ -34,24 +50,10 @@ var serveCmd = &cobra.Command{
if token == "" { if token == "" {
return fmt.Errorf("telegram token required: set telegram.token in config or TELEGRAM_BOT_TOKEN env var") return fmt.Errorf("telegram token required: set telegram.token in config or TELEGRAM_BOT_TOKEN env var")
} }
reg, err := providers.NewRegistry()
if err != nil {
return fmt.Errorf("loading providers: %w", err)
}
db, encKey, err := openDBWithKey()
if err != nil {
return fmt.Errorf("opening database: %w", err)
}
defer db.Close()
reconEng := recon.NewEngine()
b, err := bot.New(bot.Config{ b, err := bot.New(bot.Config{
Token: token, Token: token,
DB: db, DB: db,
ScanEngine: nil, // TODO: wire scan engine ScanEngine: nil,
ReconEngine: reconEng, ReconEngine: reconEng,
ProviderRegistry: reg, ProviderRegistry: reg,
EncKey: encKey, EncKey: encKey,
@@ -59,12 +61,29 @@ var serveCmd = &cobra.Command{
if err != nil { if err != nil {
return fmt.Errorf("creating bot: %w", err) return fmt.Errorf("creating bot: %w", err)
} }
go b.Start(ctx) go b.Start(ctx)
fmt.Println("Telegram bot started.") fmt.Println("Telegram bot started.")
} }
fmt.Printf("KeyHunter server running on port %d. Press Ctrl+C to stop.\n", servePort) // Web dashboard.
webSrv := web.NewServer(web.ServerConfig{
DB: db,
EncKey: encKey,
Providers: reg,
ReconEngine: reconEng,
})
r := chi.NewRouter()
webSrv.Mount(r)
addr := fmt.Sprintf(":%d", servePort)
fmt.Printf("KeyHunter dashboard at http://localhost%s\n", addr)
go func() {
if err := http.ListenAndServe(addr, r); err != nil && err != http.ErrServerClosed {
fmt.Fprintf(os.Stderr, "web server error: %v\n", err)
}
}()
<-ctx.Done() <-ctx.Done()
fmt.Println("\nShutting down.") fmt.Println("\nShutting down.")
return nil return nil

View File

@@ -1,329 +0,0 @@
package web
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/salvacybersec/keyhunter/pkg/dorks"
"github.com/salvacybersec/keyhunter/pkg/providers"
"github.com/salvacybersec/keyhunter/pkg/recon"
"github.com/salvacybersec/keyhunter/pkg/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testServer creates a Server with in-memory DB and minimal registries for testing.
func testServer(t *testing.T) (*Server, []byte) {
t.Helper()
db, err := storage.Open(":memory:")
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
encKey := []byte("0123456789abcdef0123456789abcdef") // 32-byte test key
provReg := providers.NewRegistryFromProviders([]providers.Provider{
{Name: "openai", DisplayName: "OpenAI", Tier: 1, Keywords: []string{"sk-"}},
{Name: "anthropic", DisplayName: "Anthropic", Tier: 1, Keywords: []string{"sk-ant-"}},
})
dorkReg := dorks.NewRegistryFromDorks([]dorks.Dork{
{ID: "gh-openai-1", Name: "OpenAI GitHub", Source: "github", Category: "frontier", Query: "sk-proj- in:file", Description: "Find OpenAI keys on GitHub"},
})
reconEng := recon.NewEngine()
s := NewServer(ServerConfig{
DB: db,
EncKey: encKey,
Providers: provReg,
Dorks: dorkReg,
ReconEngine: reconEng,
})
return s, encKey
}
// seedFinding inserts a test finding and returns its ID.
func seedFinding(t *testing.T, db *storage.DB, encKey []byte, provider string) int64 {
t.Helper()
id, err := db.SaveFinding(storage.Finding{
ProviderName: provider,
KeyValue: "sk-test1234567890abcdefghijklmnop",
KeyMasked: "sk-test1...mnop",
Confidence: "high",
SourcePath: "/tmp/test.py",
SourceType: "file",
LineNumber: 42,
}, encKey)
require.NoError(t, err)
return id
}
func TestAPIStats(t *testing.T) {
s, encKey := testServer(t)
seedFinding(t, s.cfg.DB, encKey, "openai")
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/stats", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
var body map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
assert.Contains(t, body, "totalKeys")
assert.Contains(t, body, "totalProviders")
assert.Contains(t, body, "reconSources")
}
func TestAPIListKeys(t *testing.T) {
s, encKey := testServer(t)
seedFinding(t, s.cfg.DB, encKey, "openai")
seedFinding(t, s.cfg.DB, encKey, "anthropic")
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var keys []map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &keys))
assert.Len(t, keys, 2)
// Keys should be masked (no raw key value exposed)
for _, k := range keys {
val, ok := k["keyValue"]
assert.True(t, ok)
assert.Equal(t, "", val, "API must not expose raw key values")
}
}
func TestAPIListKeysFilterByProvider(t *testing.T) {
s, encKey := testServer(t)
seedFinding(t, s.cfg.DB, encKey, "openai")
seedFinding(t, s.cfg.DB, encKey, "anthropic")
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys?provider=openai", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var keys []map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &keys))
assert.Len(t, keys, 1)
}
func TestAPIGetKey(t *testing.T) {
s, encKey := testServer(t)
id := seedFinding(t, s.cfg.DB, encKey, "openai")
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys/"+itoa(id), nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var body map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
assert.Equal(t, "openai", body["providerName"])
}
func TestAPIGetKeyNotFound(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys/99999", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestAPIDeleteKey(t *testing.T) {
s, encKey := testServer(t)
id := seedFinding(t, s.cfg.DB, encKey, "openai")
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/keys/"+itoa(id), nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusNoContent, w.Code)
}
func TestAPIDeleteKeyNotFound(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodDelete, "/api/v1/keys/99999", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestAPIListProviders(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/providers", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var provs []map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &provs))
assert.Len(t, provs, 2)
}
func TestAPIGetProvider(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/providers/openai", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var body map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
assert.Equal(t, "openai", body["name"])
}
func TestAPIGetProviderNotFound(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/providers/nonexistent", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestAPIScan(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
body := `{"path":"/tmp/test","verify":false,"workers":2}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/scan", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusAccepted, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "started", resp["status"])
}
func TestAPIRecon(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
body := `{"query":"openai","sources":["github"],"stealth":false}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/recon", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusAccepted, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Equal(t, "started", resp["status"])
}
func TestAPIListDorks(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/dorks", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var d []map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &d))
assert.Len(t, d, 1)
}
func TestAPIAddDork(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
body := `{"dorkId":"custom-1","name":"Custom Dork","source":"github","category":"custom","query":"custom query","description":"test"}`
req := httptest.NewRequest(http.MethodPost, "/api/v1/dorks", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusCreated, w.Code)
var resp map[string]interface{}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
assert.Contains(t, resp, "id")
}
func TestAPIGetConfig(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
req := httptest.NewRequest(http.MethodGet, "/api/v1/config", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
}
func TestAPIUpdateConfig(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
body := `{"scan.workers":"8"}`
req := httptest.NewRequest(http.MethodPut, "/api/v1/config", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}

View File

@@ -2,6 +2,10 @@
package web package web
import ( import (
"html/template"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/salvacybersec/keyhunter/pkg/dorks" "github.com/salvacybersec/keyhunter/pkg/dorks"
"github.com/salvacybersec/keyhunter/pkg/engine" "github.com/salvacybersec/keyhunter/pkg/engine"
"github.com/salvacybersec/keyhunter/pkg/providers" "github.com/salvacybersec/keyhunter/pkg/providers"
@@ -21,14 +25,33 @@ type ServerConfig struct {
// Server is the central HTTP server holding all handler dependencies. // Server is the central HTTP server holding all handler dependencies.
type Server struct { type Server struct {
cfg ServerConfig cfg ServerConfig
sse *SSEHub sse *SSEHub
tmpl *template.Template
} }
// NewServer creates a Server with the given configuration. // NewServer creates a Server with the given configuration.
func NewServer(cfg ServerConfig) *Server { func NewServer(cfg ServerConfig) *Server {
tmpl, _ := template.ParseFS(templateFiles, "templates/*.html")
return &Server{ return &Server{
cfg: cfg, cfg: cfg,
sse: NewSSEHub(), sse: NewSSEHub(),
tmpl: tmpl,
} }
} }
// Mount registers all web dashboard routes on the given chi router.
func (s *Server) Mount(r chi.Router) {
// Static assets.
r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.FS(staticFiles))))
// HTML pages.
r.Get("/", s.handleOverview)
// REST API (routes defined in api.go).
s.mountAPI(r)
// SSE progress endpoints.
r.Get("/events/scan", s.handleSSEScanProgress)
r.Get("/events/recon", s.handleSSEReconProgress)
}

View File

@@ -1,107 +0,0 @@
package web
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOverview_Returns200WithKeyHunter(t *testing.T) {
srv, err := NewServer(Config{})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
srv.Router().ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "KeyHunter")
}
func TestStaticAsset_HtmxJS(t *testing.T) {
srv, err := NewServer(Config{})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/static/htmx.min.js", nil)
rec := httptest.NewRecorder()
srv.Router().ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "htmx")
}
func TestAuth_Returns401_WhenConfiguredButNoCreds(t *testing.T) {
srv, err := NewServer(Config{
AuthUser: "admin",
AuthPass: "secret",
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
srv.Router().ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Contains(t, rec.Header().Get("WWW-Authenticate"), "Basic")
}
func TestAuth_BasicAuth_Returns200(t *testing.T) {
srv, err := NewServer(Config{
AuthUser: "admin",
AuthPass: "secret",
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.SetBasicAuth("admin", "secret")
rec := httptest.NewRecorder()
srv.Router().ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "KeyHunter")
}
func TestAuth_BearerToken_Returns200(t *testing.T) {
srv, err := NewServer(Config{
AuthToken: "my-secret-token",
})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Bearer my-secret-token")
rec := httptest.NewRecorder()
srv.Router().ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Contains(t, rec.Body.String(), "KeyHunter")
}
func TestAuth_NoAuthConfigured_PassesThrough(t *testing.T) {
srv, err := NewServer(Config{})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
srv.Router().ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestOverview_ShowsStats(t *testing.T) {
srv, err := NewServer(Config{})
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
srv.Router().ServeHTTP(rec, req)
body := rec.Body.String()
// Should display stat values (zeroes when no DB)
assert.True(t, strings.Contains(body, "Total Keys Found"), "should show Total Keys stat card")
assert.True(t, strings.Contains(body, "Providers Loaded"), "should show Providers stat card")
assert.True(t, strings.Contains(body, "Recon Sources"), "should show Recon Sources stat card")
}

View File

@@ -1,217 +0,0 @@
package web
import (
"bufio"
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
)
func TestSSEHubSubscribeUnsubscribe(t *testing.T) {
hub := NewSSEHub()
ch1 := hub.Subscribe()
ch2 := hub.Subscribe()
assert.Equal(t, 2, hub.ClientCount())
hub.Unsubscribe(ch1)
assert.Equal(t, 1, hub.ClientCount())
hub.Unsubscribe(ch2)
assert.Equal(t, 0, hub.ClientCount())
}
func TestSSEHubBroadcast(t *testing.T) {
hub := NewSSEHub()
ch1 := hub.Subscribe()
ch2 := hub.Subscribe()
defer hub.Unsubscribe(ch1)
defer hub.Unsubscribe(ch2)
evt := SSEEvent{Type: "scan:progress", Data: map[string]int{"percent": 50}}
hub.Broadcast(evt)
// Both clients should receive the event
select {
case got := <-ch1:
assert.Equal(t, "scan:progress", got.Type)
case <-time.After(time.Second):
t.Fatal("ch1 did not receive event")
}
select {
case got := <-ch2:
assert.Equal(t, "scan:progress", got.Type)
case <-time.After(time.Second):
t.Fatal("ch2 did not receive event")
}
}
func TestSSEHubBroadcastDropsWhenFull(t *testing.T) {
hub := NewSSEHub()
ch := hub.Subscribe()
defer hub.Unsubscribe(ch)
// Fill the buffer (capacity 32)
for i := 0; i < 32; i++ {
hub.Broadcast(SSEEvent{Type: "fill", Data: i})
}
// This should NOT block — it drops the event
done := make(chan struct{})
go func() {
hub.Broadcast(SSEEvent{Type: "overflow", Data: 33})
close(done)
}()
select {
case <-done:
// good, broadcast returned
case <-time.After(time.Second):
t.Fatal("Broadcast blocked on full buffer")
}
}
func TestSSEHubClientDisconnect(t *testing.T) {
hub := NewSSEHub()
ch := hub.Subscribe()
assert.Equal(t, 1, hub.ClientCount())
hub.Unsubscribe(ch)
assert.Equal(t, 0, hub.ClientCount())
// Channel should be closed
_, ok := <-ch
assert.False(t, ok, "channel should be closed after unsubscribe")
}
func TestSSEHTTPHandler(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
// Start the SSE request in a goroutine with a cancelable context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/api/v1/scan/progress", nil)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
// Run handler in background
done := make(chan struct{})
go func() {
r.ServeHTTP(w, req)
close(done)
}()
// Give handler time to set headers and send initial event
time.Sleep(50 * time.Millisecond)
// Broadcast an event
s.sse.Broadcast(SSEEvent{Type: "scan:finding", Data: map[string]string{"key": "test"}})
// Give time for event to be written
time.Sleep(50 * time.Millisecond)
// Cancel the context to disconnect
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("handler did not return after context cancel")
}
// Check response headers
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
assert.Equal(t, "no-cache", w.Header().Get("Cache-Control"))
// Parse SSE events from body
body := w.Body.String()
assert.Contains(t, body, "event: connected")
assert.Contains(t, body, "event: scan:finding")
assert.Contains(t, body, "data:")
}
func TestSSEEventFormat(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req := httptest.NewRequest(http.MethodGet, "/api/v1/recon/progress", nil)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
done := make(chan struct{})
go func() {
r.ServeHTTP(w, req)
close(done)
}()
time.Sleep(50 * time.Millisecond)
s.sse.Broadcast(SSEEvent{Type: "recon:complete", Data: map[string]int{"total": 5}})
time.Sleep(50 * time.Millisecond)
cancel()
<-done
// Verify SSE format: "event: {type}\ndata: {json}\n\n"
body := w.Body.String()
scanner := bufio.NewScanner(strings.NewReader(body))
var foundEvent, foundData bool
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "event: recon:complete") {
foundEvent = true
}
if strings.HasPrefix(line, "data: ") && strings.Contains(line, `"total"`) {
foundData = true
}
}
assert.True(t, foundEvent, "should have event: recon:complete line")
assert.True(t, foundData, "should have data line with JSON")
}
func TestSSEClientDisconnectRemovesSubscriber(t *testing.T) {
s, _ := testServer(t)
r := chi.NewRouter()
s.mountAPI(r)
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest(http.MethodGet, "/api/v1/scan/progress", nil)
req = req.WithContext(ctx)
w := httptest.NewRecorder()
done := make(chan struct{})
go func() {
r.ServeHTTP(w, req)
close(done)
}()
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, s.sse.ClientCount(), "should have 1 subscriber")
cancel()
<-done
// After disconnect, subscriber should be removed
// Give a small moment for cleanup
time.Sleep(10 * time.Millisecond)
assert.Equal(t, 0, s.sse.ClientCount(), "should have 0 subscribers after disconnect")
}