mirror of
https://github.com/bellingcat/whisperbox-transcribe.git
synced 2026-06-13 05:58:35 +03:00
refactor: add strategy.process() to BaseStrategy
This commit is contained in:
@@ -17,6 +17,20 @@ class TaskProtocol(Protocol):
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
def process(self, job: models.Job) -> TaskReturnValue:
|
||||
if job.type == models.JobType.transcript:
|
||||
return self.transcribe(job)
|
||||
elif job.type == models.JobType.translation:
|
||||
return self.translate(job)
|
||||
else:
|
||||
return self.detect_language(job)
|
||||
|
||||
def cleanup(self, job_id: UUID) -> None:
|
||||
try:
|
||||
os.remove(self._get_tmp_file(job_id))
|
||||
except OSError:
|
||||
...
|
||||
|
||||
def transcribe(self, job: models.Job) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -26,12 +40,6 @@ class BaseStrategy(ABC):
|
||||
def detect_language(self, job: models.Job) -> TaskReturnValue:
|
||||
raise NotImplementedError()
|
||||
|
||||
def cleanup(self, job_id: UUID) -> None:
|
||||
try:
|
||||
os.remove(self._get_tmp_file(job_id))
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _get_tmp_file(self, job_id: UUID) -> str:
|
||||
tmp = tempfile.gettempdir()
|
||||
return os.path.join(tmp, str(job_id))
|
||||
|
||||
Reference in New Issue
Block a user