refactor: add strategy.process() to BaseStrategy

This commit is contained in:
Felix Spöttel
2023-06-29 16:31:57 +02:00
parent 05ebc17215
commit 85c3f1fc44
3 changed files with 25 additions and 25 deletions

View File

@@ -20,6 +20,7 @@ class TaskQueue:
allow for full separation of worker processes and dependencies.
"""
transcribe = self.celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors?
transcribe.delay(job.id)
def rehydrate(self, session: Session):

View File

@@ -9,13 +9,17 @@ import app.shared.db.models as models
from app.shared.celery import get_celery_binding
from app.shared.db.base import SessionLocal
from app.shared.settings import settings
from app.worker.strategies.base import TaskProtocol
from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
class TranscribeTask(Task):
"""
Decorate the transcribe task with an instance of the transcription strategy.
This is important for the local strategy, where loading the model is expensive.
"""
abstract = True
def __init__(self) -> None:
@@ -30,15 +34,6 @@ class TranscribeTask(Task):
return self.run(*args, **kwargs)
def select_task_processor(task: Task, job: models.Job) -> TaskProtocol:
if job.type == models.JobType.transcript:
return task.strategy.transcribe
elif job.type == models.JobType.translation:
return task.strategy.translate
else:
return task.strategy.detect_language
@celery.task(
base=TranscribeTask,
bind=True,
@@ -73,29 +68,25 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
# unit of work: process job with whisper.
processor = select_task_processor(self, job)
result_type, result = processor(job)
result_type, result = self.strategy.process(job)
logger.debug(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
db.add(artifact)
job.status = models.JobStatus.success
db.commit()
logger.debug(f"[{job.id}]: successfully stored artifact.")
except Exception as e:
if job and db:
db.rollback()
if db.in_transaction():
db.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = models.JobStatus.error
db.commit()
raise
finally:
self.strategy.cleanup(job_id=job_id)
self.strategy.cleanup(job_id)
db.close()

View File

@@ -17,6 +17,20 @@ class TaskProtocol(Protocol):
class BaseStrategy(ABC):
def process(self, job: models.Job) -> TaskReturnValue:
if job.type == models.JobType.transcript:
return self.transcribe(job)
elif job.type == models.JobType.translation:
return self.translate(job)
else:
return self.detect_language(job)
def cleanup(self, job_id: UUID) -> None:
try:
os.remove(self._get_tmp_file(job_id))
except OSError:
...
def transcribe(self, job: models.Job) -> TaskReturnValue:
raise NotImplementedError()
@@ -26,12 +40,6 @@ class BaseStrategy(ABC):
def detect_language(self, job: models.Job) -> TaskReturnValue:
raise NotImplementedError()
def cleanup(self, job_id: UUID) -> None:
try:
os.remove(self._get_tmp_file(job_id))
except OSError:
pass
def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir()
return os.path.join(tmp, str(job_id))