feat: Implement diff-scope functionality for pull requests and CI integration
This commit is contained in:
10
README.md
10
README.md
@@ -161,6 +161,9 @@ strix --target api.your-app.com --instruction "Focus on business logic flaws and
|
||||
|
||||
# Provide detailed instructions through file (e.g., rules of engagement, scope, exclusions)
|
||||
strix --target api.your-app.com --instruction-file ./instruction.md
|
||||
|
||||
# Force PR diff-scope against a specific base branch
|
||||
strix -n --target ./ --scan-mode quick --scope-mode diff --diff-base origin/main
|
||||
```
|
||||
|
||||
### Headless Mode
|
||||
@@ -186,6 +189,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install Strix
|
||||
run: curl -sSL https://strix.ai/install | bash
|
||||
@@ -198,6 +203,11 @@ jobs:
|
||||
run: strix -n -t ./ --scan-mode quick
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> In CI pull request runs, Strix automatically scopes quick reviews to changed files.
|
||||
> If diff-scope cannot resolve, ensure checkout uses full history (`fetch-depth: 0`) or pass
|
||||
> `--diff-base` explicitly.
|
||||
|
||||
### Configuration
|
||||
|
||||
```bash
|
||||
|
||||
@@ -13,6 +13,12 @@ Use the `-n` or `--non-interactive` flag:
|
||||
strix -n --target ./app --scan-mode quick
|
||||
```
|
||||
|
||||
For pull-request style CI runs, Strix automatically scopes quick scans to changed files. You can force this behavior and set a base ref explicitly:
|
||||
|
||||
```bash
|
||||
strix -n --target ./app --scan-mode quick --scope-mode diff --diff-base origin/main
|
||||
```
|
||||
|
||||
## Exit Codes
|
||||
|
||||
| Code | Meaning |
|
||||
@@ -78,3 +84,7 @@ jobs:
|
||||
<Note>
|
||||
All CI platforms require Docker access. Ensure your runner has Docker available.
|
||||
</Note>
|
||||
|
||||
<Tip>
|
||||
If diff-scope fails in CI, fetch full git history (for example, `fetch-depth: 0` in GitHub Actions) so merge-base and branch comparison can be resolved.
|
||||
</Tip>
|
||||
|
||||
@@ -18,6 +18,8 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install Strix
|
||||
run: curl -sSL https://strix.ai/install | bash
|
||||
@@ -58,3 +60,7 @@ The workflow fails when vulnerabilities are found:
|
||||
<Tip>
|
||||
Use `quick` mode for PRs to keep feedback fast. Schedule `deep` scans nightly.
|
||||
</Tip>
|
||||
|
||||
<Note>
|
||||
For pull_request workflows, Strix automatically uses changed-files diff-scope in CI/headless runs. If diff resolution fails, ensure full history is fetched (`fetch-depth: 0`) or set `--diff-base`.
|
||||
</Note>
|
||||
|
||||
@@ -27,6 +27,14 @@ strix --target <target> [options]
|
||||
Scan depth: `quick`, `standard`, or `deep`.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="--scope-mode" type="string" default="auto">
|
||||
Code scope mode: `auto` (enable PR diff-scope in CI/headless runs), `diff` (force changed-files scope), or `full` (disable diff-scope).
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="--diff-base" type="string">
|
||||
Target branch or commit to compare against (e.g., `origin/main`). Defaults to the repository's default branch.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="--non-interactive, -n" type="boolean">
|
||||
Run in headless mode without TUI. Ideal for CI/CD.
|
||||
</ParamField>
|
||||
@@ -50,6 +58,9 @@ strix --target api.example.com --instruction "Focus on IDOR and auth bypass"
|
||||
# CI/CD mode
|
||||
strix -n --target ./ --scan-mode quick
|
||||
|
||||
# Force diff-scope against a specific base ref
|
||||
strix -n --target ./ --scan-mode quick --scope-mode diff --diff-base origin/main
|
||||
|
||||
# Multi-target white-box testing
|
||||
strix -t https://github.com/org/app -t https://staging.example.com
|
||||
```
|
||||
|
||||
@@ -21,6 +21,7 @@ class StrixAgent(BaseAgent):
|
||||
async def execute_scan(self, scan_config: dict[str, Any]) -> dict[str, Any]: # noqa: PLR0912
|
||||
user_instructions = scan_config.get("user_instructions", "")
|
||||
targets = scan_config.get("targets", [])
|
||||
diff_scope = scan_config.get("diff_scope", {}) or {}
|
||||
|
||||
repositories = []
|
||||
local_code = []
|
||||
@@ -81,6 +82,28 @@ class StrixAgent(BaseAgent):
|
||||
task_parts.append("\n\nIP Addresses:")
|
||||
task_parts.extend(f"- {ip}" for ip in ip_addresses)
|
||||
|
||||
if diff_scope.get("active"):
|
||||
task_parts.append("\n\nScope Constraints:")
|
||||
task_parts.append(
|
||||
"- Pull request diff-scope mode is active. Prioritize changed files "
|
||||
"and use other files only for context."
|
||||
)
|
||||
for repo_scope in diff_scope.get("repos", []):
|
||||
repo_label = (
|
||||
repo_scope.get("workspace_subdir")
|
||||
or repo_scope.get("source_path")
|
||||
or "repository"
|
||||
)
|
||||
changed_count = repo_scope.get("analyzable_files_count", 0)
|
||||
deleted_count = repo_scope.get("deleted_files_count", 0)
|
||||
task_parts.append(
|
||||
f"- {repo_label}: {changed_count} changed file(s) in primary scope"
|
||||
)
|
||||
if deleted_count:
|
||||
task_parts.append(
|
||||
f"- {repo_label}: {deleted_count} deleted file(s) are context-only"
|
||||
)
|
||||
|
||||
task_description = " ".join(task_parts)
|
||||
|
||||
if user_instructions:
|
||||
|
||||
@@ -72,6 +72,7 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915
|
||||
"targets": args.targets_info,
|
||||
"user_instructions": args.instruction or "",
|
||||
"run_name": args.run_name,
|
||||
"diff_scope": getattr(args, "diff_scope", {"active": False}),
|
||||
}
|
||||
|
||||
llm_config = LLMConfig(scan_mode=scan_mode)
|
||||
|
||||
@@ -34,6 +34,7 @@ from strix.interface.utils import ( # noqa: E402
|
||||
image_exists,
|
||||
infer_target_type,
|
||||
process_pull_line,
|
||||
resolve_diff_scope_context,
|
||||
rewrite_localhost_targets,
|
||||
validate_config_file,
|
||||
validate_llm_response,
|
||||
@@ -357,6 +358,28 @@ Examples:
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scope-mode",
|
||||
type=str,
|
||||
choices=["auto", "diff", "full"],
|
||||
default="auto",
|
||||
help=(
|
||||
"Scope mode for code targets: "
|
||||
"'auto' enables PR diff-scope in CI/headless runs, "
|
||||
"'diff' forces changed-files scope, "
|
||||
"'full' disables diff-scope."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diff-base",
|
||||
type=str,
|
||||
help=(
|
||||
"Target branch or commit to compare against (e.g., origin/main). "
|
||||
"Defaults to the repository's default branch."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
@@ -517,7 +540,7 @@ def persist_config() -> None:
|
||||
save_current_config()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
def main() -> None: # noqa: PLR0912, PLR0915
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
@@ -544,6 +567,38 @@ def main() -> None:
|
||||
target_info["details"]["cloned_repo_path"] = cloned_path
|
||||
|
||||
args.local_sources = collect_local_sources(args.targets_info)
|
||||
try:
|
||||
diff_scope = resolve_diff_scope_context(
|
||||
local_sources=args.local_sources,
|
||||
scope_mode=args.scope_mode,
|
||||
diff_base=args.diff_base,
|
||||
non_interactive=args.non_interactive,
|
||||
)
|
||||
except ValueError as e:
|
||||
console = Console()
|
||||
error_text = Text()
|
||||
error_text.append("DIFF SCOPE RESOLUTION FAILED", style="bold red")
|
||||
error_text.append("\n\n", style="white")
|
||||
error_text.append(str(e), style="white")
|
||||
|
||||
panel = Panel(
|
||||
error_text,
|
||||
title="[bold white]STRIX",
|
||||
title_align="left",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print("\n")
|
||||
console.print(panel)
|
||||
console.print()
|
||||
sys.exit(1)
|
||||
|
||||
args.diff_scope = diff_scope.metadata
|
||||
if diff_scope.instruction_block:
|
||||
if args.instruction:
|
||||
args.instruction = f"{diff_scope.instruction_block}\n\n{args.instruction}"
|
||||
else:
|
||||
args.instruction = diff_scope.instruction_block
|
||||
|
||||
is_whitebox = bool(args.local_sources)
|
||||
|
||||
|
||||
@@ -743,6 +743,7 @@ class StrixTUIApp(App): # type: ignore[misc]
|
||||
"targets": args.targets_info,
|
||||
"user_instructions": args.instruction or "",
|
||||
"run_name": args.run_name,
|
||||
"diff_scope": getattr(args, "diff_scope", {"active": False}),
|
||||
}
|
||||
|
||||
def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError, URLError
|
||||
@@ -455,6 +457,595 @@ def generate_run_name(targets_info: list[dict[str, Any]] | None = None) -> str:
|
||||
|
||||
# Target processing utilities
|
||||
|
||||
_SUPPORTED_SCOPE_MODES = {"auto", "diff", "full"}
|
||||
_MAX_FILES_PER_SECTION = 120
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffEntry:
|
||||
status: str
|
||||
path: str
|
||||
old_path: str | None = None
|
||||
similarity: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepoDiffScope:
|
||||
source_path: str
|
||||
workspace_subdir: str | None
|
||||
base_ref: str
|
||||
merge_base: str
|
||||
added_files: list[str]
|
||||
modified_files: list[str]
|
||||
renamed_files: list[dict[str, Any]]
|
||||
deleted_files: list[str]
|
||||
analyzable_files: list[str]
|
||||
truncated_sections: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
def to_metadata(self) -> dict[str, Any]:
|
||||
return {
|
||||
"source_path": self.source_path,
|
||||
"workspace_subdir": self.workspace_subdir,
|
||||
"base_ref": self.base_ref,
|
||||
"merge_base": self.merge_base,
|
||||
"added_files": self.added_files,
|
||||
"modified_files": self.modified_files,
|
||||
"renamed_files": self.renamed_files,
|
||||
"deleted_files": self.deleted_files,
|
||||
"analyzable_files": self.analyzable_files,
|
||||
"added_files_count": len(self.added_files),
|
||||
"modified_files_count": len(self.modified_files),
|
||||
"renamed_files_count": len(self.renamed_files),
|
||||
"deleted_files_count": len(self.deleted_files),
|
||||
"analyzable_files_count": len(self.analyzable_files),
|
||||
"truncated_sections": self.truncated_sections,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffScopeResult:
|
||||
active: bool
|
||||
mode: str
|
||||
instruction_block: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _run_git_command(
|
||||
repo_path: Path, args: list[str], check: bool = True
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
return subprocess.run( # noqa: S603
|
||||
["git", "-C", str(repo_path), *args], # noqa: S607
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=check,
|
||||
)
|
||||
|
||||
|
||||
def _run_git_command_raw(
|
||||
repo_path: Path, args: list[str], check: bool = True
|
||||
) -> subprocess.CompletedProcess[bytes]:
|
||||
return subprocess.run( # noqa: S603
|
||||
["git", "-C", str(repo_path), *args], # noqa: S607
|
||||
capture_output=True,
|
||||
check=check,
|
||||
)
|
||||
|
||||
|
||||
def _is_ci_environment(env: dict[str, str]) -> bool:
|
||||
return any(
|
||||
env.get(key)
|
||||
for key in (
|
||||
"CI",
|
||||
"GITHUB_ACTIONS",
|
||||
"GITLAB_CI",
|
||||
"JENKINS_URL",
|
||||
"BUILDKITE",
|
||||
"CIRCLECI",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_pr_environment(env: dict[str, str]) -> bool:
|
||||
return any(
|
||||
env.get(key)
|
||||
for key in (
|
||||
"GITHUB_BASE_REF",
|
||||
"GITHUB_HEAD_REF",
|
||||
"CI_MERGE_REQUEST_TARGET_BRANCH_NAME",
|
||||
"GITLAB_MERGE_REQUEST_TARGET_BRANCH_NAME",
|
||||
"SYSTEM_PULLREQUEST_TARGETBRANCH",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_git_repo(repo_path: Path) -> bool:
|
||||
result = _run_git_command(repo_path, ["rev-parse", "--is-inside-work-tree"], check=False)
|
||||
return result.returncode == 0 and result.stdout.strip().lower() == "true"
|
||||
|
||||
|
||||
def _is_repo_shallow(repo_path: Path) -> bool:
|
||||
result = _run_git_command(repo_path, ["rev-parse", "--is-shallow-repository"], check=False)
|
||||
if result.returncode == 0:
|
||||
value = result.stdout.strip().lower()
|
||||
if value in {"true", "false"}:
|
||||
return value == "true"
|
||||
|
||||
git_meta = repo_path / ".git"
|
||||
if git_meta.is_dir():
|
||||
return (git_meta / "shallow").exists()
|
||||
if git_meta.is_file():
|
||||
try:
|
||||
content = git_meta.read_text(encoding="utf-8").strip()
|
||||
except OSError:
|
||||
return False
|
||||
if content.startswith("gitdir:"):
|
||||
git_dir = content.split(":", 1)[1].strip()
|
||||
resolved = (repo_path / git_dir).resolve()
|
||||
return (resolved / "shallow").exists()
|
||||
return False
|
||||
|
||||
|
||||
def _git_ref_exists(repo_path: Path, ref: str) -> bool:
|
||||
result = _run_git_command(repo_path, ["rev-parse", "--verify", "--quiet", ref], check=False)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def _resolve_origin_head_ref(repo_path: Path) -> str | None:
|
||||
result = _run_git_command(
|
||||
repo_path, ["symbolic-ref", "--quiet", "refs/remotes/origin/HEAD"], check=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
ref = result.stdout.strip()
|
||||
return ref or None
|
||||
|
||||
|
||||
def _extract_branch_name(ref: str | None) -> str | None:
|
||||
if not ref:
|
||||
return None
|
||||
value = ref.strip()
|
||||
if not value:
|
||||
return None
|
||||
return value.split("/")[-1]
|
||||
|
||||
|
||||
def _extract_github_base_sha(env: dict[str, str]) -> str | None:
|
||||
event_path = env.get("GITHUB_EVENT_PATH", "").strip()
|
||||
if not event_path:
|
||||
return None
|
||||
|
||||
path = Path(event_path)
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return None
|
||||
|
||||
base_sha = payload.get("pull_request", {}).get("base", {}).get("sha")
|
||||
if isinstance(base_sha, str) and base_sha.strip():
|
||||
return base_sha.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_default_branch_name(repo_path: Path, env: dict[str, str]) -> str | None:
|
||||
github_base_ref = env.get("GITHUB_BASE_REF", "").strip()
|
||||
if github_base_ref:
|
||||
return github_base_ref
|
||||
|
||||
origin_head = _resolve_origin_head_ref(repo_path)
|
||||
if origin_head:
|
||||
branch = _extract_branch_name(origin_head)
|
||||
if branch:
|
||||
return branch
|
||||
|
||||
if _git_ref_exists(repo_path, "refs/remotes/origin/main"):
|
||||
return "main"
|
||||
if _git_ref_exists(repo_path, "refs/remotes/origin/master"):
|
||||
return "master"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_base_ref(repo_path: Path, diff_base: str | None, env: dict[str, str]) -> str:
|
||||
if diff_base and diff_base.strip():
|
||||
return diff_base.strip()
|
||||
|
||||
github_base_ref = env.get("GITHUB_BASE_REF", "").strip()
|
||||
if github_base_ref:
|
||||
github_candidate = f"refs/remotes/origin/{github_base_ref}"
|
||||
if _git_ref_exists(repo_path, github_candidate):
|
||||
return github_candidate
|
||||
|
||||
github_base_sha = _extract_github_base_sha(env)
|
||||
if github_base_sha and _git_ref_exists(repo_path, github_base_sha):
|
||||
return github_base_sha
|
||||
|
||||
origin_head = _resolve_origin_head_ref(repo_path)
|
||||
if origin_head and _git_ref_exists(repo_path, origin_head):
|
||||
return origin_head
|
||||
|
||||
if _git_ref_exists(repo_path, "refs/remotes/origin/main"):
|
||||
return "refs/remotes/origin/main"
|
||||
|
||||
if _git_ref_exists(repo_path, "refs/remotes/origin/master"):
|
||||
return "refs/remotes/origin/master"
|
||||
|
||||
raise ValueError(
|
||||
"Unable to resolve a base ref for diff-scope. Pass --diff-base explicitly "
|
||||
"(for example: --diff-base origin/main)."
|
||||
)
|
||||
|
||||
|
||||
def _get_current_branch_name(repo_path: Path) -> str | None:
|
||||
result = _run_git_command(repo_path, ["rev-parse", "--abbrev-ref", "HEAD"], check=False)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
branch_name = result.stdout.strip()
|
||||
if not branch_name or branch_name == "HEAD":
|
||||
return None
|
||||
return branch_name
|
||||
|
||||
|
||||
def _parse_name_status_z(raw_output: bytes) -> list[DiffEntry]:
|
||||
if not raw_output:
|
||||
return []
|
||||
|
||||
tokens = [
|
||||
token.decode("utf-8", errors="replace") for token in raw_output.split(b"\x00") if token
|
||||
]
|
||||
entries: list[DiffEntry] = []
|
||||
index = 0
|
||||
|
||||
while index < len(tokens):
|
||||
token = tokens[index]
|
||||
status_raw = token
|
||||
status_code = status_raw[:1]
|
||||
similarity: int | None = None
|
||||
if len(status_raw) > 1 and status_raw[1:].isdigit():
|
||||
similarity = int(status_raw[1:])
|
||||
|
||||
# Git's -z output for --name-status is:
|
||||
# - non-rename/copy: <status>\0<path>\0
|
||||
# - rename/copy: <statusN>\0<old_path>\0<new_path>\0
|
||||
if status_code in {"R", "C"} and index + 2 < len(tokens):
|
||||
old_path = tokens[index + 1]
|
||||
new_path = tokens[index + 2]
|
||||
entries.append(
|
||||
DiffEntry(
|
||||
status=status_code,
|
||||
path=new_path,
|
||||
old_path=old_path,
|
||||
similarity=similarity,
|
||||
)
|
||||
)
|
||||
index += 3
|
||||
continue
|
||||
|
||||
if index + 1 < len(tokens):
|
||||
path = tokens[index + 1]
|
||||
entries.append(DiffEntry(status=status_code, path=path, similarity=similarity))
|
||||
index += 2
|
||||
continue
|
||||
|
||||
# Backward-compat fallback if output is tab-delimited unexpectedly.
|
||||
status_fallback, has_tab, first_path = token.partition("\t")
|
||||
if not has_tab:
|
||||
break
|
||||
fallback_code = status_fallback[:1]
|
||||
fallback_similarity: int | None = None
|
||||
if len(status_fallback) > 1 and status_fallback[1:].isdigit():
|
||||
fallback_similarity = int(status_fallback[1:])
|
||||
entries.append(
|
||||
DiffEntry(status=fallback_code, path=first_path, similarity=fallback_similarity)
|
||||
)
|
||||
index += 1
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def _append_unique(container: list[str], seen: set[str], path: str) -> None:
|
||||
if path and path not in seen:
|
||||
seen.add(path)
|
||||
container.append(path)
|
||||
|
||||
|
||||
def _classify_diff_entries(entries: list[DiffEntry]) -> dict[str, Any]:
|
||||
added_files: list[str] = []
|
||||
modified_files: list[str] = []
|
||||
deleted_files: list[str] = []
|
||||
renamed_files: list[dict[str, Any]] = []
|
||||
analyzable_files: list[str] = []
|
||||
analyzable_seen: set[str] = set()
|
||||
modified_seen: set[str] = set()
|
||||
|
||||
for entry in entries:
|
||||
path = entry.path
|
||||
if not path:
|
||||
continue
|
||||
|
||||
if entry.status == "D":
|
||||
deleted_files.append(path)
|
||||
continue
|
||||
|
||||
if entry.status == "A":
|
||||
added_files.append(path)
|
||||
_append_unique(analyzable_files, analyzable_seen, path)
|
||||
continue
|
||||
|
||||
if entry.status == "M":
|
||||
_append_unique(modified_files, modified_seen, path)
|
||||
_append_unique(analyzable_files, analyzable_seen, path)
|
||||
continue
|
||||
|
||||
if entry.status == "R":
|
||||
renamed_files.append(
|
||||
{
|
||||
"old_path": entry.old_path,
|
||||
"new_path": path,
|
||||
"similarity": entry.similarity,
|
||||
}
|
||||
)
|
||||
_append_unique(analyzable_files, analyzable_seen, path)
|
||||
if entry.similarity is None or entry.similarity < 100:
|
||||
_append_unique(modified_files, modified_seen, path)
|
||||
continue
|
||||
|
||||
if entry.status == "C":
|
||||
_append_unique(modified_files, modified_seen, path)
|
||||
_append_unique(analyzable_files, analyzable_seen, path)
|
||||
continue
|
||||
|
||||
_append_unique(modified_files, modified_seen, path)
|
||||
_append_unique(analyzable_files, analyzable_seen, path)
|
||||
|
||||
return {
|
||||
"added_files": added_files,
|
||||
"modified_files": modified_files,
|
||||
"deleted_files": deleted_files,
|
||||
"renamed_files": renamed_files,
|
||||
"analyzable_files": analyzable_files,
|
||||
}
|
||||
|
||||
|
||||
def _truncate_file_list(
|
||||
files: list[str], max_files: int = _MAX_FILES_PER_SECTION
|
||||
) -> tuple[list[str], bool]:
|
||||
if len(files) <= max_files:
|
||||
return files, False
|
||||
return files[:max_files], True
|
||||
|
||||
|
||||
def build_diff_scope_instruction(scopes: list[RepoDiffScope]) -> str: # noqa: PLR0912
|
||||
lines = [
|
||||
"The user is requesting a review of a Pull Request.",
|
||||
"Instruction: Direct your analysis primarily at the changes in the listed files. "
|
||||
"You may reference other files in the repository for context (imports, definitions, "
|
||||
"usage), but report findings only if they relate to the listed changes.",
|
||||
"For Added files, review the entire file content.",
|
||||
"For Modified files, focus primarily on the changed areas.",
|
||||
]
|
||||
|
||||
for scope in scopes:
|
||||
repo_name = scope.workspace_subdir or Path(scope.source_path).name or "repository"
|
||||
lines.append("")
|
||||
lines.append(f"Repository Scope: {repo_name}")
|
||||
lines.append(f"Base reference: {scope.base_ref}")
|
||||
lines.append(f"Merge base: {scope.merge_base}")
|
||||
|
||||
focus_files, focus_truncated = _truncate_file_list(scope.analyzable_files)
|
||||
scope.truncated_sections["analyzable_files"] = focus_truncated
|
||||
if focus_files:
|
||||
lines.append("Primary Focus (changed files to analyze):")
|
||||
lines.extend(f"- {path}" for path in focus_files)
|
||||
if focus_truncated:
|
||||
lines.append(f"- ... ({len(scope.analyzable_files) - len(focus_files)} more files)")
|
||||
else:
|
||||
lines.append("Primary Focus: No analyzable changed files detected.")
|
||||
|
||||
added_files, added_truncated = _truncate_file_list(scope.added_files)
|
||||
scope.truncated_sections["added_files"] = added_truncated
|
||||
if added_files:
|
||||
lines.append("Added files (review entire file):")
|
||||
lines.extend(f"- {path}" for path in added_files)
|
||||
if added_truncated:
|
||||
lines.append(f"- ... ({len(scope.added_files) - len(added_files)} more files)")
|
||||
|
||||
modified_files, modified_truncated = _truncate_file_list(scope.modified_files)
|
||||
scope.truncated_sections["modified_files"] = modified_truncated
|
||||
if modified_files:
|
||||
lines.append("Modified files (focus on changes):")
|
||||
lines.extend(f"- {path}" for path in modified_files)
|
||||
if modified_truncated:
|
||||
lines.append(
|
||||
f"- ... ({len(scope.modified_files) - len(modified_files)} more files)"
|
||||
)
|
||||
|
||||
if scope.renamed_files:
|
||||
rename_lines = []
|
||||
for rename in scope.renamed_files:
|
||||
old_path = rename.get("old_path") or "unknown"
|
||||
new_path = rename.get("new_path") or "unknown"
|
||||
similarity = rename.get("similarity")
|
||||
if isinstance(similarity, int):
|
||||
rename_lines.append(f"- {old_path} -> {new_path} (similarity {similarity}%)")
|
||||
else:
|
||||
rename_lines.append(f"- {old_path} -> {new_path}")
|
||||
lines.append("Renamed files:")
|
||||
lines.extend(rename_lines)
|
||||
|
||||
deleted_files, deleted_truncated = _truncate_file_list(scope.deleted_files)
|
||||
scope.truncated_sections["deleted_files"] = deleted_truncated
|
||||
if deleted_files:
|
||||
lines.append("Note: These files were deleted (context only, not analyzable):")
|
||||
lines.extend(f"- {path}" for path in deleted_files)
|
||||
if deleted_truncated:
|
||||
lines.append(f"- ... ({len(scope.deleted_files) - len(deleted_files)} more files)")
|
||||
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def _should_activate_auto_scope(
|
||||
local_sources: list[dict[str, str]], non_interactive: bool, env: dict[str, str]
|
||||
) -> bool:
|
||||
if not local_sources:
|
||||
return False
|
||||
if not non_interactive:
|
||||
return False
|
||||
if not _is_ci_environment(env):
|
||||
return False
|
||||
if _is_pr_environment(env):
|
||||
return True
|
||||
|
||||
for source in local_sources:
|
||||
source_path = source.get("source_path")
|
||||
if not source_path:
|
||||
continue
|
||||
repo_path = Path(source_path)
|
||||
if not _is_git_repo(repo_path):
|
||||
continue
|
||||
current_branch = _get_current_branch_name(repo_path)
|
||||
default_branch = _resolve_default_branch_name(repo_path, env)
|
||||
if current_branch and default_branch and current_branch != default_branch:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_repo_diff_scope(
|
||||
source: dict[str, str], diff_base: str | None, env: dict[str, str]
|
||||
) -> RepoDiffScope:
|
||||
source_path = source.get("source_path", "")
|
||||
workspace_subdir = source.get("workspace_subdir")
|
||||
repo_path = Path(source_path)
|
||||
|
||||
if not _is_git_repo(repo_path):
|
||||
raise ValueError(f"Source is not a git repository: {source_path}")
|
||||
|
||||
if _is_repo_shallow(repo_path):
|
||||
raise ValueError(
|
||||
"Strix requires full git history for diff-scope. Please set fetch-depth: 0 "
|
||||
"in your CI config."
|
||||
)
|
||||
|
||||
base_ref = _resolve_base_ref(repo_path, diff_base, env)
|
||||
merge_base_result = _run_git_command(repo_path, ["merge-base", base_ref, "HEAD"], check=False)
|
||||
if merge_base_result.returncode != 0:
|
||||
stderr = merge_base_result.stderr.strip()
|
||||
raise ValueError(
|
||||
f"Unable to compute merge-base against '{base_ref}' for '{source_path}'. "
|
||||
f"{stderr or 'Ensure the base branch history is fetched and reachable.'}"
|
||||
)
|
||||
|
||||
merge_base = merge_base_result.stdout.strip()
|
||||
if not merge_base:
|
||||
raise ValueError(
|
||||
f"Unable to compute merge-base against '{base_ref}' for '{source_path}'. "
|
||||
"Ensure the base branch history is fetched and reachable."
|
||||
)
|
||||
|
||||
diff_result = _run_git_command_raw(
|
||||
repo_path,
|
||||
[
|
||||
"diff",
|
||||
"--name-status",
|
||||
"-z",
|
||||
"--find-renames",
|
||||
"--find-copies",
|
||||
f"{merge_base}...HEAD",
|
||||
],
|
||||
check=False,
|
||||
)
|
||||
if diff_result.returncode != 0:
|
||||
stderr = diff_result.stderr.decode("utf-8", errors="replace").strip()
|
||||
raise ValueError(
|
||||
f"Unable to resolve changed files for '{source_path}'. "
|
||||
f"{stderr or 'Ensure the repository has enough history for diff-scope.'}"
|
||||
)
|
||||
|
||||
entries = _parse_name_status_z(diff_result.stdout)
|
||||
classified = _classify_diff_entries(entries)
|
||||
|
||||
return RepoDiffScope(
|
||||
source_path=source_path,
|
||||
workspace_subdir=workspace_subdir,
|
||||
base_ref=base_ref,
|
||||
merge_base=merge_base,
|
||||
added_files=classified["added_files"],
|
||||
modified_files=classified["modified_files"],
|
||||
renamed_files=classified["renamed_files"],
|
||||
deleted_files=classified["deleted_files"],
|
||||
analyzable_files=classified["analyzable_files"],
|
||||
)
|
||||
|
||||
|
||||
def resolve_diff_scope_context(
|
||||
local_sources: list[dict[str, str]],
|
||||
scope_mode: str,
|
||||
diff_base: str | None,
|
||||
non_interactive: bool,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> DiffScopeResult:
|
||||
if scope_mode not in _SUPPORTED_SCOPE_MODES:
|
||||
raise ValueError(f"Unsupported scope mode: {scope_mode}")
|
||||
|
||||
env_map = dict(os.environ if env is None else env)
|
||||
|
||||
if scope_mode == "full":
|
||||
return DiffScopeResult(
|
||||
active=False,
|
||||
mode=scope_mode,
|
||||
metadata={"active": False, "mode": scope_mode},
|
||||
)
|
||||
|
||||
if scope_mode == "auto":
|
||||
should_activate = _should_activate_auto_scope(local_sources, non_interactive, env_map)
|
||||
if not should_activate:
|
||||
return DiffScopeResult(
|
||||
active=False,
|
||||
mode=scope_mode,
|
||||
metadata={"active": False, "mode": scope_mode},
|
||||
)
|
||||
|
||||
if not local_sources:
|
||||
raise ValueError("Diff-scope is active, but no local repository targets were provided.")
|
||||
|
||||
repo_scopes: list[RepoDiffScope] = []
|
||||
skipped_non_git: list[str] = []
|
||||
for source in local_sources:
|
||||
source_path = source.get("source_path")
|
||||
if not source_path:
|
||||
continue
|
||||
if not _is_git_repo(Path(source_path)):
|
||||
skipped_non_git.append(source_path)
|
||||
continue
|
||||
repo_scopes.append(_resolve_repo_diff_scope(source, diff_base, env_map))
|
||||
|
||||
if not repo_scopes:
|
||||
raise ValueError(
|
||||
"Diff-scope is active, but no Git repositories were found. "
|
||||
"Use --scope-mode full to disable diff-scope for this run."
|
||||
)
|
||||
|
||||
instruction_block = build_diff_scope_instruction(repo_scopes)
|
||||
metadata: dict[str, Any] = {
|
||||
"active": True,
|
||||
"mode": scope_mode,
|
||||
"repos": [scope.to_metadata() for scope in repo_scopes],
|
||||
"total_repositories": len(repo_scopes),
|
||||
"total_analyzable_files": sum(len(scope.analyzable_files) for scope in repo_scopes),
|
||||
"total_deleted_files": sum(len(scope.deleted_files) for scope in repo_scopes),
|
||||
}
|
||||
if skipped_non_git:
|
||||
metadata["skipped_non_git_sources"] = skipped_non_git
|
||||
|
||||
return DiffScopeResult(
|
||||
active=True,
|
||||
mode=scope_mode,
|
||||
instruction_block=instruction_block,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _is_http_git_repo(url: str) -> bool:
|
||||
check_url = f"{url.rstrip('/')}/info/refs?service=git-upload-pack"
|
||||
|
||||
98
tests/interface/test_diff_scope.py
Normal file
98
tests/interface/test_diff_scope.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_utils_module():
|
||||
module_path = Path(__file__).resolve().parents[2] / "strix" / "interface" / "utils.py"
|
||||
spec = importlib.util.spec_from_file_location("strix_interface_utils_test", module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise RuntimeError("Failed to load strix.interface.utils for tests")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
utils = _load_utils_module()
|
||||
|
||||
|
||||
def test_parse_name_status_uses_rename_destination_path() -> None:
|
||||
raw = (
|
||||
b"R100\x00old/path.py\x00new/path.py\x00"
|
||||
b"R75\x00legacy/module.py\x00modern/module.py\x00"
|
||||
b"M\x00src/app.py\x00"
|
||||
b"A\x00src/new_file.py\x00"
|
||||
b"D\x00src/deleted.py\x00"
|
||||
)
|
||||
|
||||
entries = utils._parse_name_status_z(raw)
|
||||
classified = utils._classify_diff_entries(entries)
|
||||
|
||||
assert "new/path.py" in classified["analyzable_files"]
|
||||
assert "old/path.py" not in classified["analyzable_files"]
|
||||
assert "modern/module.py" in classified["analyzable_files"]
|
||||
assert classified["renamed_files"][0]["old_path"] == "old/path.py"
|
||||
assert classified["renamed_files"][0]["new_path"] == "new/path.py"
|
||||
assert "src/deleted.py" in classified["deleted_files"]
|
||||
assert "src/deleted.py" not in classified["analyzable_files"]
|
||||
|
||||
|
||||
def test_build_diff_scope_instruction_includes_added_modified_and_deleted_guidance() -> None:
|
||||
scope = utils.RepoDiffScope(
|
||||
source_path="/tmp/repo",
|
||||
workspace_subdir="repo",
|
||||
base_ref="refs/remotes/origin/main",
|
||||
merge_base="abc123",
|
||||
added_files=["src/added.py"],
|
||||
modified_files=["src/changed.py"],
|
||||
renamed_files=[{"old_path": "src/old.py", "new_path": "src/new.py", "similarity": 90}],
|
||||
deleted_files=["src/deleted.py"],
|
||||
analyzable_files=["src/added.py", "src/changed.py", "src/new.py"],
|
||||
)
|
||||
|
||||
instruction = utils.build_diff_scope_instruction([scope])
|
||||
|
||||
assert "For Added files, review the entire file content." in instruction
|
||||
assert "For Modified files, focus primarily on the changed areas." in instruction
|
||||
assert "Note: These files were deleted" in instruction
|
||||
assert "src/deleted.py" in instruction
|
||||
assert "src/old.py -> src/new.py" in instruction
|
||||
|
||||
|
||||
def test_resolve_base_ref_prefers_github_base_ref(monkeypatch) -> None:
|
||||
calls: list[str] = []
|
||||
|
||||
def fake_ref_exists(_repo_path: Path, ref: str) -> bool:
|
||||
calls.append(ref)
|
||||
return ref == "refs/remotes/origin/release-2026"
|
||||
|
||||
monkeypatch.setattr(utils, "_git_ref_exists", fake_ref_exists)
|
||||
monkeypatch.setattr(utils, "_extract_github_base_sha", lambda _env: None)
|
||||
monkeypatch.setattr(utils, "_resolve_origin_head_ref", lambda _repo_path: None)
|
||||
|
||||
base_ref = utils._resolve_base_ref(
|
||||
Path("/tmp/repo"),
|
||||
diff_base=None,
|
||||
env={"GITHUB_BASE_REF": "release-2026"},
|
||||
)
|
||||
|
||||
assert base_ref == "refs/remotes/origin/release-2026"
|
||||
assert calls[0] == "refs/remotes/origin/release-2026"
|
||||
|
||||
|
||||
def test_resolve_base_ref_falls_back_to_remote_main(monkeypatch) -> None:
|
||||
calls: list[str] = []
|
||||
|
||||
def fake_ref_exists(_repo_path: Path, ref: str) -> bool:
|
||||
calls.append(ref)
|
||||
return ref == "refs/remotes/origin/main"
|
||||
|
||||
monkeypatch.setattr(utils, "_git_ref_exists", fake_ref_exists)
|
||||
monkeypatch.setattr(utils, "_extract_github_base_sha", lambda _env: None)
|
||||
monkeypatch.setattr(utils, "_resolve_origin_head_ref", lambda _repo_path: None)
|
||||
|
||||
base_ref = utils._resolve_base_ref(Path("/tmp/repo"), diff_base=None, env={})
|
||||
|
||||
assert base_ref == "refs/remotes/origin/main"
|
||||
assert "refs/remotes/origin/main" in calls
|
||||
assert "origin/main" not in calls
|
||||
Reference in New Issue
Block a user