feat: add gpu support

This commit is contained in:
Felix Spöttel
2023-03-01 16:04:05 +01:00
parent f27fe02958
commit 7ece7944bf
6 changed files with 125 additions and 55 deletions

View File

@@ -1,4 +1,5 @@
from asyncio.log import logger
from typing import Any, Optional
from uuid import UUID
from celery import Task
@@ -12,11 +13,30 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
class TranscribeTask(Task):
abstract = True
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
def __init__(self) -> None:
super().__init__()
# currently only `LocalStrategy` is implemented.
# TODO: implement remote processing strategy.
self.strategy: Optional[LocalStrategy] = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# load model into memory once when the first task is processed.
if not self.strategy:
self.strategy = LocalStrategy()
return self.run(*args, **kwargs)
@celery.task(
base=TranscribeTask, bind=True, soft_time_limit=2 * 60 * 60
)
def transcribe(self: Task, job_id: UUID) -> None:
try:
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
if (
@@ -34,23 +54,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status processing.")
# pick a transcription strategy.
# currently only `local` is supported.
job_record = schemas.Job.from_orm(job)
strategy = LocalStrategy(
db=db, job_id=job.id, url=job_record.url, config=job_record.config
)
# process selected task.
# currently only `transcribe` is supported.
if job.type == schemas.JobType.transcript:
result = strategy.transcribe()
logger.info(f"[{job.id}]: successfully transcribed audio.")
result = self.strategy.transcribe(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
elif job.type == schemas.JobType.translation:
result = strategy.translate()
logger.info(f"[{job.id}]: successfully translated audio.")
result = self.strategy.translate(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
else:
result = strategy.detect_language()
result = self.strategy.detect_language(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
logger.info(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
@@ -66,7 +86,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
if job and db:
job.meta = {**job.meta.__dict__, "error": str(e)}
job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error
db.commit()
raise (e)