- VerifyAll(ctx, findings, reg, workers) returns a result channel closed after all findings are processed or ctx is cancelled. - Default worker count of 10 when workers <= 0. - Missing providers yield StatusUnknown with 'provider not found' error. - Graceful context cancellation stops dispatch while still draining inflight.
229 lines
6.0 KiB
Go
229 lines
6.0 KiB
Go
package verify
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"io"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/panjf2000/ants/v2"
|
|
"github.com/salvacybersec/keyhunter/pkg/engine"
|
|
"github.com/salvacybersec/keyhunter/pkg/providers"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
// DefaultTimeout is the per-call verification timeout when none is configured.
|
|
const DefaultTimeout = 10 * time.Second
|
|
|
|
// DefaultWorkers is the fallback worker-pool size for VerifyAll.
|
|
const DefaultWorkers = 10
|
|
|
|
// maxMetadataBody caps how much of a JSON response we read for metadata extraction.
|
|
const maxMetadataBody = 1 << 20 // 1 MiB
|
|
|
|
// HTTPVerifier performs a single HTTP call against a provider's VerifySpec and
|
|
// classifies the response. It is YAML-driven — no per-provider switches live here.
|
|
type HTTPVerifier struct {
|
|
Client *http.Client
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// NewHTTPVerifier returns an HTTPVerifier with a TLS 1.2+ HTTP client and the
|
|
// given per-call timeout (falling back to DefaultTimeout when timeout <= 0).
|
|
func NewHTTPVerifier(timeout time.Duration) *HTTPVerifier {
|
|
if timeout <= 0 {
|
|
timeout = DefaultTimeout
|
|
}
|
|
return &HTTPVerifier{
|
|
Client: &http.Client{
|
|
Timeout: timeout,
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
|
},
|
|
},
|
|
Timeout: timeout,
|
|
}
|
|
}
|
|
|
|
// Verify runs a single verification against a provider's verify endpoint.
|
|
// It never returns a Go error — transport/classification failures are encoded
|
|
// in the Result. Callers classify via Result.Status against the Status* constants.
|
|
func (v *HTTPVerifier) Verify(ctx context.Context, f engine.Finding, p providers.Provider) Result {
|
|
start := time.Now()
|
|
res := Result{
|
|
ProviderName: f.ProviderName,
|
|
KeyMasked: f.KeyMasked,
|
|
Status: StatusUnknown,
|
|
}
|
|
|
|
spec := p.Verify
|
|
if spec.URL == "" {
|
|
return res // StatusUnknown: provider has no verify endpoint
|
|
}
|
|
if strings.HasPrefix(strings.ToLower(spec.URL), "http://") {
|
|
res.Status = StatusError
|
|
res.Error = "verify URL must be HTTPS"
|
|
return res
|
|
}
|
|
|
|
// Substitute {{KEY}} (and legacy {KEY}) in URL, headers, and body.
|
|
url := substituteKey(spec.URL, f.KeyValue)
|
|
|
|
method := spec.Method
|
|
if method == "" {
|
|
method = http.MethodGet
|
|
}
|
|
|
|
var bodyReader io.Reader
|
|
if spec.Body != "" {
|
|
bodyReader = bytes.NewBufferString(substituteKey(spec.Body, f.KeyValue))
|
|
}
|
|
|
|
reqCtx, cancel := context.WithTimeout(ctx, v.Timeout)
|
|
defer cancel()
|
|
|
|
req, err := http.NewRequestWithContext(reqCtx, method, url, bodyReader)
|
|
if err != nil {
|
|
res.Status = StatusError
|
|
res.Error = err.Error()
|
|
return res
|
|
}
|
|
for k, val := range spec.Headers {
|
|
req.Header.Set(k, substituteKey(val, f.KeyValue))
|
|
}
|
|
|
|
resp, err := v.Client.Do(req)
|
|
res.ResponseTime = time.Since(start)
|
|
if err != nil {
|
|
res.Status = StatusError
|
|
res.Error = err.Error()
|
|
return res
|
|
}
|
|
defer resp.Body.Close()
|
|
res.HTTPCode = resp.StatusCode
|
|
|
|
// Classify. Success codes take precedence, then failure, then rate-limit.
|
|
switch {
|
|
case containsInt(spec.EffectiveSuccessCodes(), resp.StatusCode):
|
|
res.Status = StatusLive
|
|
case containsInt(spec.EffectiveFailureCodes(), resp.StatusCode):
|
|
res.Status = StatusDead
|
|
case containsInt(spec.EffectiveRateLimitCodes(), resp.StatusCode):
|
|
res.Status = StatusRateLimited
|
|
if ra := resp.Header.Get("Retry-After"); ra != "" {
|
|
if secs, convErr := strconv.Atoi(ra); convErr == nil {
|
|
res.RetryAfter = time.Duration(secs) * time.Second
|
|
}
|
|
}
|
|
default:
|
|
res.Status = StatusUnknown
|
|
}
|
|
|
|
// Metadata extraction only for live responses with JSON body and configured paths.
|
|
if res.Status == StatusLive && len(spec.MetadataPaths) > 0 {
|
|
if strings.Contains(resp.Header.Get("Content-Type"), "application/json") {
|
|
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, maxMetadataBody))
|
|
meta := make(map[string]string, len(spec.MetadataPaths))
|
|
for displayName, path := range spec.MetadataPaths {
|
|
if r := gjson.GetBytes(bodyBytes, path); r.Exists() {
|
|
meta[displayName] = r.String()
|
|
}
|
|
}
|
|
if len(meta) > 0 {
|
|
res.Metadata = meta
|
|
}
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
// VerifyAll runs verification for every finding through an ants worker pool
|
|
// of the given size (or DefaultWorkers when workers <= 0). The returned
|
|
// channel is closed once every finding has been processed or ctx is cancelled.
|
|
//
|
|
// Findings whose provider is not present in reg are emitted as
|
|
// Result{Status: StatusUnknown, Error: "provider not found in registry"}
|
|
// rather than silently dropped.
|
|
func (v *HTTPVerifier) VerifyAll(ctx context.Context, findings []engine.Finding, reg *providers.Registry, workers int) <-chan Result {
|
|
if workers <= 0 {
|
|
workers = DefaultWorkers
|
|
}
|
|
out := make(chan Result, len(findings))
|
|
|
|
pool, err := ants.NewPool(workers)
|
|
if err != nil {
|
|
go func() {
|
|
defer close(out)
|
|
for _, f := range findings {
|
|
out <- Result{
|
|
ProviderName: f.ProviderName,
|
|
KeyMasked: f.KeyMasked,
|
|
Status: StatusError,
|
|
Error: "pool init: " + err.Error(),
|
|
}
|
|
}
|
|
}()
|
|
return out
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
go func() {
|
|
defer close(out)
|
|
defer pool.Release()
|
|
|
|
for i := range findings {
|
|
if ctx.Err() != nil {
|
|
break
|
|
}
|
|
f := findings[i]
|
|
wg.Add(1)
|
|
submitErr := pool.Submit(func() {
|
|
defer wg.Done()
|
|
prov, ok := reg.Get(f.ProviderName)
|
|
if !ok {
|
|
out <- Result{
|
|
ProviderName: f.ProviderName,
|
|
KeyMasked: f.KeyMasked,
|
|
Status: StatusUnknown,
|
|
Error: "provider not found in registry",
|
|
}
|
|
return
|
|
}
|
|
out <- v.Verify(ctx, f, prov)
|
|
})
|
|
if submitErr != nil {
|
|
wg.Done()
|
|
out <- Result{
|
|
ProviderName: f.ProviderName,
|
|
KeyMasked: f.KeyMasked,
|
|
Status: StatusError,
|
|
Error: submitErr.Error(),
|
|
}
|
|
}
|
|
}
|
|
wg.Wait()
|
|
}()
|
|
return out
|
|
}
|
|
|
|
// substituteKey replaces both {{KEY}} and the legacy {KEY} placeholder.
|
|
func substituteKey(s, key string) string {
|
|
s = strings.ReplaceAll(s, "{{KEY}}", key)
|
|
s = strings.ReplaceAll(s, "{KEY}", key)
|
|
return s
|
|
}
|
|
|
|
func containsInt(haystack []int, needle int) bool {
|
|
for _, x := range haystack {
|
|
if x == needle {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|