mirror of
https://github.com/bellingcat/snscrape.git
synced 2026-06-08 02:28:29 +03:00
Refactor Pushshift code to separate the general things from the search
This commit is contained in:
@@ -46,19 +46,35 @@ class Comment(snscrape.base.Item):
|
||||
return self.url
|
||||
|
||||
|
||||
class _RedditPushshiftScraper(snscrape.base.Scraper):
|
||||
def __init__(self, name, *, submissions = True, comments = True, before = None, after = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._name = name
|
||||
self._submissions = submissions
|
||||
self._comments = comments
|
||||
self._before = before
|
||||
self._after = after
|
||||
def _cmp_id(id1, id2):
|
||||
'''Compare two Reddit IDs. Returns -1 if id1 is less than id2, 0 if they are equal, and 1 if id1 is greater than id2.
|
||||
|
||||
if not type(self)._validationFunc(self._name):
|
||||
raise ValueError(f'invalid {type(self).name.split("-", 1)[1]} name')
|
||||
if not self._submissions and not self._comments:
|
||||
raise ValueError('At least one of submissions and comments must be True')
|
||||
id1 and id2 may have prefixes like t1_, but if included, they must be present on both and equal.'''
|
||||
|
||||
if id1.startswith('t') and '_' in id1:
|
||||
prefix, id1 = id1.split('_', 1)
|
||||
if not id2.startswith(f'{prefix}_'):
|
||||
raise ValueError('id2 must have the same prefix as id1')
|
||||
_, id2 = id2.split('_', 1)
|
||||
if id1.strip(string.ascii_lowercase + string.digits) != '':
|
||||
raise ValueError('invalid characters in id1')
|
||||
if id2.strip(string.ascii_lowercase + string.digits) != '':
|
||||
raise ValueError('invalid characters in id2')
|
||||
if len(id1) < len(id2):
|
||||
return -1
|
||||
if len(id1) > len(id2):
|
||||
return 1
|
||||
if id1 < id2:
|
||||
return -1
|
||||
if id1 > id2:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
class _RedditPushshiftScraper(snscrape.base.Scraper):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._headers = {'User-Agent': f'snscrape/{snscrape.version.__version__}'}
|
||||
|
||||
def _handle_rate_limiting(self, r):
|
||||
if r.status_code == 429:
|
||||
@@ -69,52 +85,11 @@ class _RedditPushshiftScraper(snscrape.base.Scraper):
|
||||
return False, 'non-200 status code'
|
||||
return True, None
|
||||
|
||||
def _cmp_id(self, id1, id2):
|
||||
'''Compare two Reddit IDs. Returns -1 if id1 is less than id2, 0 if they are equal, and 1 if id1 is greater than id2.
|
||||
|
||||
id1 and id2 may have prefixes like t1_, but if included, they must be present on both and equal.'''
|
||||
|
||||
if id1.startswith('t') and '_' in id1:
|
||||
prefix, id1 = id1.split('_', 1)
|
||||
if not id2.startswith(f'{prefix}_'):
|
||||
raise ValueError('id2 must have the same prefix as id1')
|
||||
_, id2 = id2.split('_', 1)
|
||||
if id1.strip(string.ascii_lowercase + string.digits) != '':
|
||||
raise ValueError('invalid characters in id1')
|
||||
if id2.strip(string.ascii_lowercase + string.digits) != '':
|
||||
raise ValueError('invalid characters in id2')
|
||||
if len(id1) < len(id2):
|
||||
return -1
|
||||
if len(id1) > len(id2):
|
||||
return 1
|
||||
if id1 < id2:
|
||||
return -1
|
||||
if id1 > id2:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def _iter_api(self, url, params = None):
|
||||
'''Iterate through the Pushshift API using the 'before' parameter and yield the items.'''
|
||||
lowestIdSeen = None
|
||||
if params is None:
|
||||
params = {}
|
||||
if self._before is not None:
|
||||
params['before'] = self._before
|
||||
if self._after is not None:
|
||||
params['after'] = self._after
|
||||
params['sort'] = 'desc'
|
||||
while True:
|
||||
r = self._get(url, params = params, headers = {'User-Agent': f'snscrape/{snscrape.version.__version__}'}, responseOkCallback = self._handle_rate_limiting)
|
||||
if r.status_code != 200:
|
||||
raise snscrape.base.ScraperException(f'Got status code {r.status_code}')
|
||||
obj = r.json()
|
||||
if not obj['data'] or (lowestIdSeen is not None and all(self._cmp_id(d['id'], lowestIdSeen) >= 0 for d in obj['data'])): # end of pagination
|
||||
break
|
||||
for d in obj['data']:
|
||||
if lowestIdSeen is None or self._cmp_id(d['id'], lowestIdSeen) == -1:
|
||||
yield self._api_obj_to_item(d)
|
||||
lowestIdSeen = d['id']
|
||||
params['before'] = obj["data"][-1]["created_utc"] + 1
|
||||
def _get_api(self, url, params = None):
|
||||
r = self._get(url, params = params, headers = self._headers, responseOkCallback = self._handle_rate_limiting)
|
||||
if r.status_code != 200:
|
||||
raise snscrape.base.ScraperException(f'Got status code {r.status_code}')
|
||||
return r.json()
|
||||
|
||||
def _api_obj_to_item(self, d):
|
||||
cls = Submission if 'title' in d else Comment
|
||||
@@ -154,6 +129,41 @@ class _RedditPushshiftScraper(snscrape.base.Scraper):
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class _RedditPushshiftSearchScraper(_RedditPushshiftScraper):
|
||||
def __init__(self, name, *, submissions = True, comments = True, before = None, after = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._name = name
|
||||
self._submissions = submissions
|
||||
self._comments = comments
|
||||
self._before = before
|
||||
self._after = after
|
||||
|
||||
if not type(self)._validationFunc(self._name):
|
||||
raise ValueError(f'invalid {type(self).name.split("-", 1)[1]} name')
|
||||
if not self._submissions and not self._comments:
|
||||
raise ValueError('At least one of submissions and comments must be True')
|
||||
|
||||
def _iter_api(self, url, params = None):
|
||||
'''Iterate through the Pushshift API using the 'before' parameter and yield the items.'''
|
||||
lowestIdSeen = None
|
||||
if params is None:
|
||||
params = {}
|
||||
if self._before is not None:
|
||||
params['before'] = self._before
|
||||
if self._after is not None:
|
||||
params['after'] = self._after
|
||||
params['sort'] = 'desc'
|
||||
while True:
|
||||
obj = self._get_api(url, params = params)
|
||||
if not obj['data'] or (lowestIdSeen is not None and all(_cmp_id(d['id'], lowestIdSeen) >= 0 for d in obj['data'])): # end of pagination
|
||||
break
|
||||
for d in obj['data']:
|
||||
if lowestIdSeen is None or _cmp_id(d['id'], lowestIdSeen) == -1:
|
||||
yield self._api_obj_to_item(d)
|
||||
lowestIdSeen = d['id']
|
||||
params['before'] = obj["data"][-1]["created_utc"] + 1
|
||||
|
||||
def _iter_api_submissions_and_comments(self, params: dict):
|
||||
# Retrieve both submissions and comments, interleave the results to get a reverse-chronological order
|
||||
params['size'] = '1000'
|
||||
@@ -218,19 +228,19 @@ class _RedditPushshiftScraper(snscrape.base.Scraper):
|
||||
return cls._cli_construct(args, getattr(args, name), submissions = not args.noSubmissions, comments = not args.noComments, before = args.before, after = args.after)
|
||||
|
||||
|
||||
class RedditUserScraper(_RedditPushshiftScraper):
|
||||
class RedditUserScraper(_RedditPushshiftSearchScraper):
|
||||
name = 'reddit-user'
|
||||
_validationFunc = lambda x: re.match('^[A-Za-z0-9_-]{3,20}$', x)
|
||||
_apiField = 'author'
|
||||
|
||||
|
||||
class RedditSubredditScraper(_RedditPushshiftScraper):
|
||||
class RedditSubredditScraper(_RedditPushshiftSearchScraper):
|
||||
name = 'reddit-subreddit'
|
||||
_validationFunc = lambda x: re.match('^[A-Za-z0-9][A-Za-z0-9_]{2,20}$', x)
|
||||
_apiField = 'subreddit'
|
||||
|
||||
|
||||
class RedditSearchScraper(_RedditPushshiftScraper):
|
||||
class RedditSearchScraper(_RedditPushshiftSearchScraper):
|
||||
name = 'reddit-search'
|
||||
_validationFunc = lambda x: True
|
||||
_apiField = 'q'
|
||||
|
||||
Reference in New Issue
Block a user