whisper: make file transcription work, add basic CI test (#2042)
parent
924ecc4d6a
commit
91168a28c4
|
@ -209,6 +209,8 @@ jobs:
|
|||
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_jit.py
|
||||
- name: Run symbolic shapetracker test
|
||||
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
|
||||
- name: Run whisper test
|
||||
run: METAL=1 python -m pytest test/models/test_whisper.py
|
||||
- name: Check Device.DEFAULT
|
||||
run: WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
|
||||
- name: Run linearizer and tensor core test
|
||||
|
|
|
@ -110,16 +110,24 @@ RATE = 16000
|
|||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
|
||||
def prep_audio(waveform=None, sr=RATE) -> Tensor:
|
||||
def prep_audio(waveform) -> Tensor:
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
if waveform is None: waveform = np.zeros(N_FFT, dtype=np.float32)
|
||||
assert waveform is not None
|
||||
waveform = waveform.reshape(1, -1)
|
||||
|
||||
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32)
|
||||
magnitudes = stft[..., :-1] ** 2
|
||||
mel_spec = librosa.filters.mel(sr=sr, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8))
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
# https://github.com/openai/whisper/blob/b38a1f20f4b23f3f3099af2c3e0ca95627276ddf/whisper/audio.py#L19
|
||||
n_frames = log_spec.shape[2]
|
||||
if n_frames < 3000:
|
||||
log_spec = np.pad(log_spec, ((0, 0), (0, 0), (0, 3000 - n_frames)))
|
||||
|
||||
#print(waveform.shape, log_spec.shape)
|
||||
return log_spec
|
||||
|
||||
|
@ -170,41 +178,63 @@ def img(x):
|
|||
plt.show()
|
||||
|
||||
def listener(q):
|
||||
prep_audio()
|
||||
import pyaudio
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
print("listening")
|
||||
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
data = stream.read(CHUNK)
|
||||
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3).reshape(1, -1)
|
||||
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
|
||||
q.put(waveform)
|
||||
print("done listening")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("SMALL"):
|
||||
fn = BASE / "whisper-small.en.pt"
|
||||
download_file("https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", fn)
|
||||
else:
|
||||
fn = BASE / "whisper-tiny.en.pt"
|
||||
download_file("https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", fn)
|
||||
state = torch_load(fn)
|
||||
|
||||
MODEL_URLS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
def init_whisper(model_name="tiny.en"):
|
||||
assert MODEL_URLS[model_name] is not None
|
||||
|
||||
filename = BASE / "whisper-{}.pt".format(model_name)
|
||||
download_file(MODEL_URLS[model_name], filename)
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state['dims'])
|
||||
load_state_dict(model, state['model_state_dict'])
|
||||
|
||||
enc = get_encoding(state['dims']['n_vocab'])
|
||||
return model, enc
|
||||
|
||||
def transcribe_file(model, enc, filename):
|
||||
waveform, sample_rate = librosa.load(filename, sr=RATE)
|
||||
log_spec = prep_audio(waveform)
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
dat = model.encoder(Tensor(log_spec)).realize()
|
||||
|
||||
for i in range(50):
|
||||
out = model.decoder(Tensor([lst]), dat).realize()
|
||||
idx = int(out[0,-1].argmax().numpy().item())
|
||||
lst.append(idx)
|
||||
transcription = enc.decode(lst)
|
||||
print(transcription)
|
||||
if lst[-1] == enc._special_tokens["<|endoftext|>"]:
|
||||
return transcription
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en")
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
# offline
|
||||
waveform, sample_rate = librosa.load(sys.argv[1], normalize=True)
|
||||
log_spec = prep_audio(waveform, sample_rate)
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
dat = model.encoder(Tensor(log_spec)).realize()
|
||||
for i in range(50):
|
||||
out = model.decoder(Tensor([lst]), dat)
|
||||
out.realize()
|
||||
idx = out[0,-1].argmax().numpy()
|
||||
lst.append(idx)
|
||||
print(enc.decode(lst))
|
||||
transcribe_file(model, enc, sys.argv[1])
|
||||
else:
|
||||
# online
|
||||
|
||||
|
@ -213,24 +243,23 @@ if __name__ == "__main__":
|
|||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
total = None
|
||||
did_read = False
|
||||
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
while not q.empty() or total is None:
|
||||
waveform = q.get()
|
||||
if total is None: total = waveform
|
||||
else: total = np.concatenate([total, waveform], axis=1)
|
||||
else: total = np.concatenate([total, waveform])
|
||||
did_read = True
|
||||
if did_read:
|
||||
last_total = total.shape[1]
|
||||
log_spec = prep_audio(waveform=Tensor(total).numpy(), sr=RATE)
|
||||
log_spec = prep_audio(total)
|
||||
encoded_audio = model.encoder(Tensor(log_spec)).realize()
|
||||
out = model.decoder(Tensor([lst]), encoded_audio).realize()
|
||||
idx = out[0,-1].argmax().numpy().astype(dtype=np.int32)
|
||||
idx = int(out[0,-1].argmax().numpy().item())
|
||||
lst.append(idx)
|
||||
dec = enc.decode(lst)
|
||||
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
|
||||
if dec.endswith("<|endoftext|>"):
|
||||
#total = total[:, 320*(len(lst)-1):]
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
lst.pop()
|
||||
|
|
1
setup.py
1
setup.py
|
@ -54,6 +54,7 @@ setup(name='tinygrad',
|
|||
"nevergrad",
|
||||
"sentencepiece",
|
||||
"tiktoken",
|
||||
"librosa"
|
||||
],
|
||||
},
|
||||
include_package_data=True)
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
import unittest
|
||||
import pathlib
|
||||
from tinygrad.ops import Device
|
||||
from examples.whisper import init_whisper, transcribe_file
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
model, enc = init_whisper("tiny.en")
|
||||
cls.model = model
|
||||
cls.enc = enc
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
del cls.model
|
||||
del cls.enc
|
||||
|
||||
def test_transcribe_file(self):
|
||||
# Audio generated with the command on MacOS:
|
||||
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
|
||||
# We use the WAVE type because it's easier to decode in CI test environments
|
||||
filename = str(pathlib.Path(__file__).parent / "whisper/test.wav")
|
||||
transcription = transcribe_file(self.model, self.enc, filename)
|
||||
self.assertEqual("<|startoftranscript|><|notimestamps|> Could you please let me out of the box?<|endoftext|>", transcription)
|
Binary file not shown.
Loading…
Reference in New Issue