From 8ac1fd3ea8f088f1404d6559e332ac505c57c1e6 Mon Sep 17 00:00:00 2001 From: JustAnotherArchivist Date: Mon, 7 Feb 2022 04:43:19 +0000 Subject: [PATCH] Refactor Pushshift code to separate the general things from the search --- snscrape/modules/reddit.py | 132 ++++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 61 deletions(-) diff --git a/snscrape/modules/reddit.py b/snscrape/modules/reddit.py index e477901..07114d7 100644 --- a/snscrape/modules/reddit.py +++ b/snscrape/modules/reddit.py @@ -46,19 +46,35 @@ class Comment(snscrape.base.Item): return self.url -class _RedditPushshiftScraper(snscrape.base.Scraper): - def __init__(self, name, *, submissions = True, comments = True, before = None, after = None, **kwargs): - super().__init__(**kwargs) - self._name = name - self._submissions = submissions - self._comments = comments - self._before = before - self._after = after +def _cmp_id(id1, id2): + '''Compare two Reddit IDs. Returns -1 if id1 is less than id2, 0 if they are equal, and 1 if id1 is greater than id2. - if not type(self)._validationFunc(self._name): - raise ValueError(f'invalid {type(self).name.split("-", 1)[1]} name') - if not self._submissions and not self._comments: - raise ValueError('At least one of submissions and comments must be True') + id1 and id2 may have prefixes like t1_, but if included, they must be present on both and equal.''' + + if id1.startswith('t') and '_' in id1: + prefix, id1 = id1.split('_', 1) + if not id2.startswith(f'{prefix}_'): + raise ValueError('id2 must have the same prefix as id1') + _, id2 = id2.split('_', 1) + if id1.strip(string.ascii_lowercase + string.digits) != '': + raise ValueError('invalid characters in id1') + if id2.strip(string.ascii_lowercase + string.digits) != '': + raise ValueError('invalid characters in id2') + if len(id1) < len(id2): + return -1 + if len(id1) > len(id2): + return 1 + if id1 < id2: + return -1 + if id1 > id2: + return 1 + return 0 + + +class _RedditPushshiftScraper(snscrape.base.Scraper): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._headers = {'User-Agent': f'snscrape/{snscrape.version.__version__}'} def _handle_rate_limiting(self, r): if r.status_code == 429: @@ -69,52 +85,11 @@ class _RedditPushshiftScraper(snscrape.base.Scraper): return False, 'non-200 status code' return True, None - def _cmp_id(self, id1, id2): - '''Compare two Reddit IDs. Returns -1 if id1 is less than id2, 0 if they are equal, and 1 if id1 is greater than id2. - - id1 and id2 may have prefixes like t1_, but if included, they must be present on both and equal.''' - - if id1.startswith('t') and '_' in id1: - prefix, id1 = id1.split('_', 1) - if not id2.startswith(f'{prefix}_'): - raise ValueError('id2 must have the same prefix as id1') - _, id2 = id2.split('_', 1) - if id1.strip(string.ascii_lowercase + string.digits) != '': - raise ValueError('invalid characters in id1') - if id2.strip(string.ascii_lowercase + string.digits) != '': - raise ValueError('invalid characters in id2') - if len(id1) < len(id2): - return -1 - if len(id1) > len(id2): - return 1 - if id1 < id2: - return -1 - if id1 > id2: - return 1 - return 0 - - def _iter_api(self, url, params = None): - '''Iterate through the Pushshift API using the 'before' parameter and yield the items.''' - lowestIdSeen = None - if params is None: - params = {} - if self._before is not None: - params['before'] = self._before - if self._after is not None: - params['after'] = self._after - params['sort'] = 'desc' - while True: - r = self._get(url, params = params, headers = {'User-Agent': f'snscrape/{snscrape.version.__version__}'}, responseOkCallback = self._handle_rate_limiting) - if r.status_code != 200: - raise snscrape.base.ScraperException(f'Got status code {r.status_code}') - obj = r.json() - if not obj['data'] or (lowestIdSeen is not None and all(self._cmp_id(d['id'], lowestIdSeen) >= 0 for d in obj['data'])): # end of pagination - break - for d in obj['data']: - if lowestIdSeen is None or self._cmp_id(d['id'], lowestIdSeen) == -1: - yield self._api_obj_to_item(d) - lowestIdSeen = d['id'] - params['before'] = obj["data"][-1]["created_utc"] + 1 + def _get_api(self, url, params = None): + r = self._get(url, params = params, headers = self._headers, responseOkCallback = self._handle_rate_limiting) + if r.status_code != 200: + raise snscrape.base.ScraperException(f'Got status code {r.status_code}') + return r.json() def _api_obj_to_item(self, d): cls = Submission if 'title' in d else Comment @@ -154,6 +129,41 @@ class _RedditPushshiftScraper(snscrape.base.Scraper): return cls(**kwargs) + +class _RedditPushshiftSearchScraper(_RedditPushshiftScraper): + def __init__(self, name, *, submissions = True, comments = True, before = None, after = None, **kwargs): + super().__init__(**kwargs) + self._name = name + self._submissions = submissions + self._comments = comments + self._before = before + self._after = after + + if not type(self)._validationFunc(self._name): + raise ValueError(f'invalid {type(self).name.split("-", 1)[1]} name') + if not self._submissions and not self._comments: + raise ValueError('At least one of submissions and comments must be True') + + def _iter_api(self, url, params = None): + '''Iterate through the Pushshift API using the 'before' parameter and yield the items.''' + lowestIdSeen = None + if params is None: + params = {} + if self._before is not None: + params['before'] = self._before + if self._after is not None: + params['after'] = self._after + params['sort'] = 'desc' + while True: + obj = self._get_api(url, params = params) + if not obj['data'] or (lowestIdSeen is not None and all(_cmp_id(d['id'], lowestIdSeen) >= 0 for d in obj['data'])): # end of pagination + break + for d in obj['data']: + if lowestIdSeen is None or _cmp_id(d['id'], lowestIdSeen) == -1: + yield self._api_obj_to_item(d) + lowestIdSeen = d['id'] + params['before'] = obj["data"][-1]["created_utc"] + 1 + def _iter_api_submissions_and_comments(self, params: dict): # Retrieve both submissions and comments, interleave the results to get a reverse-chronological order params['size'] = '1000' @@ -218,19 +228,19 @@ class _RedditPushshiftScraper(snscrape.base.Scraper): return cls._cli_construct(args, getattr(args, name), submissions = not args.noSubmissions, comments = not args.noComments, before = args.before, after = args.after) -class RedditUserScraper(_RedditPushshiftScraper): +class RedditUserScraper(_RedditPushshiftSearchScraper): name = 'reddit-user' _validationFunc = lambda x: re.match('^[A-Za-z0-9_-]{3,20}$', x) _apiField = 'author' -class RedditSubredditScraper(_RedditPushshiftScraper): +class RedditSubredditScraper(_RedditPushshiftSearchScraper): name = 'reddit-subreddit' _validationFunc = lambda x: re.match('^[A-Za-z0-9][A-Za-z0-9_]{2,20}$', x) _apiField = 'subreddit' -class RedditSearchScraper(_RedditPushshiftScraper): +class RedditSearchScraper(_RedditPushshiftSearchScraper): name = 'reddit-search' _validationFunc = lambda x: True _apiField = 'q'