1
0
Fork 0

whisper: make file transcription work, add basic CI test (#2042)

pull/2066/head
mmmkkaaayy 2023-10-13 17:13:35 -07:00 committed by GitHub
parent 924ecc4d6a
commit 91168a28c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 30 deletions

View File

@ -209,6 +209,8 @@ jobs:
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_jit.py
- name: Run symbolic shapetracker test
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
- name: Run whisper test
run: METAL=1 python -m pytest test/models/test_whisper.py
- name: Check Device.DEFAULT
run: WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
- name: Run linearizer and tensor core test

View File

@ -110,16 +110,24 @@ RATE = 16000
CHUNK = 1600
RECORD_SECONDS = 10
def prep_audio(waveform=None, sr=RATE) -> Tensor:
def prep_audio(waveform) -> Tensor:
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
if waveform is None: waveform = np.zeros(N_FFT, dtype=np.float32)
assert waveform is not None
waveform = waveform.reshape(1, -1)
stft = librosa.stft(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.float32)
magnitudes = stft[..., :-1] ** 2
mel_spec = librosa.filters.mel(sr=sr, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, mel_spec.max() + 1e8))
log_spec = (log_spec + 4.0) / 4.0
# https://github.com/openai/whisper/blob/b38a1f20f4b23f3f3099af2c3e0ca95627276ddf/whisper/audio.py#L19
n_frames = log_spec.shape[2]
if n_frames < 3000:
log_spec = np.pad(log_spec, ((0, 0), (0, 0), (0, 3000 - n_frames)))
#print(waveform.shape, log_spec.shape)
return log_spec
@ -170,41 +178,63 @@ def img(x):
plt.show()
def listener(q):
prep_audio()
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
print("listening")
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3).reshape(1, -1)
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
q.put(waveform)
print("done listening")
if __name__ == "__main__":
if getenv("SMALL"):
fn = BASE / "whisper-small.en.pt"
download_file("https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", fn)
else:
fn = BASE / "whisper-tiny.en.pt"
download_file("https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", fn)
state = torch_load(fn)
MODEL_URLS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
def init_whisper(model_name="tiny.en"):
assert MODEL_URLS[model_name] is not None
filename = BASE / "whisper-{}.pt".format(model_name)
download_file(MODEL_URLS[model_name], filename)
state = torch_load(filename)
model = Whisper(state['dims'])
load_state_dict(model, state['model_state_dict'])
enc = get_encoding(state['dims']['n_vocab'])
return model, enc
def transcribe_file(model, enc, filename):
waveform, sample_rate = librosa.load(filename, sr=RATE)
log_spec = prep_audio(waveform)
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
dat = model.encoder(Tensor(log_spec)).realize()
for i in range(50):
out = model.decoder(Tensor([lst]), dat).realize()
idx = int(out[0,-1].argmax().numpy().item())
lst.append(idx)
transcription = enc.decode(lst)
print(transcription)
if lst[-1] == enc._special_tokens["<|endoftext|>"]:
return transcription
if __name__ == "__main__":
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en")
if len(sys.argv) > 1:
# offline
waveform, sample_rate = librosa.load(sys.argv[1], normalize=True)
log_spec = prep_audio(waveform, sample_rate)
lst = [enc._special_tokens["<|startoftranscript|>"]]
dat = model.encoder(Tensor(log_spec)).realize()
for i in range(50):
out = model.decoder(Tensor([lst]), dat)
out.realize()
idx = out[0,-1].argmax().numpy()
lst.append(idx)
print(enc.decode(lst))
transcribe_file(model, enc, sys.argv[1])
else:
# online
@ -213,24 +243,23 @@ if __name__ == "__main__":
p.daemon = True
p.start()
lst = [enc._special_tokens["<|startoftranscript|>"]]
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
total = None
did_read = False
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
while not q.empty() or total is None:
waveform = q.get()
if total is None: total = waveform
else: total = np.concatenate([total, waveform], axis=1)
else: total = np.concatenate([total, waveform])
did_read = True
if did_read:
last_total = total.shape[1]
log_spec = prep_audio(waveform=Tensor(total).numpy(), sr=RATE)
log_spec = prep_audio(total)
encoded_audio = model.encoder(Tensor(log_spec)).realize()
out = model.decoder(Tensor([lst]), encoded_audio).realize()
idx = out[0,-1].argmax().numpy().astype(dtype=np.int32)
idx = int(out[0,-1].argmax().numpy().item())
lst.append(idx)
dec = enc.decode(lst)
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
if dec.endswith("<|endoftext|>"):
#total = total[:, 320*(len(lst)-1):]
lst = [enc._special_tokens["<|startoftranscript|>"]]
lst.pop()

View File

@ -54,6 +54,7 @@ setup(name='tinygrad',
"nevergrad",
"sentencepiece",
"tiktoken",
"librosa"
],
},
include_package_data=True)

View File

@ -0,0 +1,25 @@
import unittest
import pathlib
from tinygrad.ops import Device
from examples.whisper import init_whisper, transcribe_file
@unittest.skipUnless(Device.DEFAULT == "METAL", "Some non-metal backends spend too long trying to allocate a 20GB array")
class TestWhisper(unittest.TestCase):
@classmethod
def setUpClass(cls):
model, enc = init_whisper("tiny.en")
cls.model = model
cls.enc = enc
@classmethod
def tearDownClass(cls):
del cls.model
del cls.enc
def test_transcribe_file(self):
# Audio generated with the command on MacOS:
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test
# We use the WAVE type because it's easier to decode in CI test environments
filename = str(pathlib.Path(__file__).parent / "whisper/test.wav")
transcription = transcribe_file(self.model, self.enc, filename)
self.assertEqual("<|startoftranscript|><|notimestamps|> Could you please let me out of the box?<|endoftext|>", transcription)

Binary file not shown.