mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-08 03:28:35 +03:00
@@ -30,6 +30,15 @@ class TranscribeTask(Task):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
|
||||
def select_strategy(task: Task, job: schemas.Job) -> Any:
|
||||
if job.type == schemas.JobType.transcript:
|
||||
return task.strategy.transcribe
|
||||
elif job.type == schemas.JobType.translation:
|
||||
return task.strategy.translate
|
||||
else:
|
||||
return task.strategy.detect_language
|
||||
|
||||
|
||||
@celery.task(
|
||||
base=TranscribeTask,
|
||||
bind=True,
|
||||
@@ -52,27 +61,19 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(f"[{job.id}]: worker received task.")
|
||||
|
||||
job.meta = {"task_id": self.request.id}
|
||||
job.status = schemas.JobStatus.processing
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[{job.id}]: set task to status processing.")
|
||||
|
||||
job_record = schemas.Job.from_orm(job)
|
||||
|
||||
# process selected task.
|
||||
if job.type == schemas.JobType.transcript:
|
||||
result = self.strategy.transcribe(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
elif job.type == schemas.JobType.translation:
|
||||
result = self.strategy.translate(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
else:
|
||||
result = self.strategy.detect_language(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
strategy = select_strategy(self, job_record)
|
||||
result = strategy(
|
||||
url=job_record.url, job_id=job_record.id, config=job_record.config
|
||||
)
|
||||
|
||||
logger.info(f"[{job.id}]: successfully processed audio.")
|
||||
|
||||
@@ -93,6 +94,7 @@ def transcribe(self: Task, job_id: UUID) -> None:
|
||||
job.meta = {**job.meta, "error": str(e)} # type: ignore
|
||||
job.status = schemas.JobStatus.error
|
||||
db.commit()
|
||||
raise (e)
|
||||
raise (e)
|
||||
finally:
|
||||
self.strategy.cleanup(job_id=job_id)
|
||||
db.close()
|
||||
|
||||
@@ -58,7 +58,7 @@ class LocalStrategy:
|
||||
def _download(self, url: str, job_id: UUID) -> str:
|
||||
# re-create folder.
|
||||
filename = self._get_tmp_file(job_id)
|
||||
self._cleanup(job_id)
|
||||
self.cleanup(job_id)
|
||||
|
||||
# stream media to disk.
|
||||
with requests.get(url, stream=True) as r:
|
||||
@@ -87,13 +87,13 @@ class LocalStrategy:
|
||||
|
||||
return result["segments"]
|
||||
finally:
|
||||
self._cleanup(job_id)
|
||||
self.cleanup(job_id)
|
||||
|
||||
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||
tmp = tempfile.gettempdir()
|
||||
return path.join(tmp, str(job_id))
|
||||
|
||||
def _cleanup(self, job_id: UUID) -> None:
|
||||
def cleanup(self, job_id: UUID) -> None:
|
||||
try:
|
||||
os.remove(self._get_tmp_file(job_id))
|
||||
except OSError:
|
||||
|
||||
Reference in New Issue
Block a user