mirror of
https://github.com/bellingcat/auto-archiver-api.git
synced 2026-06-08 03:28:35 +03:00
Format and lint shared directory (#64)
This commit is contained in:
@@ -11,23 +11,39 @@ from app.shared.db import models
|
||||
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
|
||||
db_urls = []
|
||||
for m in result.media:
|
||||
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
|
||||
for i, url in enumerate(m.urls):
|
||||
db_urls.append(
|
||||
models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))
|
||||
)
|
||||
for k, prop in m.properties.items():
|
||||
if prop_converted := convert_if_media(prop):
|
||||
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
|
||||
for i, url in enumerate(prop_converted.urls):
|
||||
db_urls.append(
|
||||
models.ArchiveUrl(
|
||||
url=url, key=prop_converted.get("id", f"{k}_{i}")
|
||||
)
|
||||
)
|
||||
if isinstance(prop, list):
|
||||
for i, prop_media in enumerate(prop):
|
||||
if prop_media := convert_if_media(prop_media):
|
||||
for j, url in enumerate(prop_media.urls):
|
||||
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
|
||||
db_urls.append(
|
||||
models.ArchiveUrl(
|
||||
url=url,
|
||||
key=prop_media.get(
|
||||
"id", f"{k}{prop_media.key}_{i}.{j}"
|
||||
),
|
||||
)
|
||||
)
|
||||
return db_urls
|
||||
|
||||
|
||||
|
||||
def convert_if_media(media):
|
||||
if isinstance(media, Media): return media
|
||||
if isinstance(media, Media):
|
||||
return media
|
||||
elif isinstance(media, dict):
|
||||
try: return Media.from_dict(media)
|
||||
try:
|
||||
return Media.from_dict(media)
|
||||
except Exception as e:
|
||||
logger.debug(f"error parsing {media} : {e}")
|
||||
return False
|
||||
|
||||
@@ -1,26 +1,35 @@
|
||||
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
|
||||
|
||||
|
||||
import datetime
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.shared.db import worker_crud
|
||||
|
||||
|
||||
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
|
||||
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do
|
||||
# decide
|
||||
|
||||
|
||||
def get_store_archive_until(
|
||||
db: Session, group_id: str
|
||||
) -> Union[datetime.datetime, None]:
|
||||
group = worker_crud.get_group(db, group_id)
|
||||
assert group, f"Group {group_id} not found."
|
||||
assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions."
|
||||
assert group.permissions and isinstance(group.permissions, dict), (
|
||||
f"Group {group_id} has no permissions."
|
||||
)
|
||||
|
||||
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
|
||||
if max_lifespan == -1: return None
|
||||
if max_lifespan == -1:
|
||||
return None
|
||||
|
||||
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
|
||||
|
||||
|
||||
def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime:
|
||||
def get_store_archive_until_or_never(
|
||||
db: Session, group_id: str
|
||||
) -> Union[datetime.datetime, None]:
|
||||
try:
|
||||
return get_store_archive_until(db, group_id)
|
||||
except AssertionError as e:
|
||||
except AssertionError:
|
||||
return None
|
||||
|
||||
@@ -18,9 +18,9 @@ def make_engine(database_url: str):
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_size=15, # Increase pool size
|
||||
max_overflow=20, # Allow more temporary connections
|
||||
pool_recycle=1800 # Recycle connections every 30 minutes
|
||||
pool_size=15, # Increase pool size
|
||||
max_overflow=20, # Allow more temporary connections
|
||||
pool_recycle=1800, # Recycle connections every 30 minutes
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
@@ -40,8 +40,10 @@ def make_session_local(engine: Engine):
|
||||
@contextmanager
|
||||
def get_db():
|
||||
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
|
||||
try: yield session
|
||||
finally: session.close()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_db_dependency():
|
||||
@@ -59,22 +61,32 @@ def wal_checkpoint():
|
||||
|
||||
# ASYNC connections
|
||||
async def make_async_engine(database_url: str) -> AsyncEngine:
|
||||
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
|
||||
engine = create_async_engine(
|
||||
database_url, connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
|
||||
await conn.run_sync(
|
||||
lambda sync_conn: sync_conn.execute(
|
||||
text("PRAGMA journal_mode=WAL;")
|
||||
)
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
|
||||
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
|
||||
return async_sessionmaker(
|
||||
engine, expire_on_commit=False, autoflush=False, autocommit=False
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_async():
|
||||
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
|
||||
engine = await make_async_engine(get_settings().async_database_path)
|
||||
async_session = await make_async_session_local(engine)
|
||||
async with async_session() as session:
|
||||
try: yield session
|
||||
finally: await engine.dispose()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
@@ -42,7 +42,9 @@ class Archive(Base):
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
url = Column(String, index=True)
|
||||
result = Column(JSON, default=None)
|
||||
public = Column(Boolean, default=True) # if public=false, access by group and author
|
||||
public = Column(
|
||||
Boolean, default=True
|
||||
) # if public=false, access by group and author
|
||||
deleted = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
@@ -52,7 +54,11 @@ class Archive(Base):
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
|
||||
|
||||
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
|
||||
tags = relationship(
|
||||
"Tag",
|
||||
back_populates="archives",
|
||||
secondary=association_table_archive_tags,
|
||||
)
|
||||
group = relationship("Group", back_populates="archives")
|
||||
author = relationship("User", back_populates="archives")
|
||||
urls = relationship("ArchiveUrl", back_populates="archive")
|
||||
@@ -75,7 +81,11 @@ class Tag(Base):
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
|
||||
archives = relationship(
|
||||
"Archive",
|
||||
back_populates="tags",
|
||||
secondary=association_table_archive_tags,
|
||||
)
|
||||
|
||||
|
||||
class User(Base):
|
||||
@@ -85,7 +95,9 @@ class User(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="author")
|
||||
sheets = relationship("Sheet", back_populates="author")
|
||||
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
|
||||
groups = relationship(
|
||||
"Group", back_populates="users", secondary=association_table_user_groups
|
||||
)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
@@ -101,7 +113,9 @@ class Group(Base):
|
||||
|
||||
archives = relationship("Archive", back_populates="group")
|
||||
sheets = relationship("Sheet", back_populates="group")
|
||||
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
|
||||
users = relationship(
|
||||
"User", back_populates="groups", secondary=association_table_user_groups
|
||||
)
|
||||
|
||||
|
||||
class Sheet(Base):
|
||||
@@ -110,11 +124,27 @@ class Sheet(Base):
|
||||
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
|
||||
name = Column(String, default=None)
|
||||
author_id = Column(String, ForeignKey("users.email"))
|
||||
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
|
||||
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
|
||||
group_id = Column(
|
||||
String,
|
||||
ForeignKey("groups.id"),
|
||||
doc="Group ID, user must be in a group to create a sheet.",
|
||||
)
|
||||
frequency = Column(
|
||||
String,
|
||||
default="daily",
|
||||
doc="Frequency of archiving: hourly, daily, weekly.",
|
||||
)
|
||||
# TODO: stats is not being used, consider removing
|
||||
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
|
||||
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
|
||||
stats = Column(
|
||||
JSON,
|
||||
default={},
|
||||
doc="Sheet statistics like total links, total rows, ...",
|
||||
)
|
||||
last_url_archived_at = Column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
doc="Last time a new link was archived.",
|
||||
)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ from app.shared.db import models
|
||||
# TODO: isolate database operations away from worker and into WEB
|
||||
# ONLY WORKER
|
||||
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
db_sheet = (
|
||||
db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
|
||||
)
|
||||
if db_sheet:
|
||||
db_sheet.last_url_archived_at = datetime.now()
|
||||
db.commit()
|
||||
@@ -19,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
|
||||
|
||||
# ONLY WORKER and INTEROP
|
||||
|
||||
|
||||
def get_group(db: Session, group_name: str) -> models.Group:
|
||||
return db.query(models.Group).filter(models.Group.id == group_name).first()
|
||||
|
||||
|
||||
def create_or_get_user(db: Session, author_id: str) -> models.User:
|
||||
if type(author_id) == str: author_id = author_id.lower()
|
||||
db_user = db.query(models.User).filter(models.User.email == author_id).first()
|
||||
if isinstance(author_id, str):
|
||||
author_id = author_id.lower()
|
||||
db_user = (
|
||||
db.query(models.User).filter(models.User.email == author_id).first()
|
||||
)
|
||||
if not db_user:
|
||||
db_user = models.User(email=author_id)
|
||||
db.add(db_user)
|
||||
@@ -43,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
|
||||
return db_tag
|
||||
|
||||
|
||||
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
|
||||
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
|
||||
def create_archive(
|
||||
db: Session,
|
||||
archive: schemas.ArchiveCreate,
|
||||
tags: list[models.Tag],
|
||||
urls: list[models.ArchiveUrl],
|
||||
) -> models.Archive:
|
||||
db_archive = models.Archive(
|
||||
id=archive.id,
|
||||
url=archive.url,
|
||||
result=archive.result,
|
||||
public=archive.public,
|
||||
author_id=archive.author_id,
|
||||
group_id=archive.group_id,
|
||||
sheet_id=archive.sheet_id,
|
||||
store_until=archive.store_until,
|
||||
)
|
||||
db_archive.tags = tags
|
||||
db_archive.urls = urls
|
||||
db.add(db_archive)
|
||||
@@ -53,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
|
||||
return db_archive
|
||||
|
||||
|
||||
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
|
||||
def store_archived_url(
|
||||
db: Session, archive: schemas.ArchiveCreate
|
||||
) -> models.Archive:
|
||||
# create and load user, tags, if needed
|
||||
create_or_get_user(db, archive.author_id)
|
||||
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
|
||||
# insert everything
|
||||
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
|
||||
db_archive = create_archive(
|
||||
db, archive=archive, tags=db_tags, urls=archive.urls
|
||||
)
|
||||
return db_archive
|
||||
|
||||
@@ -8,7 +8,9 @@ logger.add("logs/api_logs.log", retention="30 days")
|
||||
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
|
||||
|
||||
|
||||
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
|
||||
if not traceback_str: traceback_str = traceback.format_exc()
|
||||
if extra: extra = f"{extra}\n"
|
||||
def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
|
||||
if not traceback_str:
|
||||
traceback_str = traceback.format_exc()
|
||||
if extra:
|
||||
extra = f"{extra}\n"
|
||||
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")
|
||||
|
||||
@@ -11,6 +11,7 @@ class SubmitSheet(BaseModel):
|
||||
group_id: str = "default"
|
||||
tags: set[str] | None = set()
|
||||
|
||||
|
||||
class ArchiveUrl(BaseModel):
|
||||
url: str
|
||||
public: bool = False
|
||||
@@ -18,6 +19,7 @@ class ArchiveUrl(BaseModel):
|
||||
group_id: str | None
|
||||
tags: set[str] | None = set()
|
||||
|
||||
|
||||
class ArchiveResult(BaseModel):
|
||||
id: str
|
||||
url: str
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Annotated, Set
|
||||
@@ -9,32 +8,40 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=os.environ.get("ENVIRONMENT_FILE"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
|
||||
|
||||
# general
|
||||
# general
|
||||
SERVE_LOCAL_ARCHIVE: str | None = None
|
||||
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
|
||||
|
||||
# database
|
||||
# database
|
||||
DATABASE_PATH: str
|
||||
DATABASE_QUERY_LIMIT: int = 100
|
||||
|
||||
@property
|
||||
def ASYNC_DATABASE_PATH(self) -> str:
|
||||
def async_database_path(self) -> str:
|
||||
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
|
||||
# security
|
||||
# security
|
||||
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
|
||||
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
|
||||
CHROME_APP_IDS: Annotated[
|
||||
Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
|
||||
]
|
||||
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
|
||||
|
||||
# redis
|
||||
REDIS_PASSWORD: str = ""
|
||||
REDIS_HOSTNAME: str = "localhost"
|
||||
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
|
||||
|
||||
@property
|
||||
def CELERY_BROKER_URL(self)-> str:
|
||||
def celery_broker_url(self) -> str:
|
||||
if self.REDIS_PASSWORD:
|
||||
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
|
||||
return f"redis://{self.REDIS_HOSTNAME}:6379"
|
||||
@@ -46,7 +53,7 @@ class Settings(BaseSettings):
|
||||
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
|
||||
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
|
||||
|
||||
# observability
|
||||
# observability
|
||||
REPEAT_COUNT_METRICS_SECONDS: int = 30
|
||||
|
||||
# email configuration, if needed
|
||||
@@ -58,8 +65,9 @@ class Settings(BaseSettings):
|
||||
MAIL_PORT: int = 587
|
||||
MAIL_STARTTLS: bool = False
|
||||
MAIL_SSL_TLS: bool = True
|
||||
|
||||
@property
|
||||
def MAIL_CONFIG(self) -> str:
|
||||
def mail_config(self) -> ConnectionConfig:
|
||||
return ConnectionConfig(
|
||||
MAIL_FROM=self.MAIL_FROM,
|
||||
MAIL_FROM_NAME=self.MAIL_FROM_NAME,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from celery import Celery
|
||||
@@ -11,14 +10,14 @@ from app.shared.settings import get_settings
|
||||
def get_celery(name: str = "") -> Celery:
|
||||
return Celery(
|
||||
name,
|
||||
broker_url=get_settings().CELERY_BROKER_URL,
|
||||
result_backend=get_settings().CELERY_BROKER_URL,
|
||||
broker_url=get_settings().celery_broker_url,
|
||||
result_backend=get_settings().celery_broker_url,
|
||||
broker_connection_retry_on_startup=False,
|
||||
broker_transport_options={
|
||||
'queue_order_strategy': 'priority',
|
||||
}
|
||||
"queue_order_strategy": "priority",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
|
||||
return redis.Redis.from_url(get_settings().celery_broker_url)
|
||||
|
||||
@@ -19,13 +19,16 @@ class UserGroups:
|
||||
user_groups = self.read_yaml(filename)
|
||||
self.validate_and_load(user_groups)
|
||||
|
||||
def read_yaml(self, user_groups_filename):
|
||||
@staticmethod
|
||||
def read_yaml(user_groups_filename):
|
||||
# read yaml safely
|
||||
with open(user_groups_filename) as inf:
|
||||
try:
|
||||
return yaml.safe_load(inf)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
|
||||
logger.error(
|
||||
f"could not open user groups filename {user_groups_filename}: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def validate_and_load(self, user_groups):
|
||||
@@ -52,22 +55,36 @@ class GroupPermissions(BaseModel):
|
||||
max_monthly_mbs: int = 0
|
||||
priority: str = "low"
|
||||
|
||||
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
|
||||
@classmethod
|
||||
@field_validator(
|
||||
"max_sheets",
|
||||
"max_archive_lifespan_months",
|
||||
"max_monthly_urls",
|
||||
"max_monthly_mbs",
|
||||
mode="before",
|
||||
)
|
||||
def validate_max_values(cls, v):
|
||||
if v < -1:
|
||||
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
|
||||
raise ValueError(
|
||||
"max_* values should be positive integers or -1 (for no limit)."
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator('sheet_frequency', mode='before')
|
||||
@classmethod
|
||||
@field_validator("sheet_frequency", mode="before")
|
||||
def validate_sheet_frequency(cls, v):
|
||||
if not v: return []
|
||||
if not v:
|
||||
return []
|
||||
allowed = ["daily", "hourly"]
|
||||
for k in v:
|
||||
if k not in allowed:
|
||||
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
|
||||
raise ValueError(
|
||||
f"Invalid sheet frequency: '{k}', expected one of {allowed}"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator('priority', mode='before')
|
||||
@classmethod
|
||||
@field_validator("priority", mode="before")
|
||||
def validate_priority(cls, v):
|
||||
v = v.lower()
|
||||
if v not in ["low", "high"]:
|
||||
@@ -81,7 +98,8 @@ class GroupModel(BaseModel):
|
||||
orchestrator_sheet: str
|
||||
permissions: GroupPermissions
|
||||
|
||||
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
|
||||
@classmethod
|
||||
@field_validator("orchestrator", "orchestrator_sheet", mode="before")
|
||||
def validate_orchestrator(cls, v):
|
||||
if not os.path.exists(v):
|
||||
raise ValueError(f"Orchestrator file not found with this path: {v}")
|
||||
@@ -105,13 +123,17 @@ class GroupModel(BaseModel):
|
||||
|
||||
service_account_json = find_service_account_email(orch)
|
||||
if not service_account_json:
|
||||
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
|
||||
raise ValueError(
|
||||
f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
|
||||
)
|
||||
|
||||
with open(service_account_json) as f:
|
||||
self._service_account_email = json.load(f).get("client_email")
|
||||
|
||||
if not self._service_account_email:
|
||||
raise ValueError(f"Service account email not found in {service_account_json}.")
|
||||
raise ValueError(
|
||||
f"Service account email not found in {service_account_json}."
|
||||
)
|
||||
|
||||
return self._service_account_email
|
||||
|
||||
@@ -121,29 +143,45 @@ class UserGroupModel(BaseModel):
|
||||
domains: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
groups: Dict[str, GroupModel] = Field(default_factory=dict)
|
||||
|
||||
@field_validator('users', mode='before')
|
||||
@classmethod
|
||||
@field_validator("users", mode="before")
|
||||
def validate_emails(cls, v):
|
||||
for email in v.keys():
|
||||
if '@' not in email:
|
||||
raise ValueError(f"Invalid user, it should be an address: {email}")
|
||||
if "@" not in email:
|
||||
raise ValueError(
|
||||
f"Invalid user, it should be an address: {email}"
|
||||
)
|
||||
if not v[email]:
|
||||
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
|
||||
raise ValueError(
|
||||
f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
|
||||
)
|
||||
# all users belong to the default group
|
||||
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
return {
|
||||
k.lower().strip(): list(
|
||||
set(["default"] + [g.lower().strip() for g in v])
|
||||
)
|
||||
for k, v in v.items()
|
||||
}
|
||||
|
||||
@field_validator('domains', mode='before')
|
||||
@classmethod
|
||||
@field_validator("domains", mode="before")
|
||||
def validate_domains(cls, v):
|
||||
for domain, members in v.items():
|
||||
if '.' not in domain:
|
||||
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
|
||||
if "." not in domain:
|
||||
raise ValueError(
|
||||
f"Invalid domain, it should contain a dot: {domain}"
|
||||
)
|
||||
if not members:
|
||||
raise ValueError(f"Domain {domain} should have at least one member.")
|
||||
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
|
||||
raise ValueError(
|
||||
f"Domain {domain} should have at least one member."
|
||||
)
|
||||
return {
|
||||
k.lower().strip(): list({[g.lower().strip() for g in v]})
|
||||
for k, v in v.items()
|
||||
}
|
||||
|
||||
@field_validator('groups', mode='before')
|
||||
@classmethod
|
||||
@field_validator("groups", mode="before")
|
||||
def validate_groups(cls, v):
|
||||
if "default" not in v.keys():
|
||||
raise ValueError("Please include a 'default' group.")
|
||||
@@ -154,20 +192,28 @@ class UserGroupModel(BaseModel):
|
||||
raise ValueError(f"Group names should be lowercase: {group}")
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
@model_validator(mode="after")
|
||||
def check_groups_consistency(self) -> Self:
|
||||
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
|
||||
groups_in_users = set([g for user in self.users for g in self.users[user]])
|
||||
groups_in_domains = {
|
||||
g for domain in self.domains for g in self.domains[domain]
|
||||
}
|
||||
groups_in_users = {g for user in self.users for g in self.users[user]}
|
||||
configured_groups = set(self.groups.keys())
|
||||
|
||||
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
|
||||
# groups mentioned in domains and users should be defined, but this is
|
||||
# not a ValueError since historical DB data may require it
|
||||
if groups_in_domains - configured_groups:
|
||||
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
|
||||
logger.warning(
|
||||
f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
|
||||
)
|
||||
if groups_in_users - configured_groups:
|
||||
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
|
||||
logger.warning(
|
||||
f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# for the API return values
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
|
||||
def fnv1a_hash_mod(s: str, modulo:int) -> int:
|
||||
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
|
||||
hash = 0x811c9dc5 # FNV offset basis
|
||||
fnv_prime = 0x01000193 # FNV prime
|
||||
def fnv1a_hash_mod(s: str, modulo: int) -> int:
|
||||
# receives a string and returns a number in [0:modulo-1], ensures an even
|
||||
# distribution over the modulo range
|
||||
offset_basis_hash = 0x811C9DC5 # FNV offset basis
|
||||
fnv_prime = 0x01000193 # FNV prime
|
||||
for char in s:
|
||||
hash ^= ord(char)
|
||||
hash *= fnv_prime
|
||||
hash &= 0xFFFFFFFF # Keep it 32-bit
|
||||
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
|
||||
offset_basis_hash ^= ord(char)
|
||||
offset_basis_hash *= fnv_prime
|
||||
offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
|
||||
return (
|
||||
offset_basis_hash
|
||||
if offset_basis_hash < 0x80000000
|
||||
else offset_basis_hash - 0x100000000
|
||||
) % modulo
|
||||
|
||||
@@ -84,9 +84,9 @@ def db_session(test_db):
|
||||
@pytest_asyncio.fixture()
|
||||
async def async_test_db(get_settings: Settings):
|
||||
get_user_group_names.cache_clear()
|
||||
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
|
||||
engine = await make_async_engine(get_settings.async_database_path)
|
||||
|
||||
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
|
||||
fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
|
||||
if not os.path.exists(fs):
|
||||
open(fs, "w").close()
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ async def notify_about_expired_archives():
|
||||
user_archives[archive.author_id].append(archive)
|
||||
|
||||
if user_archives:
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
fastmail = FastMail(get_settings().mail_config)
|
||||
# notify users
|
||||
for email in user_archives:
|
||||
list_of_archives = "\n".join(
|
||||
@@ -224,7 +224,7 @@ async def delete_stale_sheets():
|
||||
if not user_sheets:
|
||||
return
|
||||
|
||||
fastmail = FastMail(get_settings().MAIL_CONFIG)
|
||||
fastmail = FastMail(get_settings().mail_config)
|
||||
# notify users
|
||||
for email in user_sheets:
|
||||
list_of_sheets = "\n".join(
|
||||
|
||||
Reference in New Issue
Block a user