Format and lint shared directory (#64)

This commit is contained in:
Michael Plunkett
2025-03-03 13:20:50 -06:00
committed by GitHub
parent a9ca410d08
commit 1ca0ae2fb2
13 changed files with 255 additions and 102 deletions

View File

@@ -11,23 +11,39 @@ from app.shared.db import models
def get_all_urls(result: Metadata) -> List[models.ArchiveUrl]:
db_urls = []
for m in result.media:
for i, url in enumerate(m.urls): db_urls.append(models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}")))
for i, url in enumerate(m.urls):
db_urls.append(
models.ArchiveUrl(url=url, key=m.get("id", f"media_{i}"))
)
for k, prop in m.properties.items():
if prop_converted := convert_if_media(prop):
for i, url in enumerate(prop_converted.urls): db_urls.append(models.ArchiveUrl(url=url, key=prop_converted.get("id", f"{k}_{i}")))
for i, url in enumerate(prop_converted.urls):
db_urls.append(
models.ArchiveUrl(
url=url, key=prop_converted.get("id", f"{k}_{i}")
)
)
if isinstance(prop, list):
for i, prop_media in enumerate(prop):
if prop_media := convert_if_media(prop_media):
for j, url in enumerate(prop_media.urls):
db_urls.append(models.ArchiveUrl(url=url, key=prop_media.get("id", f"{k}{prop_media.key}_{i}.{j}")))
db_urls.append(
models.ArchiveUrl(
url=url,
key=prop_media.get(
"id", f"{k}{prop_media.key}_{i}.{j}"
),
)
)
return db_urls
def convert_if_media(media):
if isinstance(media, Media): return media
if isinstance(media, Media):
return media
elif isinstance(media, dict):
try: return Media.from_dict(media)
try:
return Media.from_dict(media)
except Exception as e:
logger.debug(f"error parsing {media} : {e}")
return False

View File

@@ -1,26 +1,35 @@
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do decide
import datetime
from typing import Union
from sqlalchemy.orm import Session
from app.shared.db import worker_crud
def get_store_archive_until(db: Session, group_id: str) -> datetime.datetime:
# TODO: temporary file for this code, maybe other code belongs here, maybe not. do
# decide
def get_store_archive_until(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
group = worker_crud.get_group(db, group_id)
assert group, f"Group {group_id} not found."
assert group.permissions and type(group.permissions) == dict, f"Group {group_id} has no permissions."
assert group.permissions and isinstance(group.permissions, dict), (
f"Group {group_id} has no permissions."
)
max_lifespan = group.permissions.get("max_archive_lifespan_months", -1)
if max_lifespan == -1: return None
if max_lifespan == -1:
return None
return datetime.datetime.now() + datetime.timedelta(days=30 * max_lifespan)
def get_store_archive_until_or_never(db: Session, group_id: str) -> datetime.datetime:
def get_store_archive_until_or_never(
db: Session, group_id: str
) -> Union[datetime.datetime, None]:
try:
return get_store_archive_until(db, group_id)
except AssertionError as e:
except AssertionError:
return None

View File

@@ -18,9 +18,9 @@ def make_engine(database_url: str):
engine = create_engine(
database_url,
connect_args={"check_same_thread": False},
pool_size=15, # Increase pool size
max_overflow=20, # Allow more temporary connections
pool_recycle=1800 # Recycle connections every 30 minutes
pool_size=15, # Increase pool size
max_overflow=20, # Allow more temporary connections
pool_recycle=1800, # Recycle connections every 30 minutes
)
@event.listens_for(engine, "connect")
@@ -40,8 +40,10 @@ def make_session_local(engine: Engine):
@contextmanager
def get_db():
session = make_session_local(make_engine(get_settings().DATABASE_PATH))()
try: yield session
finally: session.close()
try:
yield session
finally:
session.close()
def get_db_dependency():
@@ -59,22 +61,32 @@ def wal_checkpoint():
# ASYNC connections
async def make_async_engine(database_url: str) -> AsyncEngine:
engine = create_async_engine(database_url, connect_args={"check_same_thread": False})
engine = create_async_engine(
database_url, connect_args={"check_same_thread": False}
)
async with engine.begin() as conn:
await conn.run_sync(lambda sync_conn: sync_conn.execute(text("PRAGMA journal_mode=WAL;")))
await conn.run_sync(
lambda sync_conn: sync_conn.execute(
text("PRAGMA journal_mode=WAL;")
)
)
return engine
async def make_async_session_local(engine: AsyncEngine) -> AsyncSession:
return async_sessionmaker(engine, expire_on_commit=False, autoflush=False, autocommit=False)
return async_sessionmaker(
engine, expire_on_commit=False, autoflush=False, autocommit=False
)
@asynccontextmanager
async def get_db_async():
engine = await make_async_engine(get_settings().ASYNC_DATABASE_PATH)
engine = await make_async_engine(get_settings().async_database_path)
async_session = await make_async_session_local(engine)
async with async_session() as session:
try: yield session
finally: await engine.dispose()
try:
yield session
finally:
await engine.dispose()

View File

@@ -42,7 +42,9 @@ class Archive(Base):
id = Column(String, primary_key=True, index=True)
url = Column(String, index=True)
result = Column(JSON, default=None)
public = Column(Boolean, default=True) # if public=false, access by group and author
public = Column(
Boolean, default=True
) # if public=false, access by group and author
deleted = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
@@ -52,7 +54,11 @@ class Archive(Base):
author_id = Column(String, ForeignKey("users.email"))
sheet_id = Column(String, ForeignKey("sheets.id"), default=None)
tags = relationship("Tag", back_populates="archives", secondary=association_table_archive_tags)
tags = relationship(
"Tag",
back_populates="archives",
secondary=association_table_archive_tags,
)
group = relationship("Group", back_populates="archives")
author = relationship("User", back_populates="archives")
urls = relationship("ArchiveUrl", back_populates="archive")
@@ -75,7 +81,11 @@ class Tag(Base):
id = Column(String, primary_key=True, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
archives = relationship("Archive", back_populates="tags", secondary=association_table_archive_tags)
archives = relationship(
"Archive",
back_populates="tags",
secondary=association_table_archive_tags,
)
class User(Base):
@@ -85,7 +95,9 @@ class User(Base):
archives = relationship("Archive", back_populates="author")
sheets = relationship("Sheet", back_populates="author")
groups = relationship("Group", back_populates="users", secondary=association_table_user_groups)
groups = relationship(
"Group", back_populates="users", secondary=association_table_user_groups
)
class Group(Base):
@@ -101,7 +113,9 @@ class Group(Base):
archives = relationship("Archive", back_populates="group")
sheets = relationship("Sheet", back_populates="group")
users = relationship("User", back_populates="groups", secondary=association_table_user_groups)
users = relationship(
"User", back_populates="groups", secondary=association_table_user_groups
)
class Sheet(Base):
@@ -110,11 +124,27 @@ class Sheet(Base):
id = Column(String, primary_key=True, index=True, doc="Google Sheet ID")
name = Column(String, default=None)
author_id = Column(String, ForeignKey("users.email"))
group_id = Column(String, ForeignKey("groups.id"), doc="Group ID, user must be in a group to create a sheet.")
frequency = Column(String, default="daily", doc="Frequency of archiving: hourly, daily, weekly.")
group_id = Column(
String,
ForeignKey("groups.id"),
doc="Group ID, user must be in a group to create a sheet.",
)
frequency = Column(
String,
default="daily",
doc="Frequency of archiving: hourly, daily, weekly.",
)
# TODO: stats is not being used, consider removing
stats = Column(JSON, default={}, doc="Sheet statistics like total links, total rows, ...")
last_url_archived_at = Column(DateTime(timezone=True), server_default=func.now(), doc="Last time a new link was archived.")
stats = Column(
JSON,
default={},
doc="Sheet statistics like total links, total rows, ...",
)
last_url_archived_at = Column(
DateTime(timezone=True),
server_default=func.now(),
doc="Last time a new link was archived.",
)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())

View File

@@ -9,7 +9,9 @@ from app.shared.db import models
# TODO: isolate database operations away from worker and into WEB
# ONLY WORKER
def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
db_sheet = db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
db_sheet = (
db.query(models.Sheet).filter(models.Sheet.id == sheet_id).first()
)
if db_sheet:
db_sheet.last_url_archived_at = datetime.now()
db.commit()
@@ -19,12 +21,17 @@ def update_sheet_last_url_archived_at(db: Session, sheet_id: str):
# ONLY WORKER and INTEROP
def get_group(db: Session, group_name: str) -> models.Group:
return db.query(models.Group).filter(models.Group.id == group_name).first()
def create_or_get_user(db: Session, author_id: str) -> models.User:
if type(author_id) == str: author_id = author_id.lower()
db_user = db.query(models.User).filter(models.User.email == author_id).first()
if isinstance(author_id, str):
author_id = author_id.lower()
db_user = (
db.query(models.User).filter(models.User.email == author_id).first()
)
if not db_user:
db_user = models.User(email=author_id)
db.add(db_user)
@@ -43,8 +50,22 @@ def create_tag(db: Session, tag: str) -> models.Tag:
return db_tag
def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[models.Tag], urls: list[models.ArchiveUrl]) -> models.Archive:
db_archive = models.Archive(id=archive.id, url=archive.url, result=archive.result, public=archive.public, author_id=archive.author_id, group_id=archive.group_id, sheet_id=archive.sheet_id, store_until=archive.store_until)
def create_archive(
db: Session,
archive: schemas.ArchiveCreate,
tags: list[models.Tag],
urls: list[models.ArchiveUrl],
) -> models.Archive:
db_archive = models.Archive(
id=archive.id,
url=archive.url,
result=archive.result,
public=archive.public,
author_id=archive.author_id,
group_id=archive.group_id,
sheet_id=archive.sheet_id,
store_until=archive.store_until,
)
db_archive.tags = tags
db_archive.urls = urls
db.add(db_archive)
@@ -53,10 +74,14 @@ def create_archive(db: Session, archive: schemas.ArchiveCreate, tags: list[model
return db_archive
def store_archived_url(db: Session, archive: schemas.ArchiveCreate) -> models.Archive:
def store_archived_url(
db: Session, archive: schemas.ArchiveCreate
) -> models.Archive:
# create and load user, tags, if needed
create_or_get_user(db, archive.author_id)
db_tags = [create_tag(db, tag) for tag in (archive.tags or [])]
# insert everything
db_archive = create_archive(db, archive=archive, tags=db_tags, urls=archive.urls)
db_archive = create_archive(
db, archive=archive, tags=db_tags, urls=archive.urls
)
return db_archive

View File

@@ -8,7 +8,9 @@ logger.add("logs/api_logs.log", retention="30 days")
logger.add("logs/error_logs.log", retention="30 days", level="ERROR")
def log_error(e: Exception, traceback_str: str = None, extra:str = ""):
if not traceback_str: traceback_str = traceback.format_exc()
if extra: extra = f"{extra}\n"
def log_error(e: Exception, traceback_str: str = None, extra: str = ""):
if not traceback_str:
traceback_str = traceback.format_exc()
if extra:
extra = f"{extra}\n"
logger.error(f"{extra}{e.__class__.__name__}: {e}\n{traceback_str}")

View File

@@ -11,6 +11,7 @@ class SubmitSheet(BaseModel):
group_id: str = "default"
tags: set[str] | None = set()
class ArchiveUrl(BaseModel):
url: str
public: bool = False
@@ -18,6 +19,7 @@ class ArchiveUrl(BaseModel):
group_id: str | None
tags: set[str] | None = set()
class ArchiveResult(BaseModel):
id: str
url: str

View File

@@ -1,4 +1,3 @@
import os
from functools import lru_cache
from typing import Annotated, Set
@@ -9,32 +8,40 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=os.environ.get("ENVIRONMENT_FILE"),
env_file_encoding="utf-8",
extra="ignore",
str_strip_whitespace=True,
)
model_config = SettingsConfigDict(env_file=os.environ.get("ENVIRONMENT_FILE") , env_file_encoding='utf-8', extra='ignore', str_strip_whitespace=True)
# general
# general
SERVE_LOCAL_ARCHIVE: str | None = None
USER_GROUPS_FILENAME: str = "app/user-groups.yaml"
# database
# database
DATABASE_PATH: str
DATABASE_QUERY_LIMIT: int = 100
@property
def ASYNC_DATABASE_PATH(self) -> str:
def async_database_path(self) -> str:
return self.DATABASE_PATH.replace("sqlite://", "sqlite+aiosqlite://")
# security
# security
API_BEARER_TOKEN: Annotated[str, Len(min_length=20)]
ALLOWED_ORIGINS: Annotated[Set[str], Len(min_length=1)]
CHROME_APP_IDS: Annotated[Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)]
CHROME_APP_IDS: Annotated[
Set[Annotated[str, Len(min_length=10)]], Len(min_length=1)
]
BLOCKED_EMAILS: Annotated[Set[str], Len(min_length=0)] = set()
# redis
REDIS_PASSWORD: str = ""
REDIS_HOSTNAME: str = "localhost"
REDIS_EXCEPTIONS_CHANNEL: str = "exceptions-channel"
@property
def CELERY_BROKER_URL(self)-> str:
def celery_broker_url(self) -> str:
if self.REDIS_PASSWORD:
return f"redis://:{self.REDIS_PASSWORD}@{self.REDIS_HOSTNAME}:6379"
return f"redis://{self.REDIS_HOSTNAME}:6379"
@@ -46,7 +53,7 @@ class Settings(BaseSettings):
CRON_DELETE_SCHEDULED_ARCHIVES: bool = False
DELETE_SCHEDULED_ARCHIVES_CHECK_EVERY_N_DAYS: int = 7
# observability
# observability
REPEAT_COUNT_METRICS_SECONDS: int = 30
# email configuration, if needed
@@ -58,8 +65,9 @@ class Settings(BaseSettings):
MAIL_PORT: int = 587
MAIL_STARTTLS: bool = False
MAIL_SSL_TLS: bool = True
@property
def MAIL_CONFIG(self) -> str:
def mail_config(self) -> ConnectionConfig:
return ConnectionConfig(
MAIL_FROM=self.MAIL_FROM,
MAIL_FROM_NAME=self.MAIL_FROM_NAME,

View File

@@ -1,4 +1,3 @@
from functools import lru_cache
from celery import Celery
@@ -11,14 +10,14 @@ from app.shared.settings import get_settings
def get_celery(name: str = "") -> Celery:
return Celery(
name,
broker_url=get_settings().CELERY_BROKER_URL,
result_backend=get_settings().CELERY_BROKER_URL,
broker_url=get_settings().celery_broker_url,
result_backend=get_settings().celery_broker_url,
broker_connection_retry_on_startup=False,
broker_transport_options={
'queue_order_strategy': 'priority',
}
"queue_order_strategy": "priority",
},
)
def get_redis() -> redis.Redis:
return redis.Redis.from_url(get_settings().CELERY_BROKER_URL)
return redis.Redis.from_url(get_settings().celery_broker_url)

View File

@@ -19,13 +19,16 @@ class UserGroups:
user_groups = self.read_yaml(filename)
self.validate_and_load(user_groups)
def read_yaml(self, user_groups_filename):
@staticmethod
def read_yaml(user_groups_filename):
# read yaml safely
with open(user_groups_filename) as inf:
try:
return yaml.safe_load(inf)
except yaml.YAMLError as e:
logger.error(f"could not open user groups filename {user_groups_filename}: {e}")
logger.error(
f"could not open user groups filename {user_groups_filename}: {e}"
)
raise e
def validate_and_load(self, user_groups):
@@ -52,22 +55,36 @@ class GroupPermissions(BaseModel):
max_monthly_mbs: int = 0
priority: str = "low"
@field_validator('max_sheets', 'max_archive_lifespan_months', 'max_monthly_urls', 'max_monthly_mbs', mode='before')
@classmethod
@field_validator(
"max_sheets",
"max_archive_lifespan_months",
"max_monthly_urls",
"max_monthly_mbs",
mode="before",
)
def validate_max_values(cls, v):
if v < -1:
raise ValueError("max_* values should be positive integers or -1 (for no limit).")
raise ValueError(
"max_* values should be positive integers or -1 (for no limit)."
)
return v
@field_validator('sheet_frequency', mode='before')
@classmethod
@field_validator("sheet_frequency", mode="before")
def validate_sheet_frequency(cls, v):
if not v: return []
if not v:
return []
allowed = ["daily", "hourly"]
for k in v:
if k not in allowed:
raise ValueError(f"Invalid sheet frequency: '{k}', expected one of {allowed}")
raise ValueError(
f"Invalid sheet frequency: '{k}', expected one of {allowed}"
)
return v
@field_validator('priority', mode='before')
@classmethod
@field_validator("priority", mode="before")
def validate_priority(cls, v):
v = v.lower()
if v not in ["low", "high"]:
@@ -81,7 +98,8 @@ class GroupModel(BaseModel):
orchestrator_sheet: str
permissions: GroupPermissions
@field_validator('orchestrator', 'orchestrator_sheet', mode='before')
@classmethod
@field_validator("orchestrator", "orchestrator_sheet", mode="before")
def validate_orchestrator(cls, v):
if not os.path.exists(v):
raise ValueError(f"Orchestrator file not found with this path: {v}")
@@ -105,13 +123,17 @@ class GroupModel(BaseModel):
service_account_json = find_service_account_email(orch)
if not service_account_json:
raise ValueError(f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}.")
raise ValueError(
f"service_account key not found in orchestrator sheet file: {self.orchestrator_sheet}."
)
with open(service_account_json) as f:
self._service_account_email = json.load(f).get("client_email")
if not self._service_account_email:
raise ValueError(f"Service account email not found in {service_account_json}.")
raise ValueError(
f"Service account email not found in {service_account_json}."
)
return self._service_account_email
@@ -121,29 +143,45 @@ class UserGroupModel(BaseModel):
domains: Dict[str, List[str]] = Field(default_factory=dict)
groups: Dict[str, GroupModel] = Field(default_factory=dict)
@field_validator('users', mode='before')
@classmethod
@field_validator("users", mode="before")
def validate_emails(cls, v):
for email in v.keys():
if '@' not in email:
raise ValueError(f"Invalid user, it should be an address: {email}")
if "@" not in email:
raise ValueError(
f"Invalid user, it should be an address: {email}"
)
if not v[email]:
raise ValueError(f"User {email} has no explicitly listed groups, only include them here if they should be in a group.")
raise ValueError(
f"User {email} has no explicitly listed groups, only include them here if they should be in a group."
)
# all users belong to the default group
return {k.lower().strip(): list(set(["default"] + [g.lower().strip() for g in v])) for k, v in v.items()}
return {
k.lower().strip(): list(
set(["default"] + [g.lower().strip() for g in v])
)
for k, v in v.items()
}
@field_validator('domains', mode='before')
@classmethod
@field_validator("domains", mode="before")
def validate_domains(cls, v):
for domain, members in v.items():
if '.' not in domain:
raise ValueError(f"Invalid domain, it should contain a dot: {domain}")
if "." not in domain:
raise ValueError(
f"Invalid domain, it should contain a dot: {domain}"
)
if not members:
raise ValueError(f"Domain {domain} should have at least one member.")
return {k.lower().strip(): list(set([g.lower().strip() for g in v])) for k, v in v.items()}
raise ValueError(
f"Domain {domain} should have at least one member."
)
return {
k.lower().strip(): list({[g.lower().strip() for g in v]})
for k, v in v.items()
}
@field_validator('groups', mode='before')
@classmethod
@field_validator("groups", mode="before")
def validate_groups(cls, v):
if "default" not in v.keys():
raise ValueError("Please include a 'default' group.")
@@ -154,20 +192,28 @@ class UserGroupModel(BaseModel):
raise ValueError(f"Group names should be lowercase: {group}")
return v
@model_validator(mode='after')
@model_validator(mode="after")
def check_groups_consistency(self) -> Self:
groups_in_domains = set([g for domain in self.domains for g in self.domains[domain]])
groups_in_users = set([g for user in self.users for g in self.users[user]])
groups_in_domains = {
g for domain in self.domains for g in self.domains[domain]
}
groups_in_users = {g for user in self.users for g in self.users[user]}
configured_groups = set(self.groups.keys())
# groups mentioned in domains and users should be defined, but this is not a ValueError since historical DB data may require it
# groups mentioned in domains and users should be defined, but this is
# not a ValueError since historical DB data may require it
if groups_in_domains - configured_groups:
logger.warning(f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}")
logger.warning(
f"These groups are associated to DOMAINS but not defined in the GROUPS section, the domains settings may not work as expected: {groups_in_domains - configured_groups}"
)
if groups_in_users - configured_groups:
logger.warning(f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}")
logger.warning(
f"These groups are associated to USERS but not defined in the GROUPS section, the users settings may not work as expected: {groups_in_users - configured_groups}"
)
return self
# for the API return values

View File

@@ -1,10 +1,14 @@
def fnv1a_hash_mod(s: str, modulo:int) -> int:
# receives a string and returns a number in [0:modulo-1], ensures an even distribution over the modulo range
hash = 0x811c9dc5 # FNV offset basis
fnv_prime = 0x01000193 # FNV prime
def fnv1a_hash_mod(s: str, modulo: int) -> int:
# receives a string and returns a number in [0:modulo-1], ensures an even
# distribution over the modulo range
offset_basis_hash = 0x811C9DC5 # FNV offset basis
fnv_prime = 0x01000193 # FNV prime
for char in s:
hash ^= ord(char)
hash *= fnv_prime
hash &= 0xFFFFFFFF # Keep it 32-bit
return (hash if hash < 0x80000000 else hash - 0x100000000) % modulo
offset_basis_hash ^= ord(char)
offset_basis_hash *= fnv_prime
offset_basis_hash &= 0xFFFFFFFF # Keep it 32-bit
return (
offset_basis_hash
if offset_basis_hash < 0x80000000
else offset_basis_hash - 0x100000000
) % modulo

View File

@@ -84,9 +84,9 @@ def db_session(test_db):
@pytest_asyncio.fixture()
async def async_test_db(get_settings: Settings):
get_user_group_names.cache_clear()
engine = await make_async_engine(get_settings.ASYNC_DATABASE_PATH)
engine = await make_async_engine(get_settings.async_database_path)
fs = get_settings.ASYNC_DATABASE_PATH.replace("sqlite+aiosqlite:///", "")
fs = get_settings.async_database_path.replace("sqlite+aiosqlite:///", "")
if not os.path.exists(fs):
open(fs, "w").close()

View File

@@ -160,7 +160,7 @@ async def notify_about_expired_archives():
user_archives[archive.author_id].append(archive)
if user_archives:
fastmail = FastMail(get_settings().MAIL_CONFIG)
fastmail = FastMail(get_settings().mail_config)
# notify users
for email in user_archives:
list_of_archives = "\n".join(
@@ -224,7 +224,7 @@ async def delete_stale_sheets():
if not user_sheets:
return
fastmail = FastMail(get_settings().MAIL_CONFIG)
fastmail = FastMail(get_settings().mail_config)
# notify users
for email in user_sheets:
list_of_sheets = "\n".join(