-
Notifications
You must be signed in to change notification settings - Fork 2
/
ASRManager.py
72 lines (60 loc) · 2.42 KB
/
ASRManager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import librosa
os.environ['KMP_DUPLICATE_LIB_OK']='True'
os.environ['HF_HOME'] = 'medium_model'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
from faster_whisper import WhisperModel
MODEL_PATH= "medium_model"
# medium epoch5 0.99569
class ASRManager:
def __init__(self):
# initialize the model here
self.frequency = 16000
self.model = WhisperModel(MODEL_PATH, device = "cuda", compute_type="float16", local_files_only = True)
w, _ = librosa.load('tester.wav',sr=self.frequency)
for i in range(3):
t = self.batch_transcribe_vad([w])
def clean(annotation):
if "'" in annotation:
# print(annotation, f'has \' in {annotation}, removing')
annotation = annotation.split("'")[0] + annotation.split("'")[1][1:] # Tokenizer includes "'" but TIL dataset does not, remove the S following '
return annotation
def batch_transcribe(self, batch) -> str:
# batch is a list of audio waveforms
# This is extremely rudimentary. Currently looking at possible batch inference using faster-whisper
batchResponse = []
for wf in batch:
output = ""
# segments, _ = self.model.transcribe(wf, beam_size = 5, vad_filter = True)
segments, _ = self.model.transcribe(wf, beam_size = 5)
for s in segments:
output += s.text
batchResponse.append(output)
return batchResponse
def batch_transcribe_vad(self, batch, vad_ms=500) -> str:
batchResponse = []
for wf in batch:
output = ""
segments, _ = self.model.transcribe(
wf,
beam_size = 5,
vad_filter = True,
vad_parameters=dict(
min_silence_duration_ms=vad_ms,
threshold=0.1
)
)
# segments, _ = self.model.transcribe(wf, beam_size = 5)
for s in segments:
output += s.text
batchResponse.append(output)
return batchResponse
def transcribe(self, audio_bytes: bytes) -> str:
# perform ASR transcription
segments, _ = self.model.transcribe(w, beam_size=5)
output = ""
for segment in segments:
output += segment.text
return output
#_ = ASRManager()
#print("Done")