diff --git a/src/auto_archiver/core/base_module.py b/src/auto_archiver/core/base_module.py index 2f6ab20..c38db3b 100644 --- a/src/auto_archiver/core/base_module.py +++ b/src/auto_archiver/core/base_module.py @@ -56,13 +56,6 @@ class BaseModule(ABC): config = deepcopy(config) authentication = deepcopy(config.pop('authentication', {})) - # extract out concatenated sites - for key, val in copy(authentication).items(): - if "," in key: - for site in key.split(","): - authentication[site] = val - del authentication[key] - self.authentication = authentication self.config = config for key, val in config.get(self.name, {}).items(): diff --git a/src/auto_archiver/core/orchestrator.py b/src/auto_archiver/core/orchestrator.py index 54f5c32..c9980ff 100644 --- a/src/auto_archiver/core/orchestrator.py +++ b/src/auto_archiver/core/orchestrator.py @@ -6,13 +6,12 @@ from __future__ import annotations from typing import Generator, Union, List, Type, TYPE_CHECKING -from urllib.parse import urlparse -from ipaddress import ip_address import argparse import os import sys from tempfile import TemporaryDirectory import traceback +from copy import copy from rich_argparse import RichHelpFormatter from loguru import logger @@ -24,6 +23,7 @@ from .config import read_yaml, store_yaml, to_dot_notation, merge_dicts, is_vali from .module import ModuleFactory, LazyBaseModule from . import validators, Feeder, Extractor, Database, Storage, Formatter, Enricher from .consts import MODULE_TYPES +from auto_archiver.utils.url import check_url_or_raise from loguru import logger if TYPE_CHECKING: @@ -135,6 +135,9 @@ class ArchivingOrchestrator: # merge the new config with the old one config = merge_dicts(vars(parsed), yaml_config) + # set up the authentication dict as needed + config = self.setup_authentication(config) + # clean out args from the base_parser that we don't want in the config for key in vars(basic_config): config.pop(key, None) @@ -287,6 +290,7 @@ class ArchivingOrchestrator: if module in invalid_modules: continue + try: loaded_module: BaseModule = self.module_factory.get_module(module, self.config) except (KeyboardInterrupt, Exception) as e: @@ -442,8 +446,8 @@ class ArchivingOrchestrator: original_url = result.get_url().strip() try: - self.assert_valid_url(original_url) - except AssertionError as e: + check_url_or_raise(original_url) + except ValueError as e: logger.error(f"Error archiving URL {original_url}: {e}") raise e @@ -503,26 +507,26 @@ class ArchivingOrchestrator: logger.error(f"ERROR database {d.name}: {e}: {traceback.format_exc()}") return result + - def assert_valid_url(self, url: str) -> bool: + def setup_authentication(self, config: dict) -> dict: """ - Blocks localhost, private, reserved, and link-local IPs and all non-http/https schemes. + Setup authentication for all modules that require it + + Split up strings into multiple sites if they are comma separated """ - assert url.startswith("http://") or url.startswith("https://"), f"Invalid URL scheme" - parsed = urlparse(url) - assert parsed.scheme in ["http", "https"], f"Invalid URL scheme" - assert parsed.hostname, f"Invalid URL hostname" - assert parsed.hostname != "localhost", f"Invalid URL" + authentication = config.get('authentication', {}) - try: # special rules for IP addresses - ip = ip_address(parsed.hostname) - except ValueError: pass - else: - assert ip.is_global, f"Invalid IP used" - assert not ip.is_reserved, f"Invalid IP used" - assert not ip.is_link_local, f"Invalid IP used" - assert not ip.is_private, f"Invalid IP used" + # extract out concatenated sites + for key, val in copy(authentication).items(): + if "," in key: + for site in key.split(","): + authentication[site] = val + del authentication[key] + + config['authentication'] = authentication + return config # Helper Properties diff --git a/src/auto_archiver/utils/url.py b/src/auto_archiver/utils/url.py index 40884da..061f4aa 100644 --- a/src/auto_archiver/utils/url.py +++ b/src/auto_archiver/utils/url.py @@ -1,5 +1,6 @@ import re from urllib.parse import urlparse, urlunparse +from ipaddress import ip_address AUTHWALL_URLS = [ @@ -7,6 +8,43 @@ AUTHWALL_URLS = [ re.compile(r"https:\/\/www\.instagram\.com"), # instagram ] + +def check_url_or_raise(url: str) -> bool | ValueError: + """ + Blocks localhost, private, reserved, and link-local IPs and all non-http/https schemes. + """ + + + if not (url.startswith("http://") or url.startswith("https://")): + raise ValueError(f"Invalid URL scheme for url {url}") + + parsed = urlparse(url) + if not parsed.hostname: + raise ValueError(f"Invalid URL hostname for url {url}") + + if parsed.hostname == "localhost": + raise ValueError(f"Localhost URLs cannot be parsed for security reasons (for url {url})") + + if parsed.scheme not in ["http", "https"]: + raise ValueError(f"Invalid URL scheme, only http and https supported (for url {url})") + + try: # special rules for IP addresses + ip = ip_address(parsed.hostname) + except ValueError: + pass + + else: + if not ip.is_global: + raise ValueError(f"IP address {ip} is not globally reachable") + if ip.is_reserved: + raise ValueError(f"Reserved IP address {ip} used") + if ip.is_link_local: + raise ValueError(f"Link-local IP address {ip} used") + if ip.is_private: + raise ValueError(f"Private IP address {ip} used") + + return True + def domain_for_url(url: str) -> str: """ SECURITY: parse the domain using urllib to avoid any potential security issues