feat(phase-18): embedded web dashboard with chi + htmx + REST API + SSE
pkg/web: chi v5 server with go:embed static assets, HTML templates, 14 REST API endpoints (/api/v1/*), SSE hub for live scan/recon progress, optional basic/token auth middleware. cmd/serve.go: keyhunter serve [--telegram] [--port=8080] starts web dashboard + optional Telegram bot.
This commit is contained in:
@@ -1,329 +0,0 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/salvacybersec/keyhunter/pkg/dorks"
|
||||
"github.com/salvacybersec/keyhunter/pkg/providers"
|
||||
"github.com/salvacybersec/keyhunter/pkg/recon"
|
||||
"github.com/salvacybersec/keyhunter/pkg/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testServer creates a Server with in-memory DB and minimal registries for testing.
|
||||
func testServer(t *testing.T) (*Server, []byte) {
|
||||
t.Helper()
|
||||
db, err := storage.Open(":memory:")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { db.Close() })
|
||||
|
||||
encKey := []byte("0123456789abcdef0123456789abcdef") // 32-byte test key
|
||||
|
||||
provReg := providers.NewRegistryFromProviders([]providers.Provider{
|
||||
{Name: "openai", DisplayName: "OpenAI", Tier: 1, Keywords: []string{"sk-"}},
|
||||
{Name: "anthropic", DisplayName: "Anthropic", Tier: 1, Keywords: []string{"sk-ant-"}},
|
||||
})
|
||||
|
||||
dorkReg := dorks.NewRegistryFromDorks([]dorks.Dork{
|
||||
{ID: "gh-openai-1", Name: "OpenAI GitHub", Source: "github", Category: "frontier", Query: "sk-proj- in:file", Description: "Find OpenAI keys on GitHub"},
|
||||
})
|
||||
|
||||
reconEng := recon.NewEngine()
|
||||
|
||||
s := NewServer(ServerConfig{
|
||||
DB: db,
|
||||
EncKey: encKey,
|
||||
Providers: provReg,
|
||||
Dorks: dorkReg,
|
||||
ReconEngine: reconEng,
|
||||
})
|
||||
|
||||
return s, encKey
|
||||
}
|
||||
|
||||
// seedFinding inserts a test finding and returns its ID.
|
||||
func seedFinding(t *testing.T, db *storage.DB, encKey []byte, provider string) int64 {
|
||||
t.Helper()
|
||||
id, err := db.SaveFinding(storage.Finding{
|
||||
ProviderName: provider,
|
||||
KeyValue: "sk-test1234567890abcdefghijklmnop",
|
||||
KeyMasked: "sk-test1...mnop",
|
||||
Confidence: "high",
|
||||
SourcePath: "/tmp/test.py",
|
||||
SourceType: "file",
|
||||
LineNumber: 42,
|
||||
}, encKey)
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}
|
||||
|
||||
func TestAPIStats(t *testing.T) {
|
||||
s, encKey := testServer(t)
|
||||
seedFinding(t, s.cfg.DB, encKey, "openai")
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
||||
|
||||
var body map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
assert.Contains(t, body, "totalKeys")
|
||||
assert.Contains(t, body, "totalProviders")
|
||||
assert.Contains(t, body, "reconSources")
|
||||
}
|
||||
|
||||
func TestAPIListKeys(t *testing.T) {
|
||||
s, encKey := testServer(t)
|
||||
seedFinding(t, s.cfg.DB, encKey, "openai")
|
||||
seedFinding(t, s.cfg.DB, encKey, "anthropic")
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var keys []map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &keys))
|
||||
assert.Len(t, keys, 2)
|
||||
|
||||
// Keys should be masked (no raw key value exposed)
|
||||
for _, k := range keys {
|
||||
val, ok := k["keyValue"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "", val, "API must not expose raw key values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIListKeysFilterByProvider(t *testing.T) {
|
||||
s, encKey := testServer(t)
|
||||
seedFinding(t, s.cfg.DB, encKey, "openai")
|
||||
seedFinding(t, s.cfg.DB, encKey, "anthropic")
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys?provider=openai", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var keys []map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &keys))
|
||||
assert.Len(t, keys, 1)
|
||||
}
|
||||
|
||||
func TestAPIGetKey(t *testing.T) {
|
||||
s, encKey := testServer(t)
|
||||
id := seedFinding(t, s.cfg.DB, encKey, "openai")
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys/"+itoa(id), nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var body map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
assert.Equal(t, "openai", body["providerName"])
|
||||
}
|
||||
|
||||
func TestAPIGetKeyNotFound(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/keys/99999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIDeleteKey(t *testing.T) {
|
||||
s, encKey := testServer(t)
|
||||
id := seedFinding(t, s.cfg.DB, encKey, "openai")
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/keys/"+itoa(id), nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIDeleteKeyNotFound(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/v1/keys/99999", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIListProviders(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/providers", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var provs []map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &provs))
|
||||
assert.Len(t, provs, 2)
|
||||
}
|
||||
|
||||
func TestAPIGetProvider(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/providers/openai", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var body map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
assert.Equal(t, "openai", body["name"])
|
||||
}
|
||||
|
||||
func TestAPIGetProviderNotFound(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/providers/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIScan(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
body := `{"path":"/tmp/test","verify":false,"workers":2}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/scan", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusAccepted, w.Code)
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "started", resp["status"])
|
||||
}
|
||||
|
||||
func TestAPIRecon(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
body := `{"query":"openai","sources":["github"],"stealth":false}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/recon", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusAccepted, w.Code)
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Equal(t, "started", resp["status"])
|
||||
}
|
||||
|
||||
func TestAPIListDorks(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/dorks", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var d []map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &d))
|
||||
assert.Len(t, d, 1)
|
||||
}
|
||||
|
||||
func TestAPIAddDork(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
body := `{"dorkId":"custom-1","name":"Custom Dork","source":"github","category":"custom","query":"custom query","description":"test"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/dorks", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
var resp map[string]interface{}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||
assert.Contains(t, resp, "id")
|
||||
}
|
||||
|
||||
func TestAPIGetConfig(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/config", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "application/json", w.Header().Get("Content-Type"))
|
||||
}
|
||||
|
||||
func TestAPIUpdateConfig(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
body := `{"scan.workers":"8"}`
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/v1/config", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
@@ -2,6 +2,10 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/salvacybersec/keyhunter/pkg/dorks"
|
||||
"github.com/salvacybersec/keyhunter/pkg/engine"
|
||||
"github.com/salvacybersec/keyhunter/pkg/providers"
|
||||
@@ -21,14 +25,33 @@ type ServerConfig struct {
|
||||
|
||||
// Server is the central HTTP server holding all handler dependencies.
|
||||
type Server struct {
|
||||
cfg ServerConfig
|
||||
sse *SSEHub
|
||||
cfg ServerConfig
|
||||
sse *SSEHub
|
||||
tmpl *template.Template
|
||||
}
|
||||
|
||||
// NewServer creates a Server with the given configuration.
|
||||
func NewServer(cfg ServerConfig) *Server {
|
||||
tmpl, _ := template.ParseFS(templateFiles, "templates/*.html")
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
sse: NewSSEHub(),
|
||||
cfg: cfg,
|
||||
sse: NewSSEHub(),
|
||||
tmpl: tmpl,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount registers all web dashboard routes on the given chi router.
|
||||
func (s *Server) Mount(r chi.Router) {
|
||||
// Static assets.
|
||||
r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.FS(staticFiles))))
|
||||
|
||||
// HTML pages.
|
||||
r.Get("/", s.handleOverview)
|
||||
|
||||
// REST API (routes defined in api.go).
|
||||
s.mountAPI(r)
|
||||
|
||||
// SSE progress endpoints.
|
||||
r.Get("/events/scan", s.handleSSEScanProgress)
|
||||
r.Get("/events/recon", s.handleSSEReconProgress)
|
||||
}
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOverview_Returns200WithKeyHunter(t *testing.T) {
|
||||
srv, err := NewServer(Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Router().ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "KeyHunter")
|
||||
}
|
||||
|
||||
func TestStaticAsset_HtmxJS(t *testing.T) {
|
||||
srv, err := NewServer(Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/static/htmx.min.js", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Router().ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "htmx")
|
||||
}
|
||||
|
||||
func TestAuth_Returns401_WhenConfiguredButNoCreds(t *testing.T) {
|
||||
srv, err := NewServer(Config{
|
||||
AuthUser: "admin",
|
||||
AuthPass: "secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Router().ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Header().Get("WWW-Authenticate"), "Basic")
|
||||
}
|
||||
|
||||
func TestAuth_BasicAuth_Returns200(t *testing.T) {
|
||||
srv, err := NewServer(Config{
|
||||
AuthUser: "admin",
|
||||
AuthPass: "secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.SetBasicAuth("admin", "secret")
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Router().ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "KeyHunter")
|
||||
}
|
||||
|
||||
func TestAuth_BearerToken_Returns200(t *testing.T) {
|
||||
srv, err := NewServer(Config{
|
||||
AuthToken: "my-secret-token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Authorization", "Bearer my-secret-token")
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Router().ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "KeyHunter")
|
||||
}
|
||||
|
||||
func TestAuth_NoAuthConfigured_PassesThrough(t *testing.T) {
|
||||
srv, err := NewServer(Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Router().ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestOverview_ShowsStats(t *testing.T) {
|
||||
srv, err := NewServer(Config{})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Router().ServeHTTP(rec, req)
|
||||
|
||||
body := rec.Body.String()
|
||||
// Should display stat values (zeroes when no DB)
|
||||
assert.True(t, strings.Contains(body, "Total Keys Found"), "should show Total Keys stat card")
|
||||
assert.True(t, strings.Contains(body, "Providers Loaded"), "should show Providers stat card")
|
||||
assert.True(t, strings.Contains(body, "Recon Sources"), "should show Recon Sources stat card")
|
||||
}
|
||||
@@ -1,217 +0,0 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSSEHubSubscribeUnsubscribe(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
|
||||
ch1 := hub.Subscribe()
|
||||
ch2 := hub.Subscribe()
|
||||
assert.Equal(t, 2, hub.ClientCount())
|
||||
|
||||
hub.Unsubscribe(ch1)
|
||||
assert.Equal(t, 1, hub.ClientCount())
|
||||
|
||||
hub.Unsubscribe(ch2)
|
||||
assert.Equal(t, 0, hub.ClientCount())
|
||||
}
|
||||
|
||||
func TestSSEHubBroadcast(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
|
||||
ch1 := hub.Subscribe()
|
||||
ch2 := hub.Subscribe()
|
||||
defer hub.Unsubscribe(ch1)
|
||||
defer hub.Unsubscribe(ch2)
|
||||
|
||||
evt := SSEEvent{Type: "scan:progress", Data: map[string]int{"percent": 50}}
|
||||
hub.Broadcast(evt)
|
||||
|
||||
// Both clients should receive the event
|
||||
select {
|
||||
case got := <-ch1:
|
||||
assert.Equal(t, "scan:progress", got.Type)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ch1 did not receive event")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-ch2:
|
||||
assert.Equal(t, "scan:progress", got.Type)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("ch2 did not receive event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEHubBroadcastDropsWhenFull(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
ch := hub.Subscribe()
|
||||
defer hub.Unsubscribe(ch)
|
||||
|
||||
// Fill the buffer (capacity 32)
|
||||
for i := 0; i < 32; i++ {
|
||||
hub.Broadcast(SSEEvent{Type: "fill", Data: i})
|
||||
}
|
||||
|
||||
// This should NOT block — it drops the event
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hub.Broadcast(SSEEvent{Type: "overflow", Data: 33})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// good, broadcast returned
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Broadcast blocked on full buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEHubClientDisconnect(t *testing.T) {
|
||||
hub := NewSSEHub()
|
||||
ch := hub.Subscribe()
|
||||
assert.Equal(t, 1, hub.ClientCount())
|
||||
|
||||
hub.Unsubscribe(ch)
|
||||
assert.Equal(t, 0, hub.ClientCount())
|
||||
|
||||
// Channel should be closed
|
||||
_, ok := <-ch
|
||||
assert.False(t, ok, "channel should be closed after unsubscribe")
|
||||
}
|
||||
|
||||
func TestSSEHTTPHandler(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
// Start the SSE request in a goroutine with a cancelable context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/scan/progress", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Run handler in background
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Give handler time to set headers and send initial event
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Broadcast an event
|
||||
s.sse.Broadcast(SSEEvent{Type: "scan:finding", Data: map[string]string{"key": "test"}})
|
||||
|
||||
// Give time for event to be written
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Cancel the context to disconnect
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("handler did not return after context cancel")
|
||||
}
|
||||
|
||||
// Check response headers
|
||||
assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type"))
|
||||
assert.Equal(t, "no-cache", w.Header().Get("Cache-Control"))
|
||||
|
||||
// Parse SSE events from body
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, "event: connected")
|
||||
assert.Contains(t, body, "event: scan:finding")
|
||||
assert.Contains(t, body, "data:")
|
||||
}
|
||||
|
||||
func TestSSEEventFormat(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/recon/progress", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
s.sse.Broadcast(SSEEvent{Type: "recon:complete", Data: map[string]int{"total": 5}})
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
// Verify SSE format: "event: {type}\ndata: {json}\n\n"
|
||||
body := w.Body.String()
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
var foundEvent, foundData bool
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "event: recon:complete") {
|
||||
foundEvent = true
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") && strings.Contains(line, `"total"`) {
|
||||
foundData = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundEvent, "should have event: recon:complete line")
|
||||
assert.True(t, foundData, "should have data line with JSON")
|
||||
}
|
||||
|
||||
func TestSSEClientDisconnectRemovesSubscriber(t *testing.T) {
|
||||
s, _ := testServer(t)
|
||||
|
||||
r := chi.NewRouter()
|
||||
s.mountAPI(r)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/scan/progress", nil)
|
||||
req = req.WithContext(ctx)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.ServeHTTP(w, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, s.sse.ClientCount(), "should have 1 subscriber")
|
||||
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
// After disconnect, subscriber should be removed
|
||||
// Give a small moment for cleanup
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
assert.Equal(t, 0, s.sse.ClientCount(), "should have 0 subscribers after disconnect")
|
||||
}
|
||||
Reference in New Issue
Block a user