feat: Redesign vulnerability reporting with nested XML code locations and CVSS

Replace 12 flat parameters (code_file, code_before, code_after, code_diff,
and 8 CVSS fields) with structured nested XML fields: code_locations with
co-located fix_before/fix_after per location, cvss_breakdown, and cwe.

This enables multi-file vulnerability locations, per-location fixes with
precise line numbers, data flow representation (source/sink), CWE
classification, and compatibility with GitHub/GitLab PR review APIs.
This commit is contained in:
0xallam
2026-02-15 16:40:26 -08:00
committed by Ahmed Allam
parent 2b94633212
commit d6e9b3b7cf
6 changed files with 404 additions and 210 deletions

View File

@@ -1,8 +1,120 @@
import contextlib
import re
from pathlib import PurePosixPath
from typing import Any
from strix.tools.registry import register_tool
_CVSS_FIELDS = (
"attack_vector",
"attack_complexity",
"privileges_required",
"user_interaction",
"scope",
"confidentiality",
"integrity",
"availability",
)
def parse_cvss_xml(xml_str: str) -> dict[str, str] | None:
if not xml_str or not xml_str.strip():
return None
result = {}
for field in _CVSS_FIELDS:
match = re.search(rf"<{field}>(.*?)</{field}>", xml_str, re.DOTALL)
if match:
result[field] = match.group(1).strip()
return result if result else None
def parse_code_locations_xml(xml_str: str) -> list[dict[str, Any]] | None:
if not xml_str or not xml_str.strip():
return None
locations = []
for loc_match in re.finditer(r"<location>(.*?)</location>", xml_str, re.DOTALL):
loc: dict[str, Any] = {}
loc_content = loc_match.group(1)
for field in (
"file",
"start_line",
"end_line",
"snippet",
"label",
"fix_before",
"fix_after",
):
field_match = re.search(rf"<{field}>(.*?)</{field}>", loc_content, re.DOTALL)
if field_match:
raw = field_match.group(1)
value = (
raw.strip("\n")
if field in ("snippet", "fix_before", "fix_after")
else raw.strip()
)
if field in ("start_line", "end_line"):
with contextlib.suppress(ValueError, TypeError):
loc[field] = int(value)
elif value:
loc[field] = value
if loc.get("file") and loc.get("start_line") is not None:
locations.append(loc)
return locations if locations else None
def _validate_file_path(path: str) -> str | None:
if not path or not path.strip():
return "file path cannot be empty"
p = PurePosixPath(path)
if p.is_absolute():
return f"file path must be relative, got absolute: '{path}'"
if ".." in p.parts:
return f"file path must not contain '..': '{path}'"
return None
def _validate_code_locations(locations: list[dict[str, Any]]) -> list[str]:
errors = []
for i, loc in enumerate(locations):
path_err = _validate_file_path(loc.get("file", ""))
if path_err:
errors.append(f"code_locations[{i}]: {path_err}")
start = loc.get("start_line")
if not isinstance(start, int) or start < 1:
errors.append(f"code_locations[{i}]: start_line must be a positive integer")
end = loc.get("end_line")
if end is None:
errors.append(f"code_locations[{i}]: end_line is required")
elif not isinstance(end, int) or end < 1:
errors.append(f"code_locations[{i}]: end_line must be a positive integer")
elif isinstance(start, int) and end < start:
errors.append(f"code_locations[{i}]: end_line ({end}) must be >= start_line ({start})")
return errors
def _extract_cve(cve: str) -> str:
match = re.search(r"CVE-\d{4}-\d{4,}", cve)
return match.group(0) if match else cve.strip()
def _validate_cve(cve: str) -> str | None:
if not re.match(r"^CVE-\d{4}-\d{4,}$", cve):
return f"invalid CVE format: '{cve}' (expected 'CVE-YYYY-NNNNN')"
return None
def _extract_cwe(cwe: str) -> str:
match = re.search(r"CWE-\d+", cwe)
return match.group(0) if match else cwe.strip()
def _validate_cwe(cwe: str) -> str | None:
if not re.match(r"^CWE-\d+$", cwe):
return f"invalid CWE format: '{cwe}' (expected 'CWE-NNN')"
return None
def calculate_cvss_and_severity(
attack_vector: str,
attack_complexity: str,
@@ -87,7 +199,7 @@ def _validate_cvss_parameters(**kwargs: str) -> list[str]:
@register_tool(sandbox_execution=False)
def create_vulnerability_report(
def create_vulnerability_report( # noqa: PLR0912
title: str,
description: str,
impact: str,
@@ -96,23 +208,12 @@ def create_vulnerability_report(
poc_description: str,
poc_script_code: str,
remediation_steps: str,
# CVSS Breakdown Components
attack_vector: str,
attack_complexity: str,
privileges_required: str,
user_interaction: str,
scope: str,
confidentiality: str,
integrity: str,
availability: str,
# Optional fields
cvss_breakdown: str,
endpoint: str | None = None,
method: str | None = None,
cve: str | None = None,
code_file: str | None = None,
code_before: str | None = None,
code_after: str | None = None,
code_diff: str | None = None,
cwe: str | None = None,
code_locations: str | None = None,
) -> dict[str, Any]:
validation_errors = _validate_required_fields(
title=title,
@@ -125,32 +226,32 @@ def create_vulnerability_report(
remediation_steps=remediation_steps,
)
validation_errors.extend(
_validate_cvss_parameters(
attack_vector=attack_vector,
attack_complexity=attack_complexity,
privileges_required=privileges_required,
user_interaction=user_interaction,
scope=scope,
confidentiality=confidentiality,
integrity=integrity,
availability=availability,
)
)
parsed_cvss = parse_cvss_xml(cvss_breakdown)
if not parsed_cvss:
validation_errors.append("cvss: could not parse CVSS breakdown XML")
else:
validation_errors.extend(_validate_cvss_parameters(**parsed_cvss))
parsed_locations = parse_code_locations_xml(code_locations) if code_locations else None
if parsed_locations:
validation_errors.extend(_validate_code_locations(parsed_locations))
if cve:
cve = _extract_cve(cve)
cve_err = _validate_cve(cve)
if cve_err:
validation_errors.append(cve_err)
if cwe:
cwe = _extract_cwe(cwe)
cwe_err = _validate_cwe(cwe)
if cwe_err:
validation_errors.append(cwe_err)
if validation_errors:
return {"success": False, "message": "Validation failed", "errors": validation_errors}
cvss_score, severity, cvss_vector = calculate_cvss_and_severity(
attack_vector,
attack_complexity,
privileges_required,
user_interaction,
scope,
confidentiality,
integrity,
availability,
)
assert parsed_cvss is not None
cvss_score, severity, cvss_vector = calculate_cvss_and_severity(**parsed_cvss)
try:
from strix.telemetry.tracer import get_global_tracer
@@ -196,17 +297,6 @@ def create_vulnerability_report(
"reason": dedupe_result.get("reason", ""),
}
cvss_breakdown = {
"attack_vector": attack_vector,
"attack_complexity": attack_complexity,
"privileges_required": privileges_required,
"user_interaction": user_interaction,
"scope": scope,
"confidentiality": confidentiality,
"integrity": integrity,
"availability": availability,
}
report_id = tracer.add_vulnerability_report(
title=title,
description=description,
@@ -218,14 +308,12 @@ def create_vulnerability_report(
poc_script_code=poc_script_code,
remediation_steps=remediation_steps,
cvss=cvss_score,
cvss_breakdown=cvss_breakdown,
cvss_breakdown=parsed_cvss,
endpoint=endpoint,
method=method,
cve=cve,
code_file=code_file,
code_before=code_before,
code_after=code_after,
code_diff=code_diff,
cwe=cwe,
code_locations=parsed_locations,
)
return {