# thanks to https://github.com/openai/whisper for a good chunk of MIT licensed code import sys import pathlib import base64 import multiprocessing import numpy as np from typing import Optional, Union, Literal, List from tinygrad.jit import TinyJit from tinygrad.nn.state import torch_load, load_state_dict from tinygrad.helpers import getenv, DEBUG, CI, fetch import tinygrad.nn as nn from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor import itertools import librosa class MultiHeadAttention: def __init__( self, n_state, n_head, kv_caching: Literal["cross", "self"] = None, max_self_attn_cache_len=None, ): self.n_head = n_head self.query = nn.Linear(n_state, n_state) self.key = nn.Linear(n_state, n_state, bias=False) self.value = nn.Linear(n_state, n_state) self.out = nn.Linear(n_state, n_state) self.kv_caching = kv_caching self.max_self_attn_cache_len = max_self_attn_cache_len def __call__( self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, len: Union[Variable, int] = None, ): if self.kv_caching == "cross": if xa is not None: k, v = self.key(xa), self.value(xa) if not hasattr(self, "cache_k"): self.cache_k, self.cache_v = k, v else: # see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994 self.cache_k.assign(k + 1 - 1).realize() self.cache_v.assign(v + 1 - 1).realize() else: k, v = self.cache_k, self.cache_v else: k, v = self.key(x), self.value(x) if self.kv_caching == "self": if not hasattr(self, "cache_k"): self.cache_k = Tensor.zeros( x.shape[0], self.max_self_attn_cache_len, x.shape[2] ) self.cache_v = Tensor.zeros( x.shape[0], self.max_self_attn_cache_len, x.shape[2] ) k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1) v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1) padding = self.max_self_attn_cache_len - len - x.shape[1] self.cache_k.assign( k.pad((None, (0, padding), None)).contiguous() ).realize() self.cache_v.assign( v.pad((None, (0, padding), None)).contiguous() ).realize() q = self.query(x) n_ctx = q.shape[1] assert q.shape[-1] == k.shape[-1] == v.shape[-1] head_dim = q.shape[-1] // self.n_head q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3) attn = Tensor.scaled_dot_product_attention( q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None ) wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2) return self.out(wv) class ResidualAttentionBlock: def __init__( self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None ): self.attn = MultiHeadAttention( n_state, n_head, kv_caching="self" if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len, ) self.attn_ln = nn.LayerNorm(n_state) self.cross_attn = ( MultiHeadAttention(n_state, n_head, kv_caching="cross") if is_decoder_block else None ) self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None self.mlp = [ nn.Linear(n_state, n_state * 4), Tensor.gelu, nn.Linear(n_state * 4, n_state), ] self.mlp_ln = nn.LayerNorm(n_state) def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None): x = x + self.attn(self.attn_ln(x), mask=mask, len=len) if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa) x = x + self.mlp_ln(x).sequential(self.mlp) return x.realize() class AudioEncoder: def __init__( self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_ ): self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1) self.conv2 = nn.Conv1d( n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1 ) self.blocks = [ ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer) ] self.ln_post = nn.LayerNorm(n_audio_state) self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state) self.encode = TinyJit(self.__call__) def __call__(self, x): x = self.conv1(x).gelu() x = self.conv2(x).gelu() x = x.permute(0, 2, 1) x = x + self.positional_embedding[: x.shape[1]] x = x.sequential(self.blocks) x = self.ln_post(x) return x.realize() class TextDecoder: def __init__( self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_ ): self.max_tokens_to_sample = n_text_ctx // 2 self.max_self_attn_cache_len = ( self.max_tokens_to_sample * 2 + 5 ) # roughly prompt + start toks + max_tokens_to_sample self.token_embedding = nn.Embedding(n_vocab, n_text_state) self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state) self.blocks = [ ResidualAttentionBlock( n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len, ) for _ in range(n_text_layer) ] self.ln = nn.LayerNorm(n_text_state) self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize() self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks] self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks] self.start_output_tok = TinyJit(self.output_tok) self.after_start_output_tok = TinyJit(self.output_tok) # if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only') def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False): seqlen = x.shape[-1] x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen] if pos == 0: for block in self.blocks if streaming else self.blocks_start_tok: x = block( x, xa=encoded_audio, mask=self.mask, len=0 ) # pass xa for cross attn kv caching return self.output_tok(x) if streaming else self.start_output_tok(x) else: for block in self.blocks_after_start_tok: len_v = Variable( "self_attn_cache_len", 1, self.max_self_attn_cache_len ).bind(pos) x = block(x, mask=self.mask, len=len_v) return self.after_start_output_tok(x) def output_tok(self, x): return (self.ln(x) @ self.token_embedding.weight.T).realize() class Whisper: def __init__(self, dims, batch_size=1): self.encoder = AudioEncoder(**dims) self.decoder = TextDecoder(**dims) self.is_multilingual = dims["n_vocab"] == 51865 self.batch_size = batch_size RATE = 16000 SEGMENT_SECONDS = 30 SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000 N_FFT = 400 HOP_LENGTH = 160 N_MELS = 80 FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000 def prep_audio( waveforms: List[np.ndarray], batch_size: int, truncate=False ) -> np.ndarray: """ :param waveforms: A list of possibly variable length 16000Hz audio samples :param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio. Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes :param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass :return: mel spectrogram of the given waveforms """ def pad_or_trim(arr, target_len): curr_len = len(arr) if curr_len == target_len: return arr elif curr_len < target_len: return np.pad(arr, (0, target_len - curr_len), "constant") else: return arr[:target_len] max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms) if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms))) assert waveforms.shape[0] <= batch_size if waveforms.shape[0] < batch_size: # we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes waveforms = np.pad( waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)) ) stft = librosa.stft( waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann", dtype=np.csingle ) magnitudes = np.absolute(stft[..., :-1]) ** 2 mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) log_spec = np.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec LANGUAGES = { "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", } def get_encoding(encoding_name): with fetch( f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken" ).open() as f: ranks = { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line) } n_vocab = len(ranks) specials = [ "<|endoftext|>", "<|startoftranscript|>", *[f"<|{lang}|>" for lang in LANGUAGES.keys()], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], ] special_tokens = dict(zip(specials, itertools.count(n_vocab))) n_vocab += len(specials) import tiktoken return tiktoken.Encoding( name=encoding_name, explicit_n_vocab=n_vocab, pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", mergeable_ranks=ranks, special_tokens=special_tokens, ) MODEL_URLS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", } def init_whisper(model_name="tiny.en", batch_size=1): assert MODEL_URLS[model_name] is not None filename = fetch(MODEL_URLS[model_name]) state = torch_load(filename) model = Whisper(state["dims"], batch_size) load_state_dict(model, state["model_state_dict"], strict=False) enc = get_encoding("multilingual" if model.is_multilingual else "gpt2") return model, enc def load_file_waveform(filename): waveform, _ = librosa.load(filename, sr=RATE) return waveform def transcribe_file(model, enc, filename): return transcribe_waveform(model, enc, [load_file_waveform(filename)]) def transcribe_waveform(model, enc, waveforms, truncate=False): """ Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided """ N_audio = len(waveforms) log_spec = prep_audio(waveforms, model.batch_size, truncate) if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1: # we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch # if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent raise Exception( "Multi-segment transcription not supported with batch audio input" ) start_tokens = [enc._special_tokens["<|startoftranscript|>"]] if model.is_multilingual: # TODO detect language language_token = ( enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en") ) start_tokens.append(language_token) start_tokens.append(enc._special_tokens["<|transcribe|>"]) start_tokens.append(enc._special_tokens["<|notimestamps|>"]) transcription_start_index = len(start_tokens) eot = enc._special_tokens["<|endoftext|>"] transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0] for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT): encoded_audio = model.encoder.encode( Tensor(log_spec[:, :, curr_frame : curr_frame + FRAMES_PER_SEGMENT]) ) pos = 0 curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1)) if curr_frame > 0: # pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 prompt = np.concatenate( ( [enc._special_tokens["<|startofprev|>"]], transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :], start_tokens, ) ) curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1)) transcription_start_index = len(curr_segment_tokens[0]) for i in range(model.decoder.max_tokens_to_sample): out = model.decoder( Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0, ) next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32) next_tokens[curr_segment_tokens[:, -1] == eot] = eot curr_segment_tokens = np.concatenate( (curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1 ) pos = curr_segment_tokens.shape[-1] - 1 if DEBUG >= 1: print( i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)) ) if (curr_segment_tokens[:, -1] == eot).all(): break for i, t in enumerate(curr_segment_tokens): eot_index = np.where(t == eot)[0] eot_index = None if len(eot_index) == 0 else eot_index[0] transcription_tokens[i] = np.concatenate( (transcription_tokens[i], t[transcription_start_index:eot_index]) ) transcriptions = list( map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens) ) return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0] CHUNK = 1600 RECORD_SECONDS = 10 def listener(q): import pyaudio p = pyaudio.PyAudio() stream = p.open( format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK, ) print("listening") for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)): data = stream.read(CHUNK) waveform = (np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3 q.put(waveform) print("done listening") if __name__ == "__main__": model, enc = init_whisper( "small.en" if getenv("SMALL") else "tiny.en", batch_size=1 ) if len(sys.argv) > 1: print(transcribe_file(model, enc, sys.argv[1])) else: # online q = multiprocessing.Queue() p = multiprocessing.Process(target=listener, args=(q,)) p.daemon = True p.start() lst = [ enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"], ] total = None did_read = False for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)): while not q.empty() or total is None: waveform = q.get() if total is None: total = waveform else: total = np.concatenate([total, waveform]) did_read = True if did_read: log_spec = prep_audio( total.reshape(1, -1), model.batch_size, truncate=True ) encoded_audio = model.encoder.encode(Tensor(log_spec)) # pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 out = model.decoder( Tensor([lst]), 0, encoded_audio, streaming=True ).realize() idx = int(out[0, -1].argmax().numpy().item()) lst.append(idx) dec = enc.decode(lst) print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT if dec.endswith("<|endoftext|>"): lst.pop()