mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
Format and lint shared directory (#64)
This commit is contained in:
@@ -11,23 +11,39 @@ from app.shared.db import models
|
|||||||
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
||||||
db_urls = []
|
db_urls = []
|
||||||
for m in result.media:
|
for m in result.media:
|
||||||
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
|
for i, url in enumerate(m.urls):
|
||||||
|
db_urls.append(
|
||||||
|
models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))
|
||||||
|
)
|
||||||
for k, prop in m.properties.items():
|
for k, prop in m.properties.items():
|
||||||
if prop_converted := convert_if_media(prop):
|
if prop_converted := convert_if_media(prop):
|
||||||
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
|
for i, url in enumerate(prop_converted.urls):
|
||||||
|
db_urls.append(
|
||||||
|
models.ArchiveUrl(
|
||||||
|
url=url, key=prop_converted.get("id", f"{k}_{i}")
|
||||||
|
)
|
||||||
|
)
|
||||||
if isinstance(prop, list):
|
if isinstance(prop, list):
|
||||||
for i, prop_media in enumerate(prop):
|
for i, prop_media in enumerate(prop):
|
||||||
if prop_media := convert_if_media(prop_media):
|
if prop_media := convert_if_media(prop_media):
|
||||||
for j, url in enumerate(prop_media.urls):
|
for j, url in enumerate(prop_media.urls):
|
||||||
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
|
db_urls.append(
|
||||||
|
models.ArchiveUrl(
|
||||||
|
url=url,
|
||||||
|
key=prop_media.get(
|
||||||
|
"id", f"{k}{prop_media.key}_{i}.{j}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
return db_urls
|
return db_urls
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def convert_if_media(media):
|
def convert_if_media(media):
|
||||||
if isinstance(media, Media): return media
|
if isinstance(media, Media):
|
||||||
|
return media
|
||||||
elif isinstance(media, dict):
|
elif isinstance(media, dict):
|
||||||
try: return Media.from_dict(media)
|
try:
|
||||||
|
return Media.from_dict(media)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"error parsing {media} : {e}")
|
logger.debug(f"error parsing {media} : {e}")
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,26 +1,35 @@
|
|||||||
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
|
|
||||||
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.shared.db import worker_crud
|
from app.shared.db import worker_crud
|
||||||
|
|
||||||
|
|
||||||
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
|
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do
|
||||||
|
# decide
|
||||||
|
|
||||||
|
|
||||||
|
def get_store_archive_until(
|
||||||
|
db: Session, group_id: str
|
||||||
|
) -> Union[datetime.datetime, None]:
|
||||||
group = worker_crud.get_group(db, group_id)
|
group = worker_crud.get_group(db, group_id)
|
||||||
assert group, f"Group {group_id} not found."
|
assert group, f"Group {group_id} not found."
|
||||||
assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions."
|
assert group.permissions and isinstance(group.permissions, dict), (
|
||||||
|
f"Group {group_id} has no permissions."
|
||||||
|
)
|
||||||
|
|
||||||
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
|
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
|
||||||
if max_lifespan == -1: return None
|
if max_lifespan == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
|
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
|
||||||
|
|
||||||
|
|
||||||
def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime:
|
def get_store_archive_until_or_never(
|
||||||
|
db: Session, group_id: str
|
||||||
|
) -> Union[datetime.datetime, None]:
|
||||||
try:
|
try:
|
||||||
return get_store_archive_until(db, group_id)
|
return get_store_archive_until(db, group_id)
|
||||||
except AssertionError as e:
|
except AssertionError:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ def make_engine(database_url: str):
|
|||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
database_url,
|
database_url,
|
||||||
connect_args={"check_same_thread": False},
|
connect_args={"check_same_thread": False},
|
||||||
pool_size=15, # Increase pool size
|
pool_size=15, # Increase pool size
|
||||||
max_overflow=20, # Allow more temporary connections
|
max_overflow=20, # Allow more temporary connections
|
||||||
pool_recycle=1800 # Recycle connections every 30 minutes
|
pool_recycle=1800, # Recycle connections every 30 minutes
|
||||||
)
|
)
|
||||||
|
|
||||||
@event.listens_for(engine, "connect")
|
@event.listens_for(engine, "connect")
|
||||||
@@ -40,8 +40,10 @@ def make_session_local(engine: Engine):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db():
|
def get_db():
|
||||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||||
try: yield session
|
try:
|
||||||
finally: session.close()
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
def get_db_dependency():
|
def get_db_dependency():
|
||||||
@@ -59,22 +61,32 @@ def wal_checkpoint():
|
|||||||
|
|
||||||
# ASYNC connections
|
# ASYNC connections
|
||||||
async def make_async_engine(database_url: str) -> AsyncEngine:
|
async def make_async_engine(database_url: str) -> AsyncEngine:
|
||||||
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
|
engine = create_async_engine(
|
||||||
|
database_url, connect_args={"check_same_thread": False}
|
||||||
|
)
|
||||||
|
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
|
await conn.run_sync(
|
||||||
|
lambda sync_conn: sync_conn.execute(
|
||||||
|
text("PRAGMA journal_mode=WAL;")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
|
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
|
||||||
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
|
return async_sessionmaker(
|
||||||
|
engine, expire_on_commit=False, autoflush=False, autocommit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_db_async():
|
async def get_db_async():
|
||||||
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
|
engine = await make_async_engine(get_settings().async_database_path)
|
||||||
async_session = await make_async_session_local(engine)
|
async_session = await make_async_session_local(engine)
|
||||||
async with async_session() as session:
|
async with async_session() as session:
|
||||||
try: yield session
|
try:
|
||||||
finally: await engine.dispose()
|
yield session
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|||||||
@@ -42,7 +42,9 @@ class Archive(Base):
|
|||||||
id = Column(String, primary_key=True, index=True)
|
id = Column(String, primary_key=True, index=True)
|
||||||
url = Column(String, index=True)
|
url = Column(String, index=True)
|
||||||
result = Column(JSON, default=None)
|
result = Column(JSON, default=None)
|
||||||
public = Column(Boolean, default=True) # if public=false, access by group and author
|
public = Column(
|
||||||
|
Boolean, default=True
|
||||||
|
) # if public=false, access by group and author
|
||||||
deleted = Column(Boolean, default=False)
|
deleted = Column(Boolean, default=False)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
@@ -52,7 +54,11 @@ class Archive(Base):
|
|||||||
author_id = Column(String, ForeignKey("users.email"))
|
author_id = Column(String, ForeignKey("users.email"))
|
||||||
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
||||||
|
|
||||||
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
|
tags = relationship(
|
||||||
|
"Tag",
|
||||||
|
back_populates="archives",
|
||||||
|
secondary=association_table_archive_tags,
|
||||||
|
)
|
||||||
group = relationship("Group", back_populates="archives")
|
group = relationship("Group", back_populates="archives")
|
||||||
author = relationship("User", back_populates="archives")
|
author = relationship("User", back_populates="archives")
|
||||||
urls = relationship("ArchiveUrl", back_populates="archive")
|
urls = relationship("ArchiveUrl", back_populates="archive")
|
||||||
@@ -75,7 +81,11 @@ class Tag(Base):
|
|||||||
id = Column(String, primary_key=True, index=True)
|
id = Column(String, primary_key=True, index=True)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
archives = relationship(
|
||||||
|
"Archive",
|
||||||
|
back_populates="tags",
|
||||||
|
secondary=association_table_archive_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
@@ -85,7 +95,9 @@ class User(Base):
|
|||||||
|
|
||||||
archives = relationship("Archive", back_populates="author")
|
archives = relationship("Archive", back_populates="author")
|
||||||
sheets = relationship("Sheet", back_populates="author")
|
sheets = relationship("Sheet", back_populates="author")
|
||||||
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
|
groups = relationship(
|
||||||
|
"Group", back_populates="users", secondary=association_table_user_groups
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Group(Base):
|
class Group(Base):
|
||||||
@@ -101,7 +113,9 @@ class Group(Base):
|
|||||||
|
|
||||||
archives = relationship("Archive", back_populates="group")
|
archives = relationship("Archive", back_populates="group")
|
||||||
sheets = relationship("Sheet", back_populates="group")
|
sheets = relationship("Sheet", back_populates="group")
|
||||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
users = relationship(
|
||||||
|
"User", back_populates="groups", secondary=association_table_user_groups
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sheet(Base):
|
class Sheet(Base):
|
||||||
@@ -110,11 +124,27 @@ class Sheet(Base):
|
|||||||
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
||||||
name = Column(String, default=None)
|
name = Column(String, default=None)
|
||||||
author_id = Column(String, ForeignKey("users.email"))
|
author_id = Column(String, ForeignKey("users.email"))
|
||||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
group_id = Column(
|
||||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
String,
|
||||||
|
ForeignKey("groups.id"),
|
||||||
|
doc="Group ID, user must be in a group to create a sheet.",
|
||||||
|
)
|
||||||
|
frequency = Column(
|
||||||
|
String,
|
||||||
|
default="daily",
|
||||||
|
doc="Frequency of archiving: hourly, daily, weekly.",
|
||||||
|
)
|
||||||
# TODO: stats is not being used, consider removing
|
# TODO: stats is not being used, consider removing
|
||||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
stats = Column(
|
||||||
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
JSON,
|
||||||
|
default={},
|
||||||
|
doc="Sheet statistics like total links, total rows, ...",
|
||||||
|
)
|
||||||
|
last_url_archived_at = Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
doc="Last time a new link was archived.",
|
||||||
|
)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ from app.shared.db import models
|
|||||||
# TODO: isolate database operations away from worker and into WEB
|
# TODO: isolate database operations away from worker and into WEB
|
||||||
# ONLY WORKER
|
# ONLY WORKER
|
||||||
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
db_sheet = (
|
||||||
|
db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||||
|
)
|
||||||
if db_sheet:
|
if db_sheet:
|
||||||
db_sheet.last_url_archived_at = datetime.now()
|
db_sheet.last_url_archived_at = datetime.now()
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -19,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
|||||||
|
|
||||||
# ONLY WORKER and INTEROP
|
# ONLY WORKER and INTEROP
|
||||||
|
|
||||||
|
|
||||||
def get_group(db: Session, group_name: str) -> models.Group:
|
def get_group(db: Session, group_name: str) -> models.Group:
|
||||||
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||||
|
|
||||||
|
|
||||||
def create_or_get_user(db: Session, author_id: str) -> models.User:
|
def create_or_get_user(db: Session, author_id: str) -> models.User:
|
||||||
if type(author_id) == str: author_id = author_id.lower()
|
if isinstance(author_id, str):
|
||||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
author_id = author_id.lower()
|
||||||
|
db_user = (
|
||||||
|
db.query(models.User).filter(models.User.email == author_id).first()
|
||||||
|
)
|
||||||
if not db_user:
|
if not db_user:
|
||||||
db_user = models.User(email=author_id)
|
db_user = models.User(email=author_id)
|
||||||
db.add(db_user)
|
db.add(db_user)
|
||||||
@@ -43,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
|
|||||||
return db_tag
|
return db_tag
|
||||||
|
|
||||||
|
|
||||||
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
|
def create_archive(
|
||||||
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
|
db: Session,
|
||||||
|
archive: schemas.ArchiveCreate,
|
||||||
|
tags: list[models.Tag],
|
||||||
|
urls: list[models.ArchiveUrl],
|
||||||
|
) -> models.Archive:
|
||||||
|
db_archive = models.Archive(
|
||||||
|
id=archive.id,
|
||||||
|
url=archive.url,
|
||||||
|
result=archive.result,
|
||||||
|
public=archive.public,
|
||||||
|
author_id=archive.author_id,
|
||||||
|
group_id=archive.group_id,
|
||||||
|
sheet_id=archive.sheet_id,
|
||||||
|
store_until=archive.store_until,
|
||||||
|
)
|
||||||
db_archive.tags = tags
|
db_archive.tags = tags
|
||||||
db_archive.urls = urls
|
db_archive.urls = urls
|
||||||
db.add(db_archive)
|
db.add(db_archive)
|
||||||
@@ -53,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
|
|||||||
return db_archive
|
return db_archive
|
||||||
|
|
||||||
|
|
||||||
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
|
def store_archived_url(
|
||||||
|
db: Session, archive: schemas.ArchiveCreate
|
||||||
|
) -> models.Archive:
|
||||||
# create and load user, tags, if needed
|
# create and load user, tags, if needed
|
||||||
create_or_get_user(db, archive.author_id)
|
create_or_get_user(db, archive.author_id)
|
||||||
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
|
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
|
||||||
# insert everything
|
# insert everything
|
||||||
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
|
db_archive = create_archive(
|
||||||
|
db, archive=archive, tags=db_tags, urls=archive.urls
|
||||||
|
)
|
||||||
return db_archive
|
return db_archive
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ logger.add("logs/api_logs.log", retention="30 days")
|
|||||||
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
|
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
|
||||||
|
|
||||||
|
|
||||||
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
|
def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
|
||||||
if not traceback_str: traceback_str = traceback.format_exc()
|
if not traceback_str:
|
||||||
if extra: extra = f"{extra}\n"
|
traceback_str = traceback.format_exc()
|
||||||
|
if extra:
|
||||||
|
extra = f"{extra}\n"
|
||||||
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ class SubmitSheet(BaseModel):
|
|||||||
group_id: str = "default"
|
group_id: str = "default"
|
||||||
tags: set[str] | None = set()
|
tags: set[str] | None = set()
|
||||||
|
|
||||||
|
|
||||||
class ArchiveUrl(BaseModel):
|
class ArchiveUrl(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
public: bool = False
|
public: bool = False
|
||||||
@@ -18,6 +19,7 @@ class ArchiveUrl(BaseModel):
|
|||||||
group_id: str | None
|
group_id: str | None
|
||||||
tags: set[str] | None = set()
|
tags: set[str] | None = set()
|
||||||
|
|
||||||
|
|
||||||
class ArchiveResult(BaseModel):
|
class ArchiveResult(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
url: str
|
url: str
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Annotated, Set
|
from typing import Annotated, Set
|
||||||
@@ -9,32 +8,40 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=os.environ.get("ENVIRONMENT_FILE"),
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
str_strip_whitespace=True,
|
||||||
|
)
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
|
# general
|
||||||
|
|
||||||
# general
|
|
||||||
SERVE_LOCAL_ARCHIVE: str | None = None
|
SERVE_LOCAL_ARCHIVE: str | None = None
|
||||||
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
|
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
|
||||||
|
|
||||||
# database
|
# database
|
||||||
DATABASE_PATH: str
|
DATABASE_PATH: str
|
||||||
DATABASE_QUERY_LIMIT: int = 100
|
DATABASE_QUERY_LIMIT: int = 100
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ASYNC_DATABASE_PATH(self) -> str:
|
def async_database_path(self) -> str:
|
||||||
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
|
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
|
||||||
|
|
||||||
# security
|
# security
|
||||||
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
||||||
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
|
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
|
||||||
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
CHROME_APP_IDS: Annotated[
|
||||||
|
Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
|
||||||
|
]
|
||||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||||
|
|
||||||
# redis
|
# redis
|
||||||
REDIS_PASSWORD: str = ""
|
REDIS_PASSWORD: str = ""
|
||||||
REDIS_HOSTNAME: str = "localhost"
|
REDIS_HOSTNAME: str = "localhost"
|
||||||
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def CELERY_BROKER_URL(self)-> str:
|
def celery_broker_url(self) -> str:
|
||||||
if self.REDIS_PASSWORD:
|
if self.REDIS_PASSWORD:
|
||||||
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
|
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
|
||||||
return f"redis://{self.REDIS_HOSTNAME}:6379"
|
return f"redis://{self.REDIS_HOSTNAME}:6379"
|
||||||
@@ -46,7 +53,7 @@ class Settings(BaseSettings):
|
|||||||
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
|
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
|
||||||
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
|
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
|
||||||
|
|
||||||
# observability
|
# observability
|
||||||
REPEAT_COUNT_METRICS_SECONDS: int = 30
|
REPEAT_COUNT_METRICS_SECONDS: int = 30
|
||||||
|
|
||||||
# email configuration, if needed
|
# email configuration, if needed
|
||||||
@@ -58,8 +65,9 @@ class Settings(BaseSettings):
|
|||||||
MAIL_PORT: int = 587
|
MAIL_PORT: int = 587
|
||||||
MAIL_STARTTLS: bool = False
|
MAIL_STARTTLS: bool = False
|
||||||
MAIL_SSL_TLS: bool = True
|
MAIL_SSL_TLS: bool = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def MAIL_CONFIG(self) -> str:
|
def mail_config(self) -> ConnectionConfig:
|
||||||
return ConnectionConfig(
|
return ConnectionConfig(
|
||||||
MAIL_FROM=self.MAIL_FROM,
|
MAIL_FROM=self.MAIL_FROM,
|
||||||
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
|
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
@@ -11,14 +10,14 @@ from app.shared.settings import get_settings
|
|||||||
def get_celery(name: str = "") -> Celery:
|
def get_celery(name: str = "") -> Celery:
|
||||||
return Celery(
|
return Celery(
|
||||||
name,
|
name,
|
||||||
broker_url=get_settings().CELERY_BROKER_URL,
|
broker_url=get_settings().celery_broker_url,
|
||||||
result_backend=get_settings().CELERY_BROKER_URL,
|
result_backend=get_settings().celery_broker_url,
|
||||||
broker_connection_retry_on_startup=False,
|
broker_connection_retry_on_startup=False,
|
||||||
broker_transport_options={
|
broker_transport_options={
|
||||||
'queue_order_strategy': 'priority',
|
"queue_order_strategy": "priority",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_redis() -> redis.Redis:
|
def get_redis() -> redis.Redis:
|
||||||
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
|
return redis.Redis.from_url(get_settings().celery_broker_url)
|
||||||
|
|||||||
@@ -19,13 +19,16 @@ class UserGroups:
|
|||||||
user_groups = self.read_yaml(filename)
|
user_groups = self.read_yaml(filename)
|
||||||
self.validate_and_load(user_groups)
|
self.validate_and_load(user_groups)
|
||||||
|
|
||||||
def read_yaml(self, user_groups_filename):
|
@staticmethod
|
||||||
|
def read_yaml(user_groups_filename):
|
||||||
# read yaml safely
|
# read yaml safely
|
||||||
with open(user_groups_filename) as inf:
|
with open(user_groups_filename) as inf:
|
||||||
try:
|
try:
|
||||||
return yaml.safe_load(inf)
|
return yaml.safe_load(inf)
|
||||||
except yaml.YAMLError as e:
|
except yaml.YAMLError as e:
|
||||||
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
|
logger.error(
|
||||||
|
f"could not open user groups filename {user_groups_filename}: {e}"
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def validate_and_load(self, user_groups):
|
def validate_and_load(self, user_groups):
|
||||||
@@ -52,22 +55,36 @@ class GroupPermissions(BaseModel):
|
|||||||
max_monthly_mbs: int = 0
|
max_monthly_mbs: int = 0
|
||||||
priority: str = "low"
|
priority: str = "low"
|
||||||
|
|
||||||
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
|
@classmethod
|
||||||
|
@field_validator(
|
||||||
|
"max_sheets",
|
||||||
|
"max_archive_lifespan_months",
|
||||||
|
"max_monthly_urls",
|
||||||
|
"max_monthly_mbs",
|
||||||
|
mode="before",
|
||||||
|
)
|
||||||
def validate_max_values(cls, v):
|
def validate_max_values(cls, v):
|
||||||
if v < -1:
|
if v < -1:
|
||||||
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
|
raise ValueError(
|
||||||
|
"max_* values should be positive integers or -1 (for no limit)."
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator('sheet_frequency', mode='before')
|
@classmethod
|
||||||
|
@field_validator("sheet_frequency", mode="before")
|
||||||
def validate_sheet_frequency(cls, v):
|
def validate_sheet_frequency(cls, v):
|
||||||
if not v: return []
|
if not v:
|
||||||
|
return []
|
||||||
allowed = ["daily", "hourly"]
|
allowed = ["daily", "hourly"]
|
||||||
for k in v:
|
for k in v:
|
||||||
if k not in allowed:
|
if k not in allowed:
|
||||||
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
|
raise ValueError(
|
||||||
|
f"Invalid sheet frequency: '{k}', expected one of {allowed}"
|
||||||
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator('priority', mode='before')
|
@classmethod
|
||||||
|
@field_validator("priority", mode="before")
|
||||||
def validate_priority(cls, v):
|
def validate_priority(cls, v):
|
||||||
v = v.lower()
|
v = v.lower()
|
||||||
if v not in ["low", "high"]:
|
if v not in ["low", "high"]:
|
||||||
@@ -81,7 +98,8 @@ class GroupModel(BaseModel):
|
|||||||
orchestrator_sheet: str
|
orchestrator_sheet: str
|
||||||
permissions: GroupPermissions
|
permissions: GroupPermissions
|
||||||
|
|
||||||
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
|
@classmethod
|
||||||
|
@field_validator("orchestrator", "orchestrator_sheet", mode="before")
|
||||||
def validate_orchestrator(cls, v):
|
def validate_orchestrator(cls, v):
|
||||||
if not os.path.exists(v):
|
if not os.path.exists(v):
|
||||||
raise ValueError(f"Orchestrator file not found with this path: {v}")
|
raise ValueError(f"Orchestrator file not found with this path: {v}")
|
||||||
@@ -105,13 +123,17 @@ class GroupModel(BaseModel):
|
|||||||
|
|
||||||
service_account_json = find_service_account_email(orch)
|
service_account_json = find_service_account_email(orch)
|
||||||
if not service_account_json:
|
if not service_account_json:
|
||||||
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
|
raise ValueError(
|
||||||
|
f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
|
||||||
|
)
|
||||||
|
|
||||||
with open(service_account_json) as f:
|
with open(service_account_json) as f:
|
||||||
self._service_account_email = json.load(f).get("client_email")
|
self._service_account_email = json.load(f).get("client_email")
|
||||||
|
|
||||||
if not self._service_account_email:
|
if not self._service_account_email:
|
||||||
raise ValueError(f"Service account email not found in {service_account_json}.")
|
raise ValueError(
|
||||||
|
f"Service account email not found in {service_account_json}."
|
||||||
|
)
|
||||||
|
|
||||||
return self._service_account_email
|
return self._service_account_email
|
||||||
|
|
||||||
@@ -121,29 +143,45 @@ class UserGroupModel(BaseModel):
|
|||||||
domains: Dict[str, List[str]] = Field(default_factory=dict)
|
domains: Dict[str, List[str]] = Field(default_factory=dict)
|
||||||
groups: Dict[str, GroupModel] = Field(default_factory=dict)
|
groups: Dict[str, GroupModel] = Field(default_factory=dict)
|
||||||
|
|
||||||
@field_validator('users', mode='before')
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@field_validator("users", mode="before")
|
||||||
def validate_emails(cls, v):
|
def validate_emails(cls, v):
|
||||||
for email in v.keys():
|
for email in v.keys():
|
||||||
if '@' not in email:
|
if "@" not in email:
|
||||||
raise ValueError(f"Invalid user, it should be an address: {email}")
|
raise ValueError(
|
||||||
|
f"Invalid user, it should be an address: {email}"
|
||||||
|
)
|
||||||
if not v[email]:
|
if not v[email]:
|
||||||
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
|
raise ValueError(
|
||||||
|
f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
|
||||||
|
)
|
||||||
# all users belong to the default group
|
# all users belong to the default group
|
||||||
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
|
return {
|
||||||
|
k.lower().strip(): list(
|
||||||
|
set(["default"] + [g.lower().strip() for g in v])
|
||||||
|
)
|
||||||
|
for k, v in v.items()
|
||||||
|
}
|
||||||
|
|
||||||
@field_validator('domains', mode='before')
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@field_validator("domains", mode="before")
|
||||||
def validate_domains(cls, v):
|
def validate_domains(cls, v):
|
||||||
for domain, members in v.items():
|
for domain, members in v.items():
|
||||||
if '.' not in domain:
|
if "." not in domain:
|
||||||
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
|
raise ValueError(
|
||||||
|
f"Invalid domain, it should contain a dot: {domain}"
|
||||||
|
)
|
||||||
if not members:
|
if not members:
|
||||||
raise ValueError(f"Domain {domain} should have at least one member.")
|
raise ValueError(
|
||||||
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
|
f"Domain {domain} should have at least one member."
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
k.lower().strip(): list({[g.lower().strip() for g in v]})
|
||||||
|
for k, v in v.items()
|
||||||
|
}
|
||||||
|
|
||||||
@field_validator('groups', mode='before')
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@field_validator("groups", mode="before")
|
||||||
def validate_groups(cls, v):
|
def validate_groups(cls, v):
|
||||||
if "default" not in v.keys():
|
if "default" not in v.keys():
|
||||||
raise ValueError("Please include a 'default' group.")
|
raise ValueError("Please include a 'default' group.")
|
||||||
@@ -154,20 +192,28 @@ class UserGroupModel(BaseModel):
|
|||||||
raise ValueError(f"Group names should be lowercase: {group}")
|
raise ValueError(f"Group names should be lowercase: {group}")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode="after")
|
||||||
def check_groups_consistency(self) -> Self:
|
def check_groups_consistency(self) -> Self:
|
||||||
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
|
groups_in_domains = {
|
||||||
groups_in_users = set([g for user in self.users for g in self.users[user]])
|
g for domain in self.domains for g in self.domains[domain]
|
||||||
|
}
|
||||||
|
groups_in_users = {g for user in self.users for g in self.users[user]}
|
||||||
configured_groups = set(self.groups.keys())
|
configured_groups = set(self.groups.keys())
|
||||||
|
|
||||||
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
|
# groups mentioned in domains and users should be defined, but this is
|
||||||
|
# not a ValueError since historical DB data may require it
|
||||||
if groups_in_domains - configured_groups:
|
if groups_in_domains - configured_groups:
|
||||||
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
|
logger.warning(
|
||||||
|
f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
|
||||||
|
)
|
||||||
if groups_in_users - configured_groups:
|
if groups_in_users - configured_groups:
|
||||||
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
|
logger.warning(
|
||||||
|
f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# for the API return values
|
# for the API return values
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
|
def fnv1a_hash_mod(s: str, modulo: int) -> int:
|
||||||
def fnv1a_hash_mod(s: str, modulo:int) -> int:
|
# receives a string and returns a number in [0:modulo-1], ensures an even
|
||||||
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
|
# distribution over the modulo range
|
||||||
hash = 0x811c9dc5 # FNV offset basis
|
offset_basis_hash = 0x811C9DC5 # FNV offset basis
|
||||||
fnv_prime = 0x01000193 # FNV prime
|
fnv_prime = 0x01000193 # FNV prime
|
||||||
for char in s:
|
for char in s:
|
||||||
hash ^= ord(char)
|
offset_basis_hash ^= ord(char)
|
||||||
hash *= fnv_prime
|
offset_basis_hash *= fnv_prime
|
||||||
hash &= 0xFFFFFFFF # Keep it 32-bit
|
offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
|
||||||
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
|
return (
|
||||||
|
offset_basis_hash
|
||||||
|
if offset_basis_hash < 0x80000000
|
||||||
|
else offset_basis_hash - 0x100000000
|
||||||
|
) % modulo
|
||||||
|
|||||||
@@ -84,9 +84,9 @@ def db_session(test_db):
|
|||||||
@pytest_asyncio.fixture()
|
@pytest_asyncio.fixture()
|
||||||
async def async_test_db(get_settings: Settings):
|
async def async_test_db(get_settings: Settings):
|
||||||
get_user_group_names.cache_clear()
|
get_user_group_names.cache_clear()
|
||||||
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
|
engine = await make_async_engine(get_settings.async_database_path)
|
||||||
|
|
||||||
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
|
fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
|
||||||
if not os.path.exists(fs):
|
if not os.path.exists(fs):
|
||||||
open(fs, "w").close()
|
open(fs, "w").close()
|
||||||
|
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ async def notify_about_expired_archives():
|
|||||||
user_archives[archive.author_id].append(archive)
|
user_archives[archive.author_id].append(archive)
|
||||||
|
|
||||||
if user_archives:
|
if user_archives:
|
||||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
fastmail = FastMail(get_settings().mail_config)
|
||||||
# notify users
|
# notify users
|
||||||
for email in user_archives:
|
for email in user_archives:
|
||||||
list_of_archives = "\n".join(
|
list_of_archives = "\n".join(
|
||||||
@@ -224,7 +224,7 @@ async def delete_stale_sheets():
|
|||||||
if not user_sheets:
|
if not user_sheets:
|
||||||
return
|
return
|
||||||
|
|
||||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
fastmail = FastMail(get_settings().mail_config)
|
||||||
# notify users
|
# notify users
|
||||||
for email in user_sheets:
|
for email in user_sheets:
|
||||||
list_of_sheets = "\n".join(
|
list_of_sheets = "\n".join(
|
||||||
|
|||||||
Reference in New Issue
Block a user