started addressing mypy issues, updated several method type annotation signatures to be consistent with changes

This commit is contained in:
Tristan Lee
2023-08-07 19:15:39 -05:00
parent 89b5068108
commit 6f4eb21ad0
12 changed files with 101 additions and 63 deletions

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from dataclasses import dataclass, field
from datetime import datetime
import tempfile
@@ -62,7 +62,7 @@ class ScraperResult:
archived_urls: dict
#: What date was the media archived? (None if not archived)
media_archived: datetime
media_archived: Optional[datetime]
@dataclass
@@ -70,10 +70,10 @@ class Channel:
"""Information about a specific channel to be scraped."""
#: Name of channel (different from username because it can be non-unique and contain emojis), e.g. ``T🕊Редакция Президент Гордон🕊"``.
name: str
name: Optional[str]
#: String that uniquely identifies the channel on the given platform, e.g. ``"-1001101170442"``.
platform_id: str
platform_id: Optional[str]
#: User-specified category for the channel, e.g. ``"explicit_qanon"``.
category: str
@@ -82,28 +82,28 @@ class Channel:
platform: str
#: URL for the given channel on the platform, e.g. ``"https://t.me/prezidentgordonteam"``
url: str
url: Optional[str]
#: Screen name/username of channel.
screenname: str
#: 2 digit country code for the country of origin for the channel, e.g. ``"RU"``.
country: str = None
country: Optional[str] = None
#: Name of influencer, if channel belongs to an influencer that operates on multiple platforms.
influencer: str = None
influencer: Optional[str] = None
#: Whether or not the channel is publicly-accessible.
public: bool = None
public: Optional[bool] = None
#: Whether or not the channel is a chat (i.e. allows users who are not the channel creator to post/message)
chat: bool = None
chat: Optional[bool] = None
#: Any other additional notes about the channel.
notes: str = ""
#: Did the channel come from a researcher or a scraping process?
source: str = None
source: Optional[str] = None
def hydrate(self):
pass
@@ -177,7 +177,7 @@ class ChannelInfo:
verified: bool
#: Datetime at which the channel was created.
date_created: datetime
date_created: Optional[datetime]
#: Datetime (relative to UTC) that the scraped channel info was archived at.
date_archived: datetime
@@ -260,28 +260,28 @@ class Post:
normalized_content: str = ""
#: The ID of the Channel that the post was forwarded or quoted from
forwarded_from: int = None
forwarded_from: Optional[int] = None
#: The ID of the Post that this Post is a reply to
reply_to: int = None
reply_to: Optional[int] = None
#: Other users mentioned in the post
mentions: list = field(default_factory=list)
#: Number of positive post reactions (e.g. likes, favorites, rumbles, upvotes, etc.)
likes: int = None
likes: Optional[int] = None
#: Number of times the post was forwarded/retweeted/shared
forwards: int = None
forwards: Optional[int] = None
#: Number of times the post was viewed
views: int = None
views: Optional[int] = None
#: Video title, if post is a video
video_title: str = None
video_title: Optional[str] = None
#: Video duration in seconds, if post is a video
video_duration: int = None
video_duration: Optional[int] = None
def hydrate(self):
"""Populate additional fields from processed data, including language detection, named entity recognition, and extraction of outlinks, hashtags, and cryptocurrency addresses."""
@@ -404,7 +404,7 @@ class Media:
date_transformed: datetime
#: JSON dump of the dict containing metadata information for the media file.
exif: str = None
exif: Optional[str] = None
def get_blob(self):
"""Download media file as bytes blob."""

View File

@@ -1,4 +1,4 @@
from typing import Generator, Tuple, List
from typing import Generator, Tuple, List, Optional
import os
from io import BytesIO
from urllib.parse import urlparse
@@ -90,7 +90,9 @@ class Scraper:
key = urlparse(url).path.split("/")[-1]
return key
def url_to_blob(self, url: str, key: str = None) -> Tuple[bytes, str, str]:
def url_to_blob(
self, url: str, key: Optional[str] = None
) -> Tuple[bytes, str, str]:
"""Download media file from a specified media file URL.
Parameters
@@ -122,7 +124,9 @@ class Scraper:
return blob, content_type, key
def m3u8_url_to_blob(self, url: str, key: str = None) -> Tuple[bytes, str, str]:
def m3u8_url_to_blob(
self, url: str, key: Optional[str] = None
) -> Tuple[bytes, str, str]:
"""Download media file from a specified media URL, where the media file
is formatted as an m3u8 playlist, which is then decoded to an mp4 file.
@@ -164,7 +168,9 @@ class Scraper:
return blob, content_type, key
def ytdlp_url_to_blob(self, url: str, key: str = None) -> Tuple[bytes, str, str]:
def ytdlp_url_to_blob(
self, url: str, key: Optional[str] = None
) -> Tuple[bytes, str, str]:
"""Download media file from a specified media URL, using a fork of
youtube-dl that enables faster downloading.
@@ -302,7 +308,7 @@ class Scraper:
@logger.catch
def get_posts(
self, channel: Channel, since: ScraperResult = None
self, channel: Channel, since: Optional[ScraperResult] = None
) -> Generator[ScraperResult, None, None]:
"""Scrape all posts from the specified Channel.
@@ -428,7 +434,7 @@ class ScraperController:
Parameters
----------
channels: list<Channel>
channels: list[Channel]
List of Channel instances to be scraped
fetch_old: bool
If ``True``, scrape all posts from channels, regardless of when channel was last scraped.
@@ -615,7 +621,7 @@ class ScraperController:
Parameters
----------
channels: list<Channel>
channels: list[Channel]
List of Channel instances to be scraped
"""

View File

@@ -4,7 +4,7 @@ import re
from html.parser import HTMLParser
import dateparser
import json
from typing import Generator
from typing import Generator, Optional
from dateutil.relativedelta import relativedelta
import requests
@@ -28,7 +28,7 @@ class BitchuteScraper(Scraper):
@logger.catch
def get_posts(
self, channel: Channel, since: ScraperResult = None
self, channel: Channel, since: Optional[ScraperResult] = None
) -> Generator[ScraperResult, None, None]:
session = requests.Session()
session.headers.update(self.headers)

View File

@@ -1,6 +1,6 @@
from datetime import datetime, timezone
import json
from typing import Generator
from typing import Generator, Optional
from urllib.parse import urlparse
from loguru import logger
@@ -24,7 +24,7 @@ class GettrScraper(Scraper):
@logger.catch
def get_posts(
self, channel: Channel, since: ScraperResult = None
self, channel: Channel, since: Optional[ScraperResult] = None
) -> Generator[ScraperResult, None, None]:
client = PublicClient()
username = self.get_username_from_url(channel.url).lower()

View File

@@ -1,6 +1,6 @@
from datetime import datetime, timezone
import json
from typing import Generator
from typing import Generator, Optional
from urllib.parse import urlparse
from loguru import logger
@@ -25,7 +25,7 @@ class RumbleScraper(Scraper):
@logger.catch
def get_posts(
self, channel: Channel, since: ScraperResult = None
self, channel: Channel, since: Optional[ScraperResult] = None
) -> Generator[ScraperResult, None, None]:
scraper = get_channel_videos(channel.url)

View File

@@ -1,4 +1,4 @@
from typing import Generator
from typing import Generator, Optional
from datetime import datetime, timezone
import os
import json
@@ -64,9 +64,9 @@ class TelegramTelethonScraper(Scraper):
if len(result.archived_urls.keys()) == 0:
return result
if len(list(result.archived_urls.keys())) != 1:
if len(result.archived_urls.keys()) != 1:
logger.warning(
f"Expected 1 key in archived_urls, found {result.archived_keys}"
f"Expected 1 key in archived_urls, found {len(result.archived_urls.keys())}"
)
else:
key = list(result.archived_urls.keys())[0]
@@ -147,7 +147,10 @@ class TelegramTelethonScraper(Scraper):
# @logger.catch
def get_posts(
self, channel: Channel, since: ScraperResult = None, until: ScraperResult = None
self,
channel: Channel,
since: Optional[ScraperResult] = None,
until: Optional[ScraperResult] = None,
) -> Generator[ScraperResult, None, None]:
username = TelegramTelethonScraper.get_channel_identifier(channel)
if until is not None:

View File

@@ -1,7 +1,7 @@
from typing import List, Generator, Union, Callable
from loguru import logger
from sqlalchemy import cast, String
from sqlalchemy.orm import sessionmaker, make_transient
from sqlalchemy.orm import sessionmaker, make_transient, Session
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql.expression import func
from collections import defaultdict
@@ -29,7 +29,7 @@ class Transformer:
def __init__(self):
pass
def can_handle(data: ScraperResult) -> bool:
def can_handle(self, data: ScraperResult) -> bool:
"""Specifies whether or not a Transformer is capable of handling a particular
piece of scraped data.
@@ -44,11 +44,16 @@ class Transformer:
``True`` if it can be handled by this Transformer, false otherwise.
"""
pass
raise NotImplementedError
def transform(
data: ScraperResult, insert: Callable
) -> Generator[Union[Post, Channel, Media], None, None]:
self,
data: ScraperResult,
insert: Callable,
session: Session,
insert_post: Callable,
flush_posts: Callable,
):
"""Transform a ScraperResult into objects with additional parameters for analysis. This function can
yield multiple objects, as it will find references to quoted/replied posts, media objects, and Channel
objects and provide all of these to be inserted into the database.
@@ -62,7 +67,7 @@ class Transformer:
relevant unique constraints if applicable.
"""
pass
raise NotImplementedError
def transform_media(self, data: ScraperResult, transformed: Post, insert: Callable):
"""Transform a post's media attachment to standard form and insert into database.
@@ -191,7 +196,7 @@ class ETLController:
obj.hydrate()
if flush:
self.flush_posts()
self.flush_posts(session=session)
session.add(obj)
session.flush()

View File

@@ -5,6 +5,7 @@ from datetime import datetime, timezone
from dateutil.relativedelta import relativedelta
from bs4 import BeautifulSoup
from sqlalchemy.orm import Session
from cisticola.transformer.base import Transformer
from cisticola.base import (
@@ -31,9 +32,7 @@ class BitchuteTransformer(Transformer):
return False
def transform_media(
self, data: ScraperResult, transformed: Post, insert: Callable
) -> Generator[Media, None, None]:
def transform_media(self, data: ScraperResult, transformed: Post, insert: Callable):
raw = json.loads(data.raw_data)
orig = raw["video_url"]
@@ -56,7 +55,7 @@ class BitchuteTransformer(Transformer):
def transform_info(
self, data: RawChannelInfo, insert: Callable, session, channel=None
) -> Generator[Union[Post, Channel, Media], None, None]:
):
raw = json.loads(data.raw_data)
transformed = ChannelInfo(
@@ -82,8 +81,13 @@ class BitchuteTransformer(Transformer):
transformed = insert(transformed)
def transform(
self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts
) -> Generator[Union[Post, Channel, Media], None, None]:
self,
data: ScraperResult,
insert: Callable,
session: Session,
insert_post: Callable,
flush_posts: Callable,
):
raw = json.loads(data.raw_data)
if raw["category"] == "comment":

View File

@@ -4,6 +4,8 @@ from typing import Generator, Union, Callable
import dateutil.parser
from datetime import datetime, timezone
from sqlalchemy import func
from sqlalchemy.orm import Session
from gogettr import PublicClient
from gogettr.api import GettrApiError
@@ -34,7 +36,7 @@ class GettrTransformer(Transformer):
def transform_info(
self, data: RawChannelInfo, insert: Callable, session, channel=None
) -> Generator[Union[Post, Channel, Media], None, None]:
):
raw = json.loads(data.raw_data)
transformed = ChannelInfo(
@@ -100,8 +102,13 @@ class GettrTransformer(Transformer):
return channel.id
def transform(
self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts
) -> Generator[Union[Post, Channel, Media], None, None]:
self,
data: ScraperResult,
insert: Callable,
session: Session,
insert_post: Callable,
flush_posts: Callable,
):
raw = json.loads(data.raw_data)
if raw["activity"]["action"] == "shares_pst":

View File

@@ -1,9 +1,10 @@
import json
from loguru import logger
from typing import Generator, Union, Callable
from typing import Generator, Union, Callable, Optional
import dateutil.parser
from datetime import datetime, timezone
from sqlalchemy import func, JSON, String, cast, text
from sqlalchemy.orm import Session
from cisticola.transformer.base import Transformer
from cisticola.base import (
@@ -32,7 +33,7 @@ class RumbleTransformer(Transformer):
def transform_info(
self, data: RawChannelInfo, insert: Callable, session, channel=None
) -> Generator[Union[Post, Channel, Media], None, None]:
):
raw = json.loads(data.raw_data)
if "id" not in raw:
@@ -78,8 +79,13 @@ class RumbleTransformer(Transformer):
transformed = insert(transformed)
def transform(
self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts
) -> Generator[Union[Post, Channel, Media], None, None]:
self,
data: ScraperResult,
insert: Callable,
session: Session,
insert_post: Callable,
flush_posts: Callable,
):
raw = json.loads(data.raw_data)
transformed = Post(
@@ -106,9 +112,9 @@ class RumbleTransformer(Transformer):
insert_post(transformed)
def _process_number(s):
def _process_number(s: str) -> int:
if s is None:
return None
return -1
else:
s = s.replace(" ", "").replace(",", "")
if s.endswith("M"):
@@ -118,7 +124,7 @@ def _process_number(s):
return int(s)
def _parse_duration_str(duration_str: str) -> int:
def _parse_duration_str(duration_str: str) -> Optional[int]:
"""Convert duration string (e.g. '2:27:04') to the number of seconds (e.g. 8824)."""
if not duration_str:
return None

View File

@@ -15,6 +15,8 @@ from itertools import takewhile
import os
from datetime import datetime, timezone
from sqlalchemy import func
from sqlalchemy.orm import Session
from cisticola.transformer.base import Transformer
from cisticola.base import (
@@ -148,7 +150,7 @@ class TelegramTelethonTransformer(Transformer):
def transform_info(
self, data: RawChannelInfo, insert: Callable, session, channel=None
) -> Generator[Union[Post, Channel, Media], None, None]:
):
raw = json.loads(data.raw_data)
chat_raw = raw["chats"][0]
@@ -208,8 +210,13 @@ class TelegramTelethonTransformer(Transformer):
# TODO this method API is chaotic and could be cleaned up
def transform(
self, data: ScraperResult, insert: Callable, session, insert_post, flush_posts
) -> Generator[Union[Post, Channel, Media], None, None]:
self,
data: ScraperResult,
insert: Callable,
session: Session,
insert_post: Callable,
flush_posts: Callable,
):
raw = json.loads(data.raw_data)
if raw["_"] != "Message":

View File

@@ -19,7 +19,7 @@ sys.path.insert(0, os.path.abspath("../../"))
# -- Project information -----------------------------------------------------
project = "Cisticola"
copyright = "2022, Bellingcat"
copyright = "2023, Bellingcat"
author = "Bellingcat"