From 6f4eb21ad0651843e86f0c88078594e5cef9b7db Mon Sep 17 00:00:00 2001 From: Tristan Lee Date: Mon, 7 Aug 2023 19:15:39 -0500 Subject: [PATCH] started addressing mypy issues, updated several method type annotation signatures to be consistent with changes --- cisticola/base.py | 38 +++++++++++----------- cisticola/scraper/base.py | 20 ++++++++---- cisticola/scraper/bitchute.py | 4 +-- cisticola/scraper/gettr.py | 4 +-- cisticola/scraper/rumble.py | 4 +-- cisticola/scraper/telegram_telethon.py | 11 ++++--- cisticola/transformer/base.py | 19 +++++++---- cisticola/transformer/bitchute.py | 16 +++++---- cisticola/transformer/gettr.py | 13 ++++++-- cisticola/transformer/rumble.py | 20 ++++++++---- cisticola/transformer/telegram_telethon.py | 13 ++++++-- docs/source/conf.py | 2 +- 12 files changed, 101 insertions(+), 63 deletions(-) diff --git a/cisticola/base.py b/cisticola/base.py index e29b7bc..665762d 100644 --- a/cisticola/base.py +++ b/cisticola/base.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from dataclasses import dataclass, field from datetime import datetime import tempfile @@ -62,7 +62,7 @@ class ScraperResult: archived_urls: dict #: What date was the media archived? (None if not archived) - media_archived: datetime + media_archived: Optional[datetime] @dataclass @@ -70,10 +70,10 @@ class Channel: """Information about a specific channel to be scraped.""" #: Name of channel (different from username because it can be non-unique and contain emojis), e.g. ``T🕊Редакция Президент Гордон🕊"``. - name: str + name: Optional[str] #: String that uniquely identifies the channel on the given platform, e.g. ``"-1001101170442"``. - platform_id: str + platform_id: Optional[str] #: User-specified category for the channel, e.g. ``"explicit_qanon"``. category: str @@ -82,28 +82,28 @@ class Channel: platform: str #: URL for the given channel on the platform, e.g. ``"https://t.me/prezidentgordonteam"`` - url: str + url: Optional[str] #: Screen name/username of channel. screenname: str #: 2 digit country code for the country of origin for the channel, e.g. ``"RU"``. - country: str = None + country: Optional[str] = None #: Name of influencer, if channel belongs to an influencer that operates on multiple platforms. - influencer: str = None + influencer: Optional[str] = None #: Whether or not the channel is publicly-accessible. - public: bool = None + public: Optional[bool] = None #: Whether or not the channel is a chat (i.e. allows users who are not the channel creator to post/message) - chat: bool = None + chat: Optional[bool] = None #: Any other additional notes about the channel. notes: str = "" #: Did the channel come from a researcher or a scraping process? - source: str = None + source: Optional[str] = None def hydrate(self): pass @@ -177,7 +177,7 @@ class ChannelInfo: verified: bool #: Datetime at which the channel was created. - date_created: datetime + date_created: Optional[datetime] #: Datetime (relative to UTC) that the scraped channel info was archived at. date_archived: datetime @@ -260,28 +260,28 @@ class Post: normalized_content: str = "" #: The ID of the Channel that the post was forwarded or quoted from - forwarded_from: int = None + forwarded_from: Optional[int] = None #: The ID of the Post that this Post is a reply to - reply_to: int = None + reply_to: Optional[int] = None #: Other users mentioned in the post mentions: list = field(default_factory=list) #: Number of positive post reactions (e.g. likes, favorites, rumbles, upvotes, etc.) - likes: int = None + likes: Optional[int] = None #: Number of times the post was forwarded/retweeted/shared - forwards: int = None + forwards: Optional[int] = None #: Number of times the post was viewed - views: int = None + views: Optional[int] = None #: Video title, if post is a video - video_title: str = None + video_title: Optional[str] = None #: Video duration in seconds, if post is a video - video_duration: int = None + video_duration: Optional[int] = None def hydrate(self): """Populate additional fields from processed data, including language detection, named entity recognition, and extraction of outlinks, hashtags, and cryptocurrency addresses.""" @@ -404,7 +404,7 @@ class Media: date_transformed: datetime #: JSON dump of the dict containing metadata information for the media file. - exif: str = None + exif: Optional[str] = None def get_blob(self): """Download media file as bytes blob.""" diff --git a/cisticola/scraper/base.py b/cisticola/scraper/base.py index 33b4dbe..d784dce 100644 --- a/cisticola/scraper/base.py +++ b/cisticola/scraper/base.py @@ -1,4 +1,4 @@ -from typing import Generator, Tuple, List +from typing import Generator, Tuple, List, Optional import os from io import BytesIO from urllib.parse import urlparse @@ -90,7 +90,9 @@ class Scraper: key = urlparse(url).path.split("/")[-1] return key - def url_to_blob(self, url: str, key: str = None) -> Tuple[bytes, str, str]: + def url_to_blob( + self, url: str, key: Optional[str] = None + ) -> Tuple[bytes, str, str]: """Download media file from a specified media file URL. Parameters @@ -122,7 +124,9 @@ class Scraper: return blob, content_type, key - def m3u8_url_to_blob(self, url: str, key: str = None) -> Tuple[bytes, str, str]: + def m3u8_url_to_blob( + self, url: str, key: Optional[str] = None + ) -> Tuple[bytes, str, str]: """Download media file from a specified media URL, where the media file is formatted as an m3u8 playlist, which is then decoded to an mp4 file. @@ -164,7 +168,9 @@ class Scraper: return blob, content_type, key - def ytdlp_url_to_blob(self, url: str, key: str = None) -> Tuple[bytes, str, str]: + def ytdlp_url_to_blob( + self, url: str, key: Optional[str] = None + ) -> Tuple[bytes, str, str]: """Download media file from a specified media URL, using a fork of youtube-dl that enables faster downloading. @@ -302,7 +308,7 @@ class Scraper: @logger.catch def get_posts( - self, channel: Channel, since: ScraperResult = None + self, channel: Channel, since: Optional[ScraperResult] = None ) -> Generator[ScraperResult, None, None]: """Scrape all posts from the specified Channel. @@ -428,7 +434,7 @@ class ScraperController: Parameters ---------- - channels: list + channels: list[Channel] List of Channel instances to be scraped fetch_old: bool If ``True``, scrape all posts from channels, regardless of when channel was last scraped. @@ -615,7 +621,7 @@ class ScraperController: Parameters ---------- - channels: list + channels: list[Channel] List of Channel instances to be scraped """ diff --git a/cisticola/scraper/bitchute.py b/cisticola/scraper/bitchute.py index b43a98d..e624e69 100644 --- a/cisticola/scraper/bitchute.py +++ b/cisticola/scraper/bitchute.py @@ -4,7 +4,7 @@ import re from html.parser import HTMLParser import dateparser import json -from typing import Generator +from typing import Generator, Optional from dateutil.relativedelta import relativedelta import requests @@ -28,7 +28,7 @@ class BitchuteScraper(Scraper): @logger.catch def get_posts( - self, channel: Channel, since: ScraperResult = None + self, channel: Channel, since: Optional[ScraperResult] = None ) -> Generator[ScraperResult, None, None]: session = requests.Session() session.headers.update(self.headers) diff --git a/cisticola/scraper/gettr.py b/cisticola/scraper/gettr.py index 937b9ce..56927bc 100644 --- a/cisticola/scraper/gettr.py +++ b/cisticola/scraper/gettr.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone import json -from typing import Generator +from typing import Generator, Optional from urllib.parse import urlparse from loguru import logger @@ -24,7 +24,7 @@ class GettrScraper(Scraper): @logger.catch def get_posts( - self, channel: Channel, since: ScraperResult = None + self, channel: Channel, since: Optional[ScraperResult] = None ) -> Generator[ScraperResult, None, None]: client = PublicClient() username = self.get_username_from_url(channel.url).lower() diff --git a/cisticola/scraper/rumble.py b/cisticola/scraper/rumble.py index 9f1799b..aeb7a45 100644 --- a/cisticola/scraper/rumble.py +++ b/cisticola/scraper/rumble.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone import json -from typing import Generator +from typing import Generator, Optional from urllib.parse import urlparse from loguru import logger @@ -25,7 +25,7 @@ class RumbleScraper(Scraper): @logger.catch def get_posts( - self, channel: Channel, since: ScraperResult = None + self, channel: Channel, since: Optional[ScraperResult] = None ) -> Generator[ScraperResult, None, None]: scraper = get_channel_videos(channel.url) diff --git a/cisticola/scraper/telegram_telethon.py b/cisticola/scraper/telegram_telethon.py index 91f41c8..8791227 100644 --- a/cisticola/scraper/telegram_telethon.py +++ b/cisticola/scraper/telegram_telethon.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Generator, Optional from datetime import datetime, timezone import os import json @@ -64,9 +64,9 @@ class TelegramTelethonScraper(Scraper): if len(result.archived_urls.keys()) == 0: return result - if len(list(result.archived_urls.keys())) != 1: + if len(result.archived_urls.keys()) != 1: logger.warning( - f"Expected 1 key in archived_urls, found {result.archived_keys}" + f"Expected 1 key in archived_urls, found {len(result.archived_urls.keys())}" ) else: key = list(result.archived_urls.keys())[0] @@ -147,7 +147,10 @@ class TelegramTelethonScraper(Scraper): # @logger.catch def get_posts( - self, channel: Channel, since: ScraperResult = None, until: ScraperResult = None + self, + channel: Channel, + since: Optional[ScraperResult] = None, + until: Optional[ScraperResult] = None, ) -> Generator[ScraperResult, None, None]: username = TelegramTelethonScraper.get_channel_identifier(channel) if until is not None: diff --git a/cisticola/transformer/base.py b/cisticola/transformer/base.py index 937f533..7910159 100644 --- a/cisticola/transformer/base.py +++ b/cisticola/transformer/base.py @@ -1,7 +1,7 @@ from typing import List, Generator, Union, Callable from loguru import logger from sqlalchemy import cast, String -from sqlalchemy.orm import sessionmaker, make_transient +from sqlalchemy.orm import sessionmaker, make_transient, Session from sqlalchemy.engine.base import Engine from sqlalchemy.sql.expression import func from collections import defaultdict @@ -29,7 +29,7 @@ class Transformer: def __init__(self): pass - def can_handle(data: ScraperResult) -> bool: + def can_handle(self, data: ScraperResult) -> bool: """Specifies whether or not a Transformer is capable of handling a particular piece of scraped data. @@ -44,11 +44,16 @@ class Transformer: ``True`` if it can be handled by this Transformer, false otherwise. """ - pass + raise NotImplementedError def transform( - data: ScraperResult, insert: Callable - ) -> Generator[Union[Post, Channel, Media], None, None]: + self, + data: ScraperResult, + insert: Callable, + session: Session, + insert_post: Callable, + flush_posts: Callable, + ): """Transform a ScraperResult into objects with additional parameters for analysis. This function can yield multiple objects, as it will find references to quoted/replied posts, media objects, and Channel objects and provide all of these to be inserted into the database. @@ -62,7 +67,7 @@ class Transformer: relevant unique constraints if applicable. """ - pass + raise NotImplementedError def transform_media(self, data: ScraperResult, transformed: Post, insert: Callable): """Transform a post's media attachment to standard form and insert into database. @@ -191,7 +196,7 @@ class ETLController: obj.hydrate() if flush: - self.flush_posts() + self.flush_posts(session=session) session.add(obj) session.flush() diff --git a/cisticola/transformer/bitchute.py b/cisticola/transformer/bitchute.py index fba56b4..617130a 100644 --- a/cisticola/transformer/bitchute.py +++ b/cisticola/transformer/bitchute.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone from dateutil.relativedelta import relativedelta from bs4 import BeautifulSoup +from sqlalchemy.orm import Session from cisticola.transformer.base import Transformer from cisticola.base import ( @@ -31,9 +32,7 @@ class BitchuteTransformer(Transformer): return False - def transform_media( - self, data: ScraperResult, transformed: Post, insert: Callable - ) -> Generator[Media, None, None]: + def transform_media(self, data: ScraperResult, transformed: Post, insert: Callable): raw = json.loads(data.raw_data) orig = raw["video_url"] @@ -56,7 +55,7 @@ class BitchuteTransformer(Transformer): def transform_info( self, data: RawChannelInfo, insert: Callable, session, channel=None - ) -> Generator[Union[Post, Channel, Media], None, None]: + ): raw = json.loads(data.raw_data) transformed = ChannelInfo( @@ -82,8 +81,13 @@ class BitchuteTransformer(Transformer): transformed = insert(transformed) def transform( - self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts - ) -> Generator[Union[Post, Channel, Media], None, None]: + self, + data: ScraperResult, + insert: Callable, + session: Session, + insert_post: Callable, + flush_posts: Callable, + ): raw = json.loads(data.raw_data) if raw["category"] == "comment": diff --git a/cisticola/transformer/gettr.py b/cisticola/transformer/gettr.py index e1c89cb..53d7c80 100644 --- a/cisticola/transformer/gettr.py +++ b/cisticola/transformer/gettr.py @@ -4,6 +4,8 @@ from typing import Generator, Union, Callable import dateutil.parser from datetime import datetime, timezone from sqlalchemy import func +from sqlalchemy.orm import Session + from gogettr import PublicClient from gogettr.api import GettrApiError @@ -34,7 +36,7 @@ class GettrTransformer(Transformer): def transform_info( self, data: RawChannelInfo, insert: Callable, session, channel=None - ) -> Generator[Union[Post, Channel, Media], None, None]: + ): raw = json.loads(data.raw_data) transformed = ChannelInfo( @@ -100,8 +102,13 @@ class GettrTransformer(Transformer): return channel.id def transform( - self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts - ) -> Generator[Union[Post, Channel, Media], None, None]: + self, + data: ScraperResult, + insert: Callable, + session: Session, + insert_post: Callable, + flush_posts: Callable, + ): raw = json.loads(data.raw_data) if raw["activity"]["action"] == "shares_pst": diff --git a/cisticola/transformer/rumble.py b/cisticola/transformer/rumble.py index cf4b428..ad6fdd5 100644 --- a/cisticola/transformer/rumble.py +++ b/cisticola/transformer/rumble.py @@ -1,9 +1,10 @@ import json from loguru import logger -from typing import Generator, Union, Callable +from typing import Generator, Union, Callable, Optional import dateutil.parser from datetime import datetime, timezone from sqlalchemy import func, JSON, String, cast, text +from sqlalchemy.orm import Session from cisticola.transformer.base import Transformer from cisticola.base import ( @@ -32,7 +33,7 @@ class RumbleTransformer(Transformer): def transform_info( self, data: RawChannelInfo, insert: Callable, session, channel=None - ) -> Generator[Union[Post, Channel, Media], None, None]: + ): raw = json.loads(data.raw_data) if "id" not in raw: @@ -78,8 +79,13 @@ class RumbleTransformer(Transformer): transformed = insert(transformed) def transform( - self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts - ) -> Generator[Union[Post, Channel, Media], None, None]: + self, + data: ScraperResult, + insert: Callable, + session: Session, + insert_post: Callable, + flush_posts: Callable, + ): raw = json.loads(data.raw_data) transformed = Post( @@ -106,9 +112,9 @@ class RumbleTransformer(Transformer): insert_post(transformed) -def _process_number(s): +def _process_number(s: str) -> int: if s is None: - return None + return -1 else: s = s.replace(" ", "").replace(",", "") if s.endswith("M"): @@ -118,7 +124,7 @@ def _process_number(s): return int(s) -def _parse_duration_str(duration_str: str) -> int: +def _parse_duration_str(duration_str: str) -> Optional[int]: """Convert duration string (e.g. '2:27:04') to the number of seconds (e.g. 8824).""" if not duration_str: return None diff --git a/cisticola/transformer/telegram_telethon.py b/cisticola/transformer/telegram_telethon.py index 2a6c194..f350de5 100644 --- a/cisticola/transformer/telegram_telethon.py +++ b/cisticola/transformer/telegram_telethon.py @@ -15,6 +15,8 @@ from itertools import takewhile import os from datetime import datetime, timezone from sqlalchemy import func +from sqlalchemy.orm import Session + from cisticola.transformer.base import Transformer from cisticola.base import ( @@ -148,7 +150,7 @@ class TelegramTelethonTransformer(Transformer): def transform_info( self, data: RawChannelInfo, insert: Callable, session, channel=None - ) -> Generator[Union[Post, Channel, Media], None, None]: + ): raw = json.loads(data.raw_data) chat_raw = raw["chats"][0] @@ -208,8 +210,13 @@ class TelegramTelethonTransformer(Transformer): # TODO this method API is chaotic and could be cleaned up def transform( - self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts - ) -> Generator[Union[Post, Channel, Media], None, None]: + self, + data: ScraperResult, + insert: Callable, + session: Session, + insert_post: Callable, + flush_posts: Callable, + ): raw = json.loads(data.raw_data) if raw["_"] != "Message": diff --git a/docs/source/conf.py b/docs/source/conf.py index 0f26bd3..bdd6d68 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath("../../")) # -- Project information ----------------------------------------------------- project = "Cisticola" -copyright = "2022, Bellingcat" +copyright = "2023, Bellingcat" author = "Bellingcat"