mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-13 05:58:35 +03:00
refactor: add strategy.process() to BaseStrategy
This commit is contained in:
@@ -20,6 +20,7 @@ class TaskQueue:
|
|||||||
allow for full separation of worker processes and dependencies.
|
allow for full separation of worker processes and dependencies.
|
||||||
"""
|
"""
|
||||||
transcribe = self.celery.signature("app.worker.main.transcribe")
|
transcribe = self.celery.signature("app.worker.main.transcribe")
|
||||||
|
# TODO: catch delivery errors?
|
||||||
transcribe.delay(job.id)
|
transcribe.delay(job.id)
|
||||||
|
|
||||||
def rehydrate(self, session: Session):
|
def rehydrate(self, session: Session):
|
||||||
|
|||||||
@@ -9,13 +9,17 @@ import app.shared.db.models as models
|
|||||||
from app.shared.celery import get_celery_binding
|
from app.shared.celery import get_celery_binding
|
||||||
from app.shared.db.base import SessionLocal
|
from app.shared.db.base import SessionLocal
|
||||||
from app.shared.settings import settings
|
from app.shared.settings import settings
|
||||||
from app.worker.strategies.base import TaskProtocol
|
|
||||||
from app.worker.strategies.local import LocalStrategy
|
from app.worker.strategies.local import LocalStrategy
|
||||||
|
|
||||||
celery = get_celery_binding()
|
celery = get_celery_binding()
|
||||||
|
|
||||||
|
|
||||||
class TranscribeTask(Task):
|
class TranscribeTask(Task):
|
||||||
|
"""
|
||||||
|
Decorate the transcribe task with an instance of the transcription strategy.
|
||||||
|
This is important for the local strategy, where loading the model is expensive.
|
||||||
|
"""
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -30,15 +34,6 @@ class TranscribeTask(Task):
|
|||||||
return self.run(*args, **kwargs)
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def select_task_processor(task: Task, job: models.Job) -> TaskProtocol:
|
|
||||||
if job.type == models.JobType.transcript:
|
|
||||||
return task.strategy.transcribe
|
|
||||||
elif job.type == models.JobType.translation:
|
|
||||||
return task.strategy.translate
|
|
||||||
else:
|
|
||||||
return task.strategy.detect_language
|
|
||||||
|
|
||||||
|
|
||||||
@celery.task(
|
@celery.task(
|
||||||
base=TranscribeTask,
|
base=TranscribeTask,
|
||||||
bind=True,
|
bind=True,
|
||||||
@@ -73,29 +68,25 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
|||||||
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
|
logger.debug(f"[{job.id}]: finished setting task to {job.status}.")
|
||||||
|
|
||||||
# unit of work: process job with whisper.
|
# unit of work: process job with whisper.
|
||||||
|
result_type, result = self.strategy.process(job)
|
||||||
processor = select_task_processor(self, job)
|
|
||||||
|
|
||||||
result_type, result = processor(job)
|
|
||||||
|
|
||||||
logger.debug(f"[{job.id}]: successfully processed audio.")
|
logger.debug(f"[{job.id}]: successfully processed audio.")
|
||||||
|
|
||||||
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type)
|
||||||
|
|
||||||
db.add(artifact)
|
db.add(artifact)
|
||||||
|
|
||||||
job.status = models.JobStatus.success
|
job.status = models.JobStatus.success
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
logger.debug(f"[{job.id}]: successfully stored artifact.")
|
logger.debug(f"[{job.id}]: successfully stored artifact.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if job and db:
|
if job and db:
|
||||||
db.rollback()
|
if db.in_transaction():
|
||||||
|
db.rollback()
|
||||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||||
job.status = models.JobStatus.error
|
job.status = models.JobStatus.error
|
||||||
db.commit()
|
db.commit()
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self.strategy.cleanup(job_id=job_id)
|
self.strategy.cleanup(job_id)
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -17,6 +17,20 @@ class TaskProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
|
def process(self, job: models.Job) -> TaskReturnValue:
|
||||||
|
if job.type == models.JobType.transcript:
|
||||||
|
return self.transcribe(job)
|
||||||
|
elif job.type == models.JobType.translation:
|
||||||
|
return self.translate(job)
|
||||||
|
else:
|
||||||
|
return self.detect_language(job)
|
||||||
|
|
||||||
|
def cleanup(self, job_id: UUID) -> None:
|
||||||
|
try:
|
||||||
|
os.remove(self._get_tmp_file(job_id))
|
||||||
|
except OSError:
|
||||||
|
...
|
||||||
|
|
||||||
def transcribe(self, job: models.Job) -> TaskReturnValue:
|
def transcribe(self, job: models.Job) -> TaskReturnValue:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@@ -26,12 +40,6 @@ class BaseStrategy(ABC):
|
|||||||
def detect_language(self, job: models.Job) -> TaskReturnValue:
|
def detect_language(self, job: models.Job) -> TaskReturnValue:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def cleanup(self, job_id: UUID) -> None:
|
|
||||||
try:
|
|
||||||
os.remove(self._get_tmp_file(job_id))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_tmp_file(self, job_id: UUID) -> str:
|
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||||
tmp = tempfile.gettempdir()
|
tmp = tempfile.gettempdir()
|
||||||
return os.path.join(tmp, str(job_id))
|
return os.path.join(tmp, str(job_id))
|
||||||
|
|||||||
Reference in New Issue
Block a user