Format and lint shared directory (#64)

This commit is contained in:
Michael Plunkett
2025-03-03 13:20:50 -06:00
committed by GitHub
parent a9ca410d08
commit 1ca0ae2fb2
13 changed files with 255 additions and 102 deletions

View File

@@ -11,23 +11,39 @@ from app.shared.db import models
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]: def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = [] db_urls = []
for m in result.media: for m in result.media:
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))) for i, url in enumerate(m.urls):
db_urls.append(
models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))
)
for k, prop in m.properties.items(): for k, prop in m.properties.items():
if prop_converted := convert_if_media(prop): if prop_converted := convert_if_media(prop):
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}"))) for i, url in enumerate(prop_converted.urls):
db_urls.append(
models.ArchiveUrl(
url=url, key=prop_converted.get("id", f"{k}_{i}")
)
)
if isinstance(prop, list): if isinstance(prop, list):
for i, prop_media in enumerate(prop): for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media): if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls): for j, url in enumerate(prop_media.urls):
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}"))) db_urls.append(
models.ArchiveUrl(
url=url,
key=prop_media.get(
"id", f"{k}{prop_media.key}_{i}.{j}"
),
)
)
return db_urls return db_urls
def convert_if_media(media): def convert_if_media(media):
if isinstance(media, Media): return media if isinstance(media, Media):
return media
elif isinstance(media, dict): elif isinstance(media, dict):
try: return Media.from_dict(media) try:
return Media.from_dict(media)
except Exception as e: except Exception as e:
logger.debug(f"error parsing {media} : {e}") logger.debug(f"error parsing {media} : {e}")
return False return False

View File

@@ -1,26 +1,35 @@
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
import datetime import datetime
from typing import Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.shared.db import worker_crud from app.shared.db import worker_crud
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime: # TODO: temporary file for this code, maybe other code belongs here, maybe not. do
# decide
def get_store_archive_until(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
group = worker_crud.get_group(db, group_id) group = worker_crud.get_group(db, group_id)
assert group, f"Group {group_id} not found." assert group, f"Group {group_id} not found."
assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions." assert group.permissions and isinstance(group.permissions, dict), (
f"Group {group_id} has no permissions."
)
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1) max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
if max_lifespan == -1: return None if max_lifespan == -1:
return None
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan) return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime: def get_store_archive_until_or_never(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
try: try:
return get_store_archive_until(db, group_id) return get_store_archive_until(db, group_id)
except AssertionError as e: except AssertionError:
return None return None

View File

@@ -18,9 +18,9 @@ def make_engine(database_url: str):
engine = create_engine( engine = create_engine(
database_url, database_url,
connect_args={"check_same_thread": False}, connect_args={"check_same_thread": False},
pool_size=15, # Increase pool size pool_size=15, # Increase pool size
max_overflow=20, # Allow more temporary connections max_overflow=20, # Allow more temporary connections
pool_recycle=1800 # Recycle connections every 30 minutes pool_recycle=1800, # Recycle connections every 30 minutes
) )
@event.listens_for(engine, "connect") @event.listens_for(engine, "connect")
@@ -40,8 +40,10 @@ def make_session_local(engine: Engine):
@contextmanager @contextmanager
def get_db(): def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))() session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
try: yield session try:
finally: session.close() yield session
finally:
session.close()
def get_db_dependency(): def get_db_dependency():
@@ -59,22 +61,32 @@ def wal_checkpoint():
# ASYNC connections # ASYNC connections
async def make_async_engine(database_url: str) -> AsyncEngine: async def make_async_engine(database_url: str) -> AsyncEngine:
engine = create_async_engine(database_url, connect_args={"check_same_thread": False}) engine = create_async_engine(
database_url, connect_args={"check_same_thread": False}
)
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;"))) await conn.run_sync(
lambda sync_conn: sync_conn.execute(
text("PRAGMA journal_mode=WAL;")
)
)
return engine return engine
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession: async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False) return async_sessionmaker(
engine, expire_on_commit=False, autoflush=False, autocommit=False
)
@asynccontextmanager @asynccontextmanager
async def get_db_async(): async def get_db_async():
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH) engine = await make_async_engine(get_settings().async_database_path)
async_session = await make_async_session_local(engine) async_session = await make_async_session_local(engine)
async with async_session() as session: async with async_session() as session:
try: yield session try:
finally: await engine.dispose() yield session
finally:
await engine.dispose()

View File

@@ -42,7 +42,9 @@ class Archive(Base):
id = Column(String, primary_key=True, index=True) id = Column(String, primary_key=True, index=True)
url = Column(String, index=True) url = Column(String, index=True)
result = Column(JSON, default=None) result = Column(JSON, default=None)
public = Column(Boolean, default=True) # if public=false, access by group and author public = Column(
Boolean, default=True
) # if public=false, access by group and author
deleted = Column(Boolean, default=False) deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -52,7 +54,11 @@ class Archive(Base):
author_id = Column(String, ForeignKey("users.email")) author_id = Column(String, ForeignKey("users.email"))
sheet_id = Column(String, ForeignKey("sheets.id"), default=None) sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags) tags = relationship(
"Tag",
back_populates="archives",
secondary=association_table_archive_tags,
)
group = relationship("Group", back_populates="archives") group = relationship("Group", back_populates="archives")
author = relationship("User", back_populates="archives") author = relationship("User", back_populates="archives")
urls = relationship("ArchiveUrl", back_populates="archive") urls = relationship("ArchiveUrl", back_populates="archive")
@@ -75,7 +81,11 @@ class Tag(Base):
id = Column(String, primary_key=True, index=True) id = Column(String, primary_key=True, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags) archives = relationship(
"Archive",
back_populates="tags",
secondary=association_table_archive_tags,
)
class User(Base): class User(Base):
@@ -85,7 +95,9 @@ class User(Base):
archives = relationship("Archive", back_populates="author") archives = relationship("Archive", back_populates="author")
sheets = relationship("Sheet", back_populates="author") sheets = relationship("Sheet", back_populates="author")
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups) groups = relationship(
"Group", back_populates="users", secondary=association_table_user_groups
)
class Group(Base): class Group(Base):
@@ -101,7 +113,9 @@ class Group(Base):
archives = relationship("Archive", back_populates="group") archives = relationship("Archive", back_populates="group")
sheets = relationship("Sheet", back_populates="group") sheets = relationship("Sheet", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups) users = relationship(
"User", back_populates="groups", secondary=association_table_user_groups
)
class Sheet(Base): class Sheet(Base):
@@ -110,11 +124,27 @@ class Sheet(Base):
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID") id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
name = Column(String, default=None) name = Column(String, default=None)
author_id = Column(String, ForeignKey("users.email")) author_id = Column(String, ForeignKey("users.email"))
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.") group_id = Column(
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.") String,
ForeignKey("groups.id"),
doc="Group ID, user must be in a group to create a sheet.",
)
frequency = Column(
String,
default="daily",
doc="Frequency of archiving: hourly, daily, weekly.",
)
# TODO: stats is not being used, consider removing # TODO: stats is not being used, consider removing
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...") stats = Column(
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.") JSON,
default={},
doc="Sheet statistics like total links, total rows, ...",
)
last_url_archived_at = Column(
DateTime(timezone=True),
server_default=func.now(),
doc="Last time a new link was archived.",
)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@@ -9,7 +9,9 @@ from app.shared.db import models
# TODO: isolate database operations away from worker and into WEB # TODO: isolate database operations away from worker and into WEB
# ONLY WORKER # ONLY WORKER
def update_sheet_last_url_archived_at(db: Session, sheet_id: str): def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first() db_sheet = (
db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
)
if db_sheet: if db_sheet:
db_sheet.last_url_archived_at = datetime.now() db_sheet.last_url_archived_at = datetime.now()
db.commit() db.commit()
@@ -19,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
# ONLY WORKER and INTEROP # ONLY WORKER and INTEROP
def get_group(db: Session, group_name: str) -> models.Group: def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first() return db.query(models.Group).filter(models.Group.id == group_name).first()
def create_or_get_user(db: Session, author_id: str) -> models.User: def create_or_get_user(db: Session, author_id: str) -> models.User:
if type(author_id) == str: author_id = author_id.lower() if isinstance(author_id, str):
db_user = db.query(models.User).filter(models.User.email == author_id).first() author_id = author_id.lower()
db_user = (
db.query(models.User).filter(models.User.email == author_id).first()
)
if not db_user: if not db_user:
db_user = models.User(email=author_id) db_user = models.User(email=author_id)
db.add(db_user) db.add(db_user)
@@ -43,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
return db_tag return db_tag
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive: def create_archive(
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until) db: Session,
archive: schemas.ArchiveCreate,
tags: list[models.Tag],
urls: list[models.ArchiveUrl],
) -> models.Archive:
db_archive = models.Archive(
id=archive.id,
url=archive.url,
result=archive.result,
public=archive.public,
author_id=archive.author_id,
group_id=archive.group_id,
sheet_id=archive.sheet_id,
store_until=archive.store_until,
)
db_archive.tags = tags db_archive.tags = tags
db_archive.urls = urls db_archive.urls = urls
db.add(db_archive) db.add(db_archive)
@@ -53,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
return db_archive return db_archive
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive: def store_archived_url(
db: Session, archive: schemas.ArchiveCreate
) -> models.Archive:
# create and load user, tags, if needed # create and load user, tags, if needed
create_or_get_user(db, archive.author_id) create_or_get_user(db, archive.author_id)
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])] db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
# insert everything # insert everything
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls) db_archive = create_archive(
db, archive=archive, tags=db_tags, urls=archive.urls
)
return db_archive return db_archive

View File

@@ -8,7 +8,9 @@ logger.add("logs/api_logs.log", retention="30 days")
logger.add("logs/error_logs.log", retention="30 days", level="ERROR") logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""): def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
if not traceback_str: traceback_str = traceback.format_exc() if not traceback_str:
if extra: extra = f"{extra}\n" traceback_str = traceback.format_exc()
if extra:
extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}") logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")

View File

@@ -11,6 +11,7 @@ class SubmitSheet(BaseModel):
group_id: str = "default" group_id: str = "default"
tags: set[str] | None = set() tags: set[str] | None = set()
class ArchiveUrl(BaseModel): class ArchiveUrl(BaseModel):
url: str url: str
public: bool = False public: bool = False
@@ -18,6 +19,7 @@ class ArchiveUrl(BaseModel):
group_id: str | None group_id: str | None
tags: set[str] | None = set() tags: set[str] | None = set()
class ArchiveResult(BaseModel): class ArchiveResult(BaseModel):
id: str id: str
url: str url: str

View File

@@ -1,4 +1,3 @@
import os import os
from functools import lru_cache from functools import lru_cache
from typing import Annotated, Set from typing import Annotated, Set
@@ -9,32 +8,40 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=os.environ.get("ENVIRONMENT_FILE"),
env_file_encoding="utf-8",
extra="ignore",
str_strip_whitespace=True,
)
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True) # general
# general
SERVE_LOCAL_ARCHIVE: str | None = None SERVE_LOCAL_ARCHIVE: str | None = None
USER_GROUPS_FILENAME: str = "app/user-groups.yaml" USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
# database # database
DATABASE_PATH: str DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100 DATABASE_QUERY_LIMIT: int = 100
@property @property
def ASYNC_DATABASE_PATH(self) -> str: def async_database_path(self) -> str:
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://") return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
# security # security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)] API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)] ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)] CHROME_APP_IDS: Annotated[
Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set() BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
# redis # redis
REDIS_PASSWORD: str = "" REDIS_PASSWORD: str = ""
REDIS_HOSTNAME: str = "localhost" REDIS_HOSTNAME: str = "localhost"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel" REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
@property @property
def CELERY_BROKER_URL(self)-> str: def celery_broker_url(self) -> str:
if self.REDIS_PASSWORD: if self.REDIS_PASSWORD:
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379" return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
return f"redis://{self.REDIS_HOSTNAME}:6379" return f"redis://{self.REDIS_HOSTNAME}:6379"
@@ -46,7 +53,7 @@ class Settings(BaseSettings):
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7 DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
# observability # observability
REPEAT_COUNT_METRICS_SECONDS: int = 30 REPEAT_COUNT_METRICS_SECONDS: int = 30
# email configuration, if needed # email configuration, if needed
@@ -58,8 +65,9 @@ class Settings(BaseSettings):
MAIL_PORT: int = 587 MAIL_PORT: int = 587
MAIL_STARTTLS: bool = False MAIL_STARTTLS: bool = False
MAIL_SSL_TLS: bool = True MAIL_SSL_TLS: bool = True
@property @property
def MAIL_CONFIG(self) -> str: def mail_config(self) -> ConnectionConfig:
return ConnectionConfig( return ConnectionConfig(
MAIL_FROM=self.MAIL_FROM, MAIL_FROM=self.MAIL_FROM,
MAIL_FROM_NAME=self.MAIL_FROM_NAME, MAIL_FROM_NAME=self.MAIL_FROM_NAME,

View File

@@ -1,4 +1,3 @@
from functools import lru_cache from functools import lru_cache
from celery import Celery from celery import Celery
@@ -11,14 +10,14 @@ from app.shared.settings import get_settings
def get_celery(name: str = "") -> Celery: def get_celery(name: str = "") -> Celery:
return Celery( return Celery(
name, name,
broker_url=get_settings().CELERY_BROKER_URL, broker_url=get_settings().celery_broker_url,
result_backend=get_settings().CELERY_BROKER_URL, result_backend=get_settings().celery_broker_url,
broker_connection_retry_on_startup=False, broker_connection_retry_on_startup=False,
broker_transport_options={ broker_transport_options={
'queue_order_strategy': 'priority', "queue_order_strategy": "priority",
} },
) )
def get_redis() -> redis.Redis: def get_redis() -> redis.Redis:
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL) return redis.Redis.from_url(get_settings().celery_broker_url)

View File

@@ -19,13 +19,16 @@ class UserGroups:
user_groups = self.read_yaml(filename) user_groups = self.read_yaml(filename)
self.validate_and_load(user_groups) self.validate_and_load(user_groups)
def read_yaml(self, user_groups_filename): @staticmethod
def read_yaml(user_groups_filename):
# read yaml safely # read yaml safely
with open(user_groups_filename) as inf: with open(user_groups_filename) as inf:
try: try:
return yaml.safe_load(inf) return yaml.safe_load(inf)
except yaml.YAMLError as e: except yaml.YAMLError as e:
logger.error(f"could not open user groups filename {user_groups_filename}: {e}") logger.error(
f"could not open user groups filename {user_groups_filename}: {e}"
)
raise e raise e
def validate_and_load(self, user_groups): def validate_and_load(self, user_groups):
@@ -52,22 +55,36 @@ class GroupPermissions(BaseModel):
max_monthly_mbs: int = 0 max_monthly_mbs: int = 0
priority: str = "low" priority: str = "low"
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before') @classmethod
@field_validator(
"max_sheets",
"max_archive_lifespan_months",
"max_monthly_urls",
"max_monthly_mbs",
mode="before",
)
def validate_max_values(cls, v): def validate_max_values(cls, v):
if v < -1: if v < -1:
raise ValueError("max_* values should be positive integers or -1 (for no limit).") raise ValueError(
"max_* values should be positive integers or -1 (for no limit)."
)
return v return v
@field_validator('sheet_frequency', mode='before') @classmethod
@field_validator("sheet_frequency", mode="before")
def validate_sheet_frequency(cls, v): def validate_sheet_frequency(cls, v):
if not v: return [] if not v:
return []
allowed = ["daily", "hourly"] allowed = ["daily", "hourly"]
for k in v: for k in v:
if k not in allowed: if k not in allowed:
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}") raise ValueError(
f"Invalid sheet frequency: '{k}', expected one of {allowed}"
)
return v return v
@field_validator('priority', mode='before') @classmethod
@field_validator("priority", mode="before")
def validate_priority(cls, v): def validate_priority(cls, v):
v = v.lower() v = v.lower()
if v not in ["low", "high"]: if v not in ["low", "high"]:
@@ -81,7 +98,8 @@ class GroupModel(BaseModel):
orchestrator_sheet: str orchestrator_sheet: str
permissions: GroupPermissions permissions: GroupPermissions
@field_validator('orchestrator', 'orchestrator_sheet', mode='before') @classmethod
@field_validator("orchestrator", "orchestrator_sheet", mode="before")
def validate_orchestrator(cls, v): def validate_orchestrator(cls, v):
if not os.path.exists(v): if not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}") raise ValueError(f"Orchestrator file not found with this path: {v}")
@@ -105,13 +123,17 @@ class GroupModel(BaseModel):
service_account_json = find_service_account_email(orch) service_account_json = find_service_account_email(orch)
if not service_account_json: if not service_account_json:
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.") raise ValueError(
f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
)
with open(service_account_json) as f: with open(service_account_json) as f:
self._service_account_email = json.load(f).get("client_email") self._service_account_email = json.load(f).get("client_email")
if not self._service_account_email: if not self._service_account_email:
raise ValueError(f"Service account email not found in {service_account_json}.") raise ValueError(
f"Service account email not found in {service_account_json}."
)
return self._service_account_email return self._service_account_email
@@ -121,29 +143,45 @@ class UserGroupModel(BaseModel):
domains: Dict[str, List[str]] = Field(default_factory=dict) domains: Dict[str, List[str]] = Field(default_factory=dict)
groups: Dict[str, GroupModel] = Field(default_factory=dict) groups: Dict[str, GroupModel] = Field(default_factory=dict)
@field_validator('users', mode='before')
@classmethod @classmethod
@field_validator("users", mode="before")
def validate_emails(cls, v): def validate_emails(cls, v):
for email in v.keys(): for email in v.keys():
if '@' not in email: if "@" not in email:
raise ValueError(f"Invalid user, it should be an address: {email}") raise ValueError(
f"Invalid user, it should be an address: {email}"
)
if not v[email]: if not v[email]:
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.") raise ValueError(
f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
)
# all users belong to the default group # all users belong to the default group
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()} return {
k.lower().strip(): list(
set(["default"] + [g.lower().strip() for g in v])
)
for k, v in v.items()
}
@field_validator('domains', mode='before')
@classmethod @classmethod
@field_validator("domains", mode="before")
def validate_domains(cls, v): def validate_domains(cls, v):
for domain, members in v.items(): for domain, members in v.items():
if '.' not in domain: if "." not in domain:
raise ValueError(f"Invalid domain, it should contain a dot: {domain}") raise ValueError(
f"Invalid domain, it should contain a dot: {domain}"
)
if not members: if not members:
raise ValueError(f"Domain {domain} should have at least one member.") raise ValueError(
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()} f"Domain {domain} should have at least one member."
)
return {
k.lower().strip(): list({[g.lower().strip() for g in v]})
for k, v in v.items()
}
@field_validator('groups', mode='before')
@classmethod @classmethod
@field_validator("groups", mode="before")
def validate_groups(cls, v): def validate_groups(cls, v):
if "default" not in v.keys(): if "default" not in v.keys():
raise ValueError("Please include a 'default' group.") raise ValueError("Please include a 'default' group.")
@@ -154,20 +192,28 @@ class UserGroupModel(BaseModel):
raise ValueError(f"Group names should be lowercase: {group}") raise ValueError(f"Group names should be lowercase: {group}")
return v return v
@model_validator(mode='after') @model_validator(mode="after")
def check_groups_consistency(self) -> Self: def check_groups_consistency(self) -> Self:
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]]) groups_in_domains = {
groups_in_users = set([g for user in self.users for g in self.users[user]]) g for domain in self.domains for g in self.domains[domain]
}
groups_in_users = {g for user in self.users for g in self.users[user]}
configured_groups = set(self.groups.keys()) configured_groups = set(self.groups.keys())
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it # groups mentioned in domains and users should be defined, but this is
# not a ValueError since historical DB data may require it
if groups_in_domains - configured_groups: if groups_in_domains - configured_groups:
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}") logger.warning(
f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
)
if groups_in_users - configured_groups: if groups_in_users - configured_groups:
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}") logger.warning(
f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
)
return self return self
# for the API return values # for the API return values

View File

@@ -1,10 +1,14 @@
def fnv1a_hash_mod(s: str, modulo: int) -> int:
def fnv1a_hash_mod(s: str, modulo:int) -> int: # receives a string and returns a number in [0:modulo-1], ensures an even
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range # distribution over the modulo range
hash = 0x811c9dc5 # FNV offset basis offset_basis_hash = 0x811C9DC5 # FNV offset basis
fnv_prime = 0x01000193 # FNV prime fnv_prime = 0x01000193 # FNV prime
for char in s: for char in s:
hash ^= ord(char) offset_basis_hash ^= ord(char)
hash *= fnv_prime offset_basis_hash *= fnv_prime
hash &= 0xFFFFFFFF # Keep it 32-bit offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo return (
offset_basis_hash
if offset_basis_hash < 0x80000000
else offset_basis_hash - 0x100000000
) % modulo

View File

@@ -84,9 +84,9 @@ def db_session(test_db):
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings): async def async_test_db(get_settings: Settings):
get_user_group_names.cache_clear() get_user_group_names.cache_clear()
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH) engine = await make_async_engine(get_settings.async_database_path)
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "") fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs): if not os.path.exists(fs):
open(fs, "w").close() open(fs, "w").close()

View File

@@ -160,7 +160,7 @@ async def notify_about_expired_archives():
user_archives[archive.author_id].append(archive) user_archives[archive.author_id].append(archive)
if user_archives: if user_archives:
fastmail = FastMail(get_settings().MAIL_CONFIG) fastmail = FastMail(get_settings().mail_config)
# notify users # notify users
for email in user_archives: for email in user_archives:
list_of_archives = "\n".join( list_of_archives = "\n".join(
@@ -224,7 +224,7 @@ async def delete_stale_sheets():
if not user_sheets: if not user_sheets:
return return
fastmail = FastMail(get_settings().MAIL_CONFIG) fastmail = FastMail(get_settings().mail_config)
# notify users # notify users
for email in user_sheets: for email in user_sheets:
list_of_sheets = "\n".join( list_of_sheets = "\n".join(