- StdinSource reads from an injectable io.Reader (INPUT-03) - URLSource fetches http/https with 30s timeout, 50MB cap, scheme whitelist, and Content-Type filter (INPUT-04) - ClipboardSource wraps atotto/clipboard with graceful fallback for missing tooling (INPUT-05) - emitByteChunks local helper mirrors file.go windowing to stay independent of sibling wave-1 plans - Tests cover happy path, cancellation, redirects, oversize bodies, binary content types, scheme rejection, and clipboard error paths
136 lines
3.5 KiB
Go
136 lines
3.5 KiB
Go
package sources
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/salvacybersec/keyhunter/pkg/types"
|
|
)
|
|
|
|
// MaxURLContentLength is the hard cap on URLSource response bodies.
|
|
const MaxURLContentLength int64 = 50 * 1024 * 1024 // 50 MB
|
|
|
|
// DefaultURLTimeout is the overall request timeout (connect + read + body).
|
|
const DefaultURLTimeout = 30 * time.Second
|
|
|
|
// allowedContentTypes is the whitelist of Content-Type prefixes URLSource
|
|
// will accept. Binary types (images, archives, executables) are rejected.
|
|
var allowedContentTypes = []string{
|
|
"text/",
|
|
"application/json",
|
|
"application/javascript",
|
|
"application/xml",
|
|
"application/x-yaml",
|
|
"application/yaml",
|
|
}
|
|
|
|
// URLSource fetches a remote resource over HTTP(S) and emits its body as chunks.
|
|
type URLSource struct {
|
|
URL string
|
|
Client *http.Client
|
|
UserAgent string
|
|
Insecure bool // skip TLS verification (default false)
|
|
ChunkSize int
|
|
}
|
|
|
|
// NewURLSource creates a URLSource with sane defaults.
|
|
func NewURLSource(rawURL string) *URLSource {
|
|
return &URLSource{
|
|
URL: rawURL,
|
|
Client: defaultHTTPClient(),
|
|
UserAgent: "keyhunter/dev",
|
|
ChunkSize: defaultChunkSize,
|
|
}
|
|
}
|
|
|
|
func defaultHTTPClient() *http.Client {
|
|
return &http.Client{
|
|
Timeout: DefaultURLTimeout,
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
if len(via) >= 5 {
|
|
return errors.New("stopped after 5 redirects")
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
// Chunks validates the URL, issues a GET, and emits the response body as chunks.
|
|
func (u *URLSource) Chunks(ctx context.Context, out chan<- types.Chunk) error {
|
|
parsed, err := url.Parse(u.URL)
|
|
if err != nil {
|
|
return fmt.Errorf("URLSource: parse %q: %w", u.URL, err)
|
|
}
|
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
|
return fmt.Errorf("URLSource: unsupported scheme %q (only http/https)", parsed.Scheme)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.URL, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("URLSource: new request: %w", err)
|
|
}
|
|
req.Header.Set("User-Agent", u.UserAgent)
|
|
|
|
client := u.Client
|
|
if client == nil {
|
|
client = defaultHTTPClient()
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("URLSource: fetch: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
return fmt.Errorf("URLSource: non-2xx status %d from %s", resp.StatusCode, u.URL)
|
|
}
|
|
|
|
ct := resp.Header.Get("Content-Type")
|
|
if !isAllowedContentType(ct) {
|
|
return fmt.Errorf("URLSource: disallowed Content-Type %q", ct)
|
|
}
|
|
|
|
if resp.ContentLength > MaxURLContentLength {
|
|
return fmt.Errorf("URLSource: Content-Length %d exceeds cap %d", resp.ContentLength, MaxURLContentLength)
|
|
}
|
|
|
|
// LimitReader cap + 1 to detect overflow even if ContentLength was missing/wrong.
|
|
limited := io.LimitReader(resp.Body, MaxURLContentLength+1)
|
|
data, err := io.ReadAll(limited)
|
|
if err != nil {
|
|
return fmt.Errorf("URLSource: read body: %w", err)
|
|
}
|
|
if int64(len(data)) > MaxURLContentLength {
|
|
return fmt.Errorf("URLSource: body exceeds %d bytes", MaxURLContentLength)
|
|
}
|
|
if len(data) == 0 {
|
|
return nil
|
|
}
|
|
|
|
source := "url:" + u.URL
|
|
return emitByteChunks(ctx, data, source, u.ChunkSize, out)
|
|
}
|
|
|
|
func isAllowedContentType(ct string) bool {
|
|
if ct == "" {
|
|
return true // some servers omit; trust and scan
|
|
}
|
|
// Strip parameters like "; charset=utf-8".
|
|
if idx := strings.Index(ct, ";"); idx >= 0 {
|
|
ct = ct[:idx]
|
|
}
|
|
ct = strings.TrimSpace(strings.ToLower(ct))
|
|
for _, prefix := range allowedContentTypes {
|
|
if strings.HasPrefix(ct, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|