feat: add gpu support

This commit is contained in:
Felix Spöttel
2023-03-01 16:04:05 +01:00
parent f27fe02958
commit 7ece7944bf
6 changed files with 125 additions and 55 deletions

View File

@@ -1,4 +1,5 @@
from asyncio.log import logger
from typing import Any, Optional
from uuid import UUID
from celery import Task
@@ -12,11 +13,30 @@ from app.worker.strategies.local import LocalStrategy
celery = get_celery_binding()
class TranscribeTask(Task):
abstract = True
@celery.task(bind=True, soft_time_limit=2 * 60 * 60) # TODO: make configurable
def __init__(self) -> None:
super().__init__()
# currently only `LocalStrategy` is implemented.
# TODO: implement remote processing strategy.
self.strategy: Optional[LocalStrategy] = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# load model into memory once when the first task is processed.
if not self.strategy:
self.strategy = LocalStrategy()
return self.run(*args, **kwargs)
@celery.task(
base=TranscribeTask, bind=True, soft_time_limit=2 * 60 * 60
)
def transcribe(self: Task, job_id: UUID) -> None:
try:
# runs in a separate thread => requires sqlite's WAL mode to be enabled.
db: Session = SessionLocal()
job = db.query(models.Job).filter(models.Job.id == job_id).one()
if (
@@ -34,23 +54,23 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status processing.")
# pick a transcription strategy.
# currently only `local` is supported.
job_record = schemas.Job.from_orm(job)
strategy = LocalStrategy(
db=db, job_id=job.id, url=job_record.url, config=job_record.config
)
# process selected task.
# currently only `transcribe` is supported.
if job.type == schemas.JobType.transcript:
result = strategy.transcribe()
logger.info(f"[{job.id}]: successfully transcribed audio.")
result = self.strategy.transcribe(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
elif job.type == schemas.JobType.translation:
result = strategy.translate()
logger.info(f"[{job.id}]: successfully translated audio.")
result = self.strategy.translate(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
else:
result = strategy.detect_language()
result = self.strategy.detect_language(
url=job_record.url, job_id=job_record.id, config=job_record.config
)
logger.info(f"[{job.id}]: successfully processed audio.")
artifact = models.Artifact(
job_id=str(job.id), data=result, type=schemas.ArtifactType.raw_transcript
@@ -66,7 +86,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
logger.info(f"[{job.id}]: set task to status success.")
except Exception as e:
if job and db:
job.meta = {**job.meta.__dict__, "error": str(e)}
job.meta = {**job.meta, "error": str(e)} # type: ignore
job.status = schemas.JobStatus.error
db.commit()
raise (e)

View File

@@ -6,9 +6,8 @@ from typing import Any, List, Literal, Optional
from uuid import UUID
import requests
from pydantic import BaseModel
from sqlalchemy.orm import Session
import torch
from pydantic import BaseModel
from whisper import load_model
import app.shared.db.schemas as schemas
@@ -20,43 +19,49 @@ class DecodeOptions(BaseModel):
class LocalStrategy:
def __init__(
self, db: Session, job_id: UUID, url: str, config: Optional[schemas.JobConfig]
):
self.db = db
self.job_id = job_id
self.url = url
self.config = config
def __init__(self) -> None:
if torch.cuda.is_available():
logger.info("initializing GPU model.")
self.model = load_model(
os.environ["WHISPER_MODEL"],
download_root="/models"
os.environ["WHISPER_MODEL"], download_root="/models"
).cuda()
else:
logger.info("initializing CPU model.")
self.model = load_model(
os.environ["WHISPER_MODEL"],
download_root="/models"
os.environ["WHISPER_MODEL"], download_root="/models"
)
logger.info(f"[{self.job_id}]: initialized local strategy.")
logger.info("initialized local strategy.")
def transcribe(self) -> List[Any]:
return self.run_whisper(self._download(), "transcribe")
def transcribe(
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
return self.run_whisper(
self._download(url, job_id), "transcribe", config, job_id
)
def translate(self) -> List[Any]:
return self.run_whisper(self._download(), "translate")
def translate(
self, url: str, job_id: UUID, config: Optional[schemas.JobConfig]
) -> List[Any]:
return self.run_whisper(
self._download(url, job_id),
"translate",
config,
job_id,
)
def detect_language(self) -> List[Any]:
def detect_language(
self, url: str, config: Optional[schemas.JobConfig]
) -> List[Any]:
raise NotImplementedError("detect_language has not been implemented yet.")
def _download(self) -> str:
def _download(self, url: str, job_id: UUID) -> str:
# re-create folder.
filename = self._get_tmp_file()
self._cleanup()
filename = self._get_tmp_file(job_id)
self._cleanup(job_id)
# stream media to disk.
with requests.get(self.url, stream=True) as r:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
@@ -64,11 +69,17 @@ class LocalStrategy:
return filename
def run_whisper(self, filepath: str, task: str) -> List[Any]:
def run_whisper(
self,
filepath: str,
task: str,
config: Optional[schemas.JobConfig],
job_id: UUID,
) -> List[Any]:
try:
language = self.config.language if self.config else None
language = config.language if config else None
result = model.transcribe(
result = self.model.transcribe(
filepath,
condition_on_previous_text=False,
**DecodeOptions(task=task, language=language).dict(),
@@ -76,20 +87,14 @@ class LocalStrategy:
return result["segments"]
finally:
self._cleanup()
self._cleanup(job_id)
def _get_tmp_file(self) -> str:
def _get_tmp_file(self, job_id: UUID) -> str:
tmp = tempfile.gettempdir()
return path.join(tmp, str(self.job_id))
return path.join(tmp, str(job_id))
def _cleanup(self) -> None:
def _cleanup(self, job_id: UUID) -> None:
try:
os.remove(self._get_tmp_file())
os.remove(self._get_tmp_file(job_id))
except OSError:
pass
def _convert(self) -> None:
pass
def _transcribe(self) -> None:
pass