mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-12 21:48:35 +03:00
feat: add gpu support
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from asyncio.log import logger
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from celery import Task
|
||||
@@ -12,11 +13,30 @@ from app.worker.strategies.local import LocalStrategy
|
||||
|
||||
celery = get_celery_binding()
|
||||
|
||||
class TranscribeTask(Task):
|
||||
abstract = True
|
||||
|
||||
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# currently only `LocalStrategy` is implemented.
|
||||
# TODO: implement remote processing strategy.
|
||||
self.strategy: Optional[LocalStrategy] = None
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# load model into memory once when the first task is processed.
|
||||
if not self.strategy:
|
||||
self.strategy = LocalStrategy()
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
|
||||
@celery.task(
|
||||
base=TranscribeTask, bind=True, soft_time_limit=2 * 60 * 60
|
||||
)
|
||||
def transcribe(self: Task, job_id: UUID) -> None:
|
||||
try:
|
||||
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
|
||||
db: Session = SessionLocal()
|
||||
|
||||
job = db.query(models.Job).filter(models.Job.id == job_id).one()
|
||||
|
||||
if (
|
||||
@@ -34,23 +54,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
|
||||
logger.info(f"[{job.id}]: set task to status processing.")
|
||||
|
||||
# pick a transcription strategy.
|
||||
# currently only `local` is supported.
|
||||
job_record = schemas.Job.from_orm(job)
|
||||
strategy = LocalStrategy(
|
||||
db=db, job_id=job.id, url=job_record.url, config=job_record.config
|
||||
)
|
||||
|
||||
# process selected task.
|
||||
# currently only `transcribe` is supported.
|
||||
if job.type == schemas.JobType.transcript:
|
||||
result = strategy.transcribe()
|
||||
logger.info(f"[{job.id}]: successfully transcribed audio.")
|
||||
result = self.strategy.transcribe(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
elif job.type == schemas.JobType.translation:
|
||||
result = strategy.translate()
|
||||
logger.info(f"[{job.id}]: successfully translated audio.")
|
||||
result = self.strategy.translate(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
else:
|
||||
result = strategy.detect_language()
|
||||
result = self.strategy.detect_language(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
|
||||
logger.info(f"[{job.id}]: successfully processed audio.")
|
||||
|
||||
artifact = models.Artifact(
|
||||
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
|
||||
@@ -66,7 +86,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
logger.info(f"[{job.id}]: set task to status success.")
|
||||
except Exception as e:
|
||||
if job and db:
|
||||
job.meta = {**job.meta.__dict__, "error": str(e)}
|
||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||
job.status = schemas.JobStatus.error
|
||||
db.commit()
|
||||
raise (e)
|
||||
|
||||
Reference in New Issue
Block a user