diff --git a/strix/interface/main.py b/strix/interface/main.py index 5714938..52e2e84 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -30,9 +30,10 @@ from strix.interface.utils import ( image_exists, infer_target_type, process_pull_line, + rewrite_localhost_targets, validate_llm_response, ) -from strix.runtime.docker_runtime import STRIX_IMAGE +from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME, STRIX_IMAGE from strix.telemetry.tracer import get_global_tracer @@ -377,6 +378,7 @@ Examples: parser.error(f"Invalid target '{target}'") assign_workspace_subdirs(args.targets_info) + rewrite_localhost_targets(args.targets_info, HOST_GATEWAY_HOSTNAME) return args diff --git a/strix/interface/utils.py b/strix/interface/utils.py index efc1fc9..bd128a7 100644 --- a/strix/interface/utils.py +++ b/strix/interface/utils.py @@ -404,6 +404,47 @@ def collect_local_sources(targets_info: list[dict[str, Any]]) -> list[dict[str, return local_sources +def _is_localhost_host(host: str) -> bool: + host_lower = host.lower().strip("[]") + + if host_lower in ("localhost", "0.0.0.0", "::1"): # nosec B104 + return True + + try: + ip = ipaddress.ip_address(host_lower) + if isinstance(ip, ipaddress.IPv4Address): + return ip.is_loopback # 127.0.0.0/8 + if isinstance(ip, ipaddress.IPv6Address): + return ip.is_loopback # ::1 + except ValueError: + pass + + return False + + +def rewrite_localhost_targets(targets_info: list[dict[str, Any]], host_gateway: str) -> None: + from yarl import URL # type: ignore[import-not-found] + + for target_info in targets_info: + target_type = target_info.get("type") + details = target_info.get("details", {}) + + if target_type == "web_application": + target_url = details.get("target_url", "") + try: + url = URL(target_url) + except (ValueError, TypeError): + continue + + if url.host and _is_localhost_host(url.host): + details["target_url"] = str(url.with_host(host_gateway)) + + elif target_type == "ip_address": + target_ip = details.get("target_ip", "") + if target_ip and _is_localhost_host(target_ip): + details["target_ip"] = host_gateway + + # Repository utilities def clone_repository(repo_url: str, run_name: str, dest_name: str | None = None) -> str: console = Console() diff --git a/strix/runtime/docker_runtime.py b/strix/runtime/docker_runtime.py index 7ba04f8..5abaa71 100644 --- a/strix/runtime/docker_runtime.py +++ b/strix/runtime/docker_runtime.py @@ -15,6 +15,7 @@ from .runtime import AbstractRuntime, SandboxInfo STRIX_IMAGE = os.getenv("STRIX_IMAGE", "ghcr.io/usestrix/strix-sandbox:0.1.10") +HOST_GATEWAY_HOSTNAME = "host.docker.internal" logger = logging.getLogger(__name__) @@ -121,7 +122,9 @@ class DockerRuntime(AbstractRuntime): "CAIDO_PORT": str(caido_port), "TOOL_SERVER_PORT": str(tool_server_port), "TOOL_SERVER_TOKEN": tool_server_token, + "HOST_GATEWAY": HOST_GATEWAY_HOSTNAME, }, + extra_hosts=self._get_extra_hosts(), tty=True, ) @@ -381,6 +384,9 @@ class DockerRuntime(AbstractRuntime): return "127.0.0.1" + def _get_extra_hosts(self) -> dict[str, str]: + return {HOST_GATEWAY_HOSTNAME: "host-gateway"} + async def destroy_sandbox(self, container_id: str) -> None: logger.info("Destroying scan container %s", container_id) try: