feat: add language detection task

This commit is contained in:
Felix Spöttel
2023-06-29 09:13:11 +02:00
parent d2223206be
commit 908bd48170
15 changed files with 267 additions and 191 deletions

View File

@@ -10,6 +10,7 @@ import app.shared.db.schemas as schemas
from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal
from app.shared.settings import settings
from app.worker.strategies.base import TaskProtocol
from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
@@ -30,10 +31,10 @@ class TranscribeTask(Task):
return self.run(*args, **kwargs)
def select_strategy(task: Task, job: schemas.Job) -> Any:
if job.type == schemas.JobType.transcript:
def select_task_processor(task: Task, job: schemas.Job) -> TaskProtocol:
if job.type == models.JobType.transcript:
return task.strategy.transcribe
elif job.type == schemas.JobType.translation:
elif job.type == models.JobType.translation:
return task.strategy.translate
else:
return task.strategy.detect_language
@@ -50,49 +51,50 @@ def transcribe(self: Task, job_id: UUID) -> None:
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
# unit of work: set task status to processing.
if (
job.status == schemas.JobStatus.error
or job.status == schemas.JobStatus.success
):
logger.warn(
"[{job.id}]: Received job that has already been processed, abort."
)
job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none()
if job is None:
logger.warn("[{job.id}]: Received unknown job, abort.")
return
logger.info(f"[{job.id}]: worker received task.")
if job.status in [models.JobStatus.error, models.JobStatus.success]:
logger.warn("[{job.id}]: job has already been processed, abort.")
return
logger.info(f"[{job.id}]: received eligible job.")
job.meta = {"task_id": self.request.id}
job.status = schemas.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: set task to status processing.")
job.status = models.JobStatus.processing
db.commit()
logger.info(f"[{job.id}]: finished setting task to status processing.")
# unit of work: process job with whisper.
job_record = schemas.Job.from_orm(job)
strategy = select_strategy(self, job_record)
result = strategy(
processor = select_task_processor(self, job_record)
result_type, result = processor(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
logger.info(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
)
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
db.add(artifact)
job.status = models.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: successfully stored artifact.")
job.status = schemas.JobStatus.success
db.commit()
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
if job and db:
db.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error
job.status = models.JobStatus.error
db.commit()
raise
finally: