package sources import ( "context" "errors" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/salvacybersec/keyhunter/pkg/types" ) // MaxURLContentLength is the hard cap on URLSource response bodies. const MaxURLContentLength int64 = 50 * 1024 * 1024 // 50 MB // DefaultURLTimeout is the overall request timeout (connect + read + body). const DefaultURLTimeout = 30 * time.Second // allowedContentTypes is the whitelist of Content-Type prefixes URLSource // will accept. Binary types (images, archives, executables) are rejected. var allowedContentTypes = []string{ "text/", "application/json", "application/javascript", "application/xml", "application/x-yaml", "application/yaml", } // URLSource fetches a remote resource over HTTP(S) and emits its body as chunks. type URLSource struct { URL string Client *http.Client UserAgent string Insecure bool // skip TLS verification (default false) ChunkSize int } // NewURLSource creates a URLSource with sane defaults. func NewURLSource(rawURL string) *URLSource { return &URLSource{ URL: rawURL, Client: defaultHTTPClient(), UserAgent: "keyhunter/dev", ChunkSize: defaultChunkSize, } } func defaultHTTPClient() *http.Client { return &http.Client{ Timeout: DefaultURLTimeout, CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= 5 { return errors.New("stopped after 5 redirects") } return nil }, } } // Chunks validates the URL, issues a GET, and emits the response body as chunks. func (u *URLSource) Chunks(ctx context.Context, out chan<- types.Chunk) error { parsed, err := url.Parse(u.URL) if err != nil { return fmt.Errorf("URLSource: parse %q: %w", u.URL, err) } if parsed.Scheme != "http" && parsed.Scheme != "https" { return fmt.Errorf("URLSource: unsupported scheme %q (only http/https)", parsed.Scheme) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.URL, nil) if err != nil { return fmt.Errorf("URLSource: new request: %w", err) } req.Header.Set("User-Agent", u.UserAgent) client := u.Client if client == nil { client = defaultHTTPClient() } resp, err := client.Do(req) if err != nil { return fmt.Errorf("URLSource: fetch: %w", err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("URLSource: non-2xx status %d from %s", resp.StatusCode, u.URL) } ct := resp.Header.Get("Content-Type") if !isAllowedContentType(ct) { return fmt.Errorf("URLSource: disallowed Content-Type %q", ct) } if resp.ContentLength > MaxURLContentLength { return fmt.Errorf("URLSource: Content-Length %d exceeds cap %d", resp.ContentLength, MaxURLContentLength) } // LimitReader cap + 1 to detect overflow even if ContentLength was missing/wrong. limited := io.LimitReader(resp.Body, MaxURLContentLength+1) data, err := io.ReadAll(limited) if err != nil { return fmt.Errorf("URLSource: read body: %w", err) } if int64(len(data)) > MaxURLContentLength { return fmt.Errorf("URLSource: body exceeds %d bytes", MaxURLContentLength) } if len(data) == 0 { return nil } source := "url:" + u.URL return emitByteChunks(ctx, data, source, u.ChunkSize, out) } func isAllowedContentType(ct string) bool { if ct == "" { return true // some servers omit; trust and scan } // Strip parameters like "; charset=utf-8". if idx := strings.Index(ct, ";"); idx >= 0 { ct = ct[:idx] } ct = strings.TrimSpace(strings.ToLower(ct)) for _, prefix := range allowedContentTypes { if strings.HasPrefix(ct, prefix) { return true } } return false }