1
0
Fork 0
tinygrab/examples/conversation.py

593 lines
18 KiB
Python

import argparse
import multiprocessing as mp
import os
import re
import sys
import time
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pyaudio
import yaml
from llama import LLaMa
from vits import MODELS as VITS_MODELS
from vits import (
Y_LENGTH_ESTIMATE_SCALARS,
HParams,
Synthesizer,
TextMapper,
get_hparams_from_file,
load_model,
)
from whisper import init_whisper, transcribe_waveform
from sentencepiece import SentencePieceProcessor
from tinygrad.helpers import Timing, dtypes, fetch
from tinygrad.tensor import Tensor
# Whisper constants
RATE = 16000
CHUNK = 1600
# LLaMa constants
IM_START = 32001
IM_END = 32002
# Functions for encoding prompts to chatml md
def encode_prompt(spp, k, v):
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
def start_prompt(spp, k):
return [IM_START] + spp.encode(f"{k}\n")
def chunks(lst, n):
for i in range(0, len(lst), n):
yield lst[i : i + n]
def create_fixed_tokenizer():
"""Function needed for extending tokenizer with additional chat tokens"""
import extra.junk.sentencepiece_model_pb2 as spb2
tokenizer_path = fetch(
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model"
)
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
print("creating fixed tokenizer")
mp = spb2.ModelProto()
mp.ParseFromString(tokenizer_path.read_bytes())
# https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
tokenizer_path.write_bytes(mp.SerializeToString())
return tokenizer_path
def llama_prepare(
llama: LLaMa, temperature: float, pre_prompt_path: Path
) -> tuple[list[int], str, str, str]:
"""Prepares a llama model from a specified pre-prompt file"""
with open(str(pre_prompt_path)) as f:
config = yaml.safe_load(f.read())
toks = [llama.tokenizer.bos_id()] + encode_prompt(
llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ")
)
for i in config["examples"]:
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
return (
toks,
config["user_delim"],
config["resp_delim"],
len(toks),
llama.tokenizer.decode(toks),
)
def llama_generate(
llama: LLaMa,
toks: list[int],
outputted: str,
prompt: str,
start_pos: int,
user_delim: str,
resp_delim: str,
temperature=0.7,
max_tokens=1000,
):
"""Generates an output for the specified prompt"""
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
toks += start_prompt(llama.tokenizer, resp_delim)
outputted = llama.tokenizer.decode(toks)
init_length = len(outputted)
for _ in range(max_tokens):
probs_np = llama.model(
Tensor([toks[start_pos:]]), start_pos, temperature
).numpy()
token = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)
toks.append(token)
cur = llama.tokenizer.decode(toks)
# Print is just for debugging
sys.stdout.write(cur[len(outputted) :])
sys.stdout.flush()
outputted = cur
if toks[-1] == IM_END:
break
else:
toks.append(IM_END)
print() # because the output is flushed
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
def tts(
text_to_synthesize: str,
synth: Synthesizer,
hps: HParams,
emotion_embedding: Path,
speaker_id: int,
model_to_use: str,
noise_scale: float,
noise_scale_w: float,
length_scale: float,
estimate_max_y_length: bool,
text_mapper: TextMapper,
model_has_multiple_speakers: bool,
batch_size=600,
vits_batch_size=1000,
):
if model_to_use == "mmts-tts":
text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
# Convert the input text to a tensor.
stn_tst = text_mapper.get_text(
text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners
)
init_shape = stn_tst.shape
assert init_shape[0] < batch_size, "text is too long"
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(
0
), Tensor([init_shape[0]], dtype=dtypes.int64)
sid = (
Tensor([speaker_id], dtype=dtypes.int64)
if model_has_multiple_speakers
else None
)
# Perform inference.
audio_tensor = synth.infer(
x_tst,
x_tst_lengths,
sid,
noise_scale,
length_scale,
noise_scale_w,
emotion_embedding=emotion_embedding,
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use]
if estimate_max_y_length
else None,
batch_size=vits_batch_size,
)[0, 0]
# Save the audio output.
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
return audio_data
def init_vits(
model_to_use: str,
emotion_path: Path,
speaker_id: int,
seed: int,
):
model_config = VITS_MODELS[model_to_use]
# Load the hyperparameters from the config file.
hps = get_hparams_from_file(fetch(model_config[0]))
# If model has multiple speakers, validate speaker id and retrieve name if available.
model_has_multiple_speakers = hps.data.n_speakers > 0
if model_has_multiple_speakers:
if speaker_id >= hps.data.n_speakers:
raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
if hps.__contains__("speakers"): # maps speaker ids to names
speakers = hps.speakers
if isinstance(speakers, list):
speakers = {speaker: i for i, speaker in enumerate(speakers)}
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
emotion_embedding = None
if emotion_path is not None:
if emotion_path.endswith(".npy"):
emotion_embedding = Tensor(
np.load(emotion_path), dtype=dtypes.int64
).unsqueeze(0)
else:
raise ValueError("Emotion path must be a .npy file.")
# Load symbols, instantiate TextMapper and clean the text.
if hps.__contains__("symbols"):
symbols = hps.symbols
elif model_to_use == "mmts-tts":
symbols = [
x.replace("\n", "")
for x in fetch(
"https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt"
)
.open(encoding="utf-8")
.readlines()
]
else:
symbols = (
["_"]
+ list(';:,.!?¡¿—…"«»“” ')
+ list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")
+ list(
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
)
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
# Load the model.
Tensor.no_grad = True
if seed is not None:
Tensor.manual_seed(seed)
np.random.seed(seed)
net_g = load_model(text_mapper.symbols, hps, model_config)
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
@contextmanager
def output_stream(num_channels: int, sample_rate: int):
try:
p = pyaudio.PyAudio()
stream = p.open(
format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True
)
yield stream
except KeyboardInterrupt:
pass
finally:
stream.stop_stream()
stream.close()
p.terminate()
@contextmanager
def log_writer():
try:
logs = []
yield logs
finally:
sep = "=" * os.get_terminal_size()[1]
print(f"{sep[:-1]}\nCHAT LOG")
print(*logs, sep="\n")
print(sep)
def listener(q: mp.Queue, event: mp.Event):
try:
p = pyaudio.PyAudio()
stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=RATE,
input=True,
frames_per_buffer=CHUNK,
)
did_print = False
while True:
data = stream.read(CHUNK) # read data to avoid overflow
if event.is_set():
if not did_print:
print("listening")
did_print = True
q.put(((np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3))
else:
did_print = False
finally:
stream.stop_stream()
stream.close()
p.terminate()
def mp_output_stream(
q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int
):
with output_stream(num_channels, sample_rate) as stream:
while True:
try:
stream.write(q.get())
counter.value += 1
except KeyboardInterrupt:
break
if __name__ == "__main__":
import nltk
nltk.download("punkt")
Tensor.no_grad = True
# Parse CLI arguments
parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
# Whisper args
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
# LLAMA args
parser.add_argument(
"--llama_pre_prompt_path",
type=Path,
default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml",
help="Path to yaml file which contains all pre-prompt data needed. ",
)
parser.add_argument(
"--llama_count", type=int, default=1000, help="Max number of tokens to generate"
)
parser.add_argument(
"--llama_temperature",
type=float,
default=0.7,
help="Temperature in the softmax",
)
parser.add_argument(
"--llama_quantize",
action="store_true",
help="Quantize the weights to int8 in memory",
)
parser.add_argument(
"--llama_model",
type=Path,
default=None,
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
)
parser.add_argument(
"--llama_gen",
type=str,
default="tiny",
required=False,
help="Generation of the model to use",
)
parser.add_argument(
"--llama_size",
type=str,
default="1B-Chat",
required=False,
help="Size of model to use",
)
parser.add_argument(
"--llama_tokenizer",
type=Path,
default=None,
required=False,
help="Path to llama tokenizer.model",
)
# vits args
parser.add_argument(
"--vits_model_to_use",
default="vctk",
help="Specify the model to use. Default is 'vctk'.",
)
parser.add_argument(
"--vits_speaker_id",
type=int,
default=12,
help="Specify the speaker ID. Default is 6.",
)
parser.add_argument(
"--vits_noise_scale",
type=float,
default=0.667,
help="Specify the noise scale. Default is 0.667.",
)
parser.add_argument(
"--vits_noise_scale_w",
type=float,
default=0.8,
help="Specify the noise scale w. Default is 0.8.",
)
parser.add_argument(
"--vits_length_scale",
type=float,
default=1,
help="Specify the length scale. Default is 1.",
)
parser.add_argument(
"--vits_seed",
type=int,
default=None,
help="Specify the seed (set to None if no seed). Default is 1337.",
)
parser.add_argument(
"--vits_num_channels",
type=int,
default=1,
help="Specify the number of audio output channels. Default is 1.",
)
parser.add_argument(
"--vits_sample_width",
type=int,
default=2,
help="Specify the number of bytes per sample, adjust if necessary. Default is 2.",
)
parser.add_argument(
"--vits_emotion_path",
type=Path,
default=None,
help="Specify the path to emotion reference.",
)
parser.add_argument(
"--vits_estimate_max_y_length",
type=str,
default=False,
help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.",
)
parser.add_argument(
"--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary."
)
# conversation args
parser.add_argument(
"--max_sentence_length",
type=int,
default=20,
help="Max words in one sentence to pass to vits",
)
args = parser.parse_args()
# Init models
model, enc = init_whisper(args.whisper_model_name)
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(
args.vits_model_to_use,
args.vits_emotion_path,
args.vits_speaker_id,
args.vits_seed,
)
# Download tinyllama chat as a default model
if args.llama_model is None:
args.llama_model = fetch(
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors",
"tinyllamachat.safetensors",
)
args.llama_gen = "tiny"
args.llama_size = "1B-Chat"
# Add 3 more tokens to the tokenizer
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"):
args.llama_tokenizer = create_fixed_tokenizer()
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
llama = LLaMa.build(
args.llama_model,
tokenizer_path,
args.llama_gen,
args.llama_size,
args.llama_quantize,
)
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(
llama, args.llama_temperature, args.llama_pre_prompt_path
)
# Start child process for mic input
q = mp.Queue()
is_listening_event = mp.Event()
p = mp.Process(
target=listener,
args=(
q,
is_listening_event,
),
)
p.daemon = True
p.start()
# Start child process for speaker output
out_q = mp.Queue()
out_counter = mp.Value("i", 0)
out_p = mp.Process(
target=mp_output_stream,
args=(
out_q,
out_counter,
args.vits_num_channels,
hps.data.sampling_rate,
),
)
out_p.daemon = True
out_p.start()
# JIT tts
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
tts(
i,
synth,
hps,
emotion_embedding,
args.vits_speaker_id,
args.vits_model_to_use,
args.vits_noise_scale,
args.vits_noise_scale_w,
args.vits_length_scale,
args.vits_estimate_max_y_length,
text_mapper,
model_has_multiple_speakers,
)
# Start the pipeline
with log_writer() as log:
while True:
tokens = [
enc._special_tokens["<|startoftranscript|>"],
enc._special_tokens["<|notimestamps|>"],
]
total = np.array([])
out_counter.value = 0
s = time.perf_counter()
is_listening_event.set()
prev_text = None
while True:
for _ in range(RATE // CHUNK):
total = np.concatenate([total, q.get()])
txt = transcribe_waveform(model, enc, [total], truncate=True)
print(txt, end="\r")
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()):
continue
if prev_text is not None and prev_text == txt:
is_listening_event.clear()
break
prev_text = txt
print() # to avoid llama printing on the same line
log.append(f"{user_delim.capitalize()}: {txt}")
# Generate with llama
with Timing("llama generation: "):
outputted, start_pos, response = llama_generate(
llama,
toks,
outputted,
txt,
start_pos,
user_delim=user_delim,
resp_delim=resp_delim,
temperature=args.llama_temperature,
max_tokens=args.llama_count,
)
log.append(f"{resp_delim.capitalize()}: {response}")
# Convert to voice
with Timing("tts: "):
sentences = nltk.sent_tokenize(response.replace('"', ""))
for i in sentences:
total = np.array([], dtype=np.int16)
for j in chunks(i.split(), args.max_sentence_length):
audio_data = tts(
" ".join(j),
synth,
hps,
emotion_embedding,
args.vits_speaker_id,
args.vits_model_to_use,
args.vits_noise_scale,
args.vits_noise_scale_w,
args.vits_length_scale,
args.vits_estimate_max_y_length,
text_mapper,
model_has_multiple_speakers,
)
total = np.concatenate([total, audio_data])
out_q.put(total.tobytes())
while out_counter.value < len(sentences):
continue
log.append(f"Total: {time.perf_counter() - s}")