1
0
Fork 0
tinygrab/examples/whisper.py

577 lines
21 KiB
Python
Raw Normal View History

# thanks to https://github.com/openai/whisper for a good chunk of MIT licensed code
import sys
import pathlib
import base64
import multiprocessing
import numpy as np
from typing import Optional, Union, Literal, List
from tinygrad.jit import TinyJit
2023-08-22 08:36:24 -06:00
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv, DEBUG, CI, fetch
import tinygrad.nn as nn
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
import itertools
import librosa
2023-12-04 22:01:04 -07:00
class MultiHeadAttention:
2023-12-04 22:01:04 -07:00
def __init__(
self,
n_state,
n_head,
kv_caching: Literal["cross", "self"] = None,
max_self_attn_cache_len=None,
):
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.kv_caching = kv_caching
self.max_self_attn_cache_len = max_self_attn_cache_len
def __call__(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
len: Union[Variable, int] = None,
):
if self.kv_caching == "cross":
if xa is not None:
k, v = self.key(xa), self.value(xa)
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = k, v
else:
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
self.cache_k.assign(k + 1 - 1).realize()
self.cache_v.assign(v + 1 - 1).realize()
else:
k, v = self.cache_k, self.cache_v
else:
2023-12-04 22:01:04 -07:00
k, v = self.key(x), self.value(x)
if self.kv_caching == "self":
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.zeros(
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
)
self.cache_v = Tensor.zeros(
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
)
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
padding = self.max_self_attn_cache_len - len - x.shape[1]
self.cache_k.assign(
k.pad((None, (0, padding), None)).contiguous()
).realize()
self.cache_v.assign(
v.pad((None, (0, padding), None)).contiguous()
).realize()
q = self.query(x)
n_ctx = q.shape[1]
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
head_dim = q.shape[-1] // self.n_head
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
attn = Tensor.scaled_dot_product_attention(
q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None
)
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
return self.out(wv)
class ResidualAttentionBlock:
2023-12-04 22:01:04 -07:00
def __init__(
self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None
):
self.attn = MultiHeadAttention(
n_state,
n_head,
kv_caching="self" if is_decoder_block else None,
max_self_attn_cache_len=max_self_attn_cache_len,
)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head, kv_caching="cross")
if is_decoder_block
else None
)
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
self.mlp = [
nn.Linear(n_state, n_state * 4),
Tensor.gelu,
nn.Linear(n_state * 4, n_state),
]
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None):
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp_ln(x).sequential(self.mlp)
return x.realize()
class AudioEncoder:
2023-12-04 22:01:04 -07:00
def __init__(
self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_
):
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(
n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1
)
self.blocks = [
ResidualAttentionBlock(n_audio_state, n_audio_head)
for _ in range(n_audio_layer)
]
self.ln_post = nn.LayerNorm(n_audio_state)
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
self.encode = TinyJit(self.__call__)
def __call__(self, x):
x = self.conv1(x).gelu()
x = self.conv2(x).gelu()
x = x.permute(0, 2, 1)
x = x + self.positional_embedding[: x.shape[1]]
x = x.sequential(self.blocks)
x = self.ln_post(x)
return x.realize()
class TextDecoder:
2023-12-04 22:01:04 -07:00
def __init__(
self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_
):
self.max_tokens_to_sample = n_text_ctx // 2
self.max_self_attn_cache_len = (
self.max_tokens_to_sample * 2 + 5
) # roughly prompt + start toks + max_tokens_to_sample
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
self.blocks = [
ResidualAttentionBlock(
n_text_state,
n_text_head,
is_decoder_block=True,
max_self_attn_cache_len=self.max_self_attn_cache_len,
)
for _ in range(n_text_layer)
]
self.ln = nn.LayerNorm(n_text_state)
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks]
self.start_output_tok = TinyJit(self.output_tok)
self.after_start_output_tok = TinyJit(self.output_tok)
# if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only')
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
seqlen = x.shape[-1]
x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen]
if pos == 0:
for block in self.blocks if streaming else self.blocks_start_tok:
x = block(
x, xa=encoded_audio, mask=self.mask, len=0
) # pass xa for cross attn kv caching
return self.output_tok(x) if streaming else self.start_output_tok(x)
else:
for block in self.blocks_after_start_tok:
len_v = Variable(
"self_attn_cache_len", 1, self.max_self_attn_cache_len
).bind(pos)
x = block(x, mask=self.mask, len=len_v)
return self.after_start_output_tok(x)
def output_tok(self, x):
return (self.ln(x) @ self.token_embedding.weight.T).realize()
class Whisper:
2023-12-04 22:01:04 -07:00
def __init__(self, dims, batch_size=1):
self.encoder = AudioEncoder(**dims)
self.decoder = TextDecoder(**dims)
self.is_multilingual = dims["n_vocab"] == 51865
self.batch_size = batch_size
RATE = 16000
2023-12-04 22:01:04 -07:00
SEGMENT_SECONDS = 30
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
N_FFT = 400
HOP_LENGTH = 160
N_MELS = 80
2023-12-04 22:01:04 -07:00
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
def prep_audio(
waveforms: List[np.ndarray], batch_size: int, truncate=False
) -> np.ndarray:
"""
:param waveforms: A list of possibly variable length 16000Hz audio samples
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
:return: mel spectrogram of the given waveforms
"""
def pad_or_trim(arr, target_len):
curr_len = len(arr)
if curr_len == target_len:
return arr
elif curr_len < target_len:
return np.pad(arr, (0, target_len - curr_len), "constant")
else:
return arr[:target_len]
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
max_len += SAMPLES_PER_SEGMENT - r
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
assert waveforms.shape[0] <= batch_size
if waveforms.shape[0] < batch_size:
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
waveforms = np.pad(
waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0))
)
2023-12-04 22:01:04 -07:00
stft = librosa.stft(
waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann", dtype=np.csingle
)
magnitudes = np.absolute(stft[..., :-1]) ** 2
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
2023-12-04 22:01:04 -07:00
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
2023-12-04 22:01:04 -07:00
return log_spec
LANGUAGES = {
2023-12-04 22:01:04 -07:00
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
2023-12-04 22:01:04 -07:00
def get_encoding(encoding_name):
2023-12-04 22:01:04 -07:00
with fetch(
f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken"
).open() as f:
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in f if line)
}
n_vocab = len(ranks)
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
n_vocab += len(specials)
import tiktoken
return tiktoken.Encoding(
name=encoding_name,
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
MODEL_URLS = {
2023-12-04 22:01:04 -07:00
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
2023-12-04 22:01:04 -07:00
def init_whisper(model_name="tiny.en", batch_size=1):
2023-12-04 22:01:04 -07:00
assert MODEL_URLS[model_name] is not None
filename = fetch(MODEL_URLS[model_name])
state = torch_load(filename)
model = Whisper(state["dims"], batch_size)
load_state_dict(model, state["model_state_dict"], strict=False)
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
return model, enc
def load_file_waveform(filename):
2023-12-04 22:01:04 -07:00
waveform, _ = librosa.load(filename, sr=RATE)
return waveform
def transcribe_file(model, enc, filename):
2023-12-04 22:01:04 -07:00
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
def transcribe_waveform(model, enc, waveforms, truncate=False):
2023-12-04 22:01:04 -07:00
"""
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
"""
N_audio = len(waveforms)
log_spec = prep_audio(waveforms, model.batch_size, truncate)
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
raise Exception(
"Multi-segment transcription not supported with batch audio input"
)
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
if model.is_multilingual:
# TODO detect language
language_token = (
enc._special_tokens["<|startoftranscript|>"]
+ 1
+ tuple(LANGUAGES.keys()).index("en")
)
start_tokens.append(language_token)
start_tokens.append(enc._special_tokens["<|transcribe|>"])
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
transcription_start_index = len(start_tokens)
eot = enc._special_tokens["<|endoftext|>"]
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
encoded_audio = model.encoder.encode(
Tensor(log_spec[:, :, curr_frame : curr_frame + FRAMES_PER_SEGMENT])
)
pos = 0
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
if curr_frame > 0:
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt = np.concatenate(
(
[enc._special_tokens["<|startofprev|>"]],
transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :],
start_tokens,
)
)
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
transcription_start_index = len(curr_segment_tokens[0])
for i in range(model.decoder.max_tokens_to_sample):
out = model.decoder(
Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]),
pos,
encoded_audio,
streaming=curr_frame > 0,
)
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
curr_segment_tokens = np.concatenate(
(curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1
)
pos = curr_segment_tokens.shape[-1] - 1
if DEBUG >= 1:
print(
i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens))
)
if (curr_segment_tokens[:, -1] == eot).all():
break
for i, t in enumerate(curr_segment_tokens):
eot_index = np.where(t == eot)[0]
eot_index = None if len(eot_index) == 0 else eot_index[0]
transcription_tokens[i] = np.concatenate(
(transcription_tokens[i], t[transcription_start_index:eot_index])
)
transcriptions = list(
map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens)
)
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
CHUNK = 1600
RECORD_SECONDS = 10
2023-12-04 22:01:04 -07:00
def listener(q):
2023-12-04 22:01:04 -07:00
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=RATE,
input=True,
frames_per_buffer=CHUNK,
)
print("listening")
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
waveform = (np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3
q.put(waveform)
print("done listening")
if __name__ == "__main__":
2023-12-04 22:01:04 -07:00
model, enc = init_whisper(
"small.en" if getenv("SMALL") else "tiny.en", batch_size=1
)
if len(sys.argv) > 1:
print(transcribe_file(model, enc, sys.argv[1]))
else:
# online
q = multiprocessing.Queue()
p = multiprocessing.Process(target=listener, args=(q,))
p.daemon = True
p.start()
lst = [
enc._special_tokens["<|startoftranscript|>"],
enc._special_tokens["<|notimestamps|>"],
]
total = None
did_read = False
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
while not q.empty() or total is None:
waveform = q.get()
if total is None:
total = waveform
else:
total = np.concatenate([total, waveform])
did_read = True
if did_read:
log_spec = prep_audio(
total.reshape(1, -1), model.batch_size, truncate=True
)
encoded_audio = model.encoder.encode(Tensor(log_spec))
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
out = model.decoder(
Tensor([lst]), 0, encoded_audio, streaming=True
).realize()
idx = int(out[0, -1].argmax().numpy().item())
lst.append(idx)
dec = enc.decode(lst)
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
if dec.endswith("<|endoftext|>"):
lst.pop()