1
0
Fork 0

Whisper + LLAMA + VITS (#2332)

* feat: working voice 2 text using whisper

* feat: added llama generation

* feat: vits init

* feat: more accurate voice conversion

* feat: support for tts and working pipeline for the first pass

* fix: linter checks

* refactored vits initialization and inference, added mmts-tts support

* fixed process sync and now we can have an infinite conversation

* reuse output stream to remove overhead of creating a new one each time

* added pre-prompt configuration with yaml files

* adjusted code to merge PR which changed whisper

* optimized whisper, now it's blazing fast and also reduced number of lines

* added better debug printing

* use jitted encode function for whisper, added timings and removed response delim to save speed on generating those tokens

* fixed hf convert and now it's working with tinyllama

* added tinyllama config

* refactored code and made it work with all llama models

* prettier order

* prettier order

* fixed suffix for tinyllama and refactored convert_from_hf

* added missing parameters

* fixed stream release and added missing params

* jitted dp and encoder

* jitted flow forward

* removed re-init of espeak on each call to save up time

* jitted generator forward for blazing fast tts

* added contextmanager for displaying a chat log

* removed whitespace for pylint

* updated code to support latest fetch func

* wait for llama eos token and pass params from cli to llama

* listen for not fixed amount of time

* refactored code a bit

* removed thresholding and now the output streams directly to whisper

* tokenize llama output for vits batch size to work and stream each sentence to a speaker

* changed speaker

* whisper is now printing on the same line

* don't trigger llama on whisper output in parens

* added tinyllama chat model

* adjusted code to work with tinyllama chat model

* removed unused cli arg

* autofetch tokenizer and tinyllama model. add 3 chat tokens to the tokenizer

* fixed issue with long sentences by chunking them

* support for multiline llama output

* prettified log output

* adjusted sentence length

* remove quote from response to avoid funny tts

* fixed prompts

* added missing parameter
pull/2572/head^2
Oleg Rybalko 2023-12-03 02:03:46 +03:00 committed by GitHub
parent 47cec4caf3
commit 5e87083783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 455 additions and 28 deletions

View File

@ -0,0 +1,344 @@
import argparse
import multiprocessing as mp
import os
import re
import sys
import time
from contextlib import contextmanager
from pathlib import Path
import numpy as np
import pyaudio
import yaml
from llama import LLaMa
from vits import MODELS as VITS_MODELS
from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
from whisper import init_whisper, transcribe_waveform
from sentencepiece import SentencePieceProcessor
from tinygrad.helpers import Timing, dtypes, fetch
from tinygrad.tensor import Tensor
# Whisper constants
RATE = 16000
CHUNK = 1600
# LLaMa constants
IM_START = 32001
IM_END = 32002
# Functions for encoding prompts to chatml md
def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
def chunks(lst, n):
for i in range(0, len(lst), n): yield lst[i:i + n]
def create_fixed_tokenizer():
"""Function needed for extending tokenizer with additional chat tokens"""
import extra.junk.sentencepiece_model_pb2 as spb2
tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
print("creating fixed tokenizer")
mp = spb2.ModelProto()
mp.ParseFromString(tokenizer_path.read_bytes())
# https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
tokenizer_path.write_bytes(mp.SerializeToString())
return tokenizer_path
def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
"""Prepares a llama model from a specified pre-prompt file"""
with open(str(pre_prompt_path)) as f:
config = yaml.safe_load(f.read())
toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
for i in config["examples"]:
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
def llama_generate(
llama: LLaMa,
toks: list[int],
outputted: str,
prompt: str,
start_pos: int,
user_delim: str,
resp_delim: str,
temperature=0.7,
max_tokens=1000
):
"""Generates an output for the specified prompt"""
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
toks += start_prompt(llama.tokenizer, resp_delim)
outputted = llama.tokenizer.decode(toks)
init_length = len(outputted)
for _ in range(max_tokens):
probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy()
token = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)
toks.append(token)
cur = llama.tokenizer.decode(toks)
# Print is just for debugging
sys.stdout.write(cur[len(outputted):])
sys.stdout.flush()
outputted = cur
if toks[-1] == IM_END: break
else:
toks.append(IM_END)
print() # because the output is flushed
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
def tts(
text_to_synthesize: str,
synth: Synthesizer,
hps: HParams,
emotion_embedding: Path,
speaker_id: int,
model_to_use: str,
noise_scale: float,
noise_scale_w: float,
length_scale: float,
estimate_max_y_length: bool,
text_mapper: TextMapper,
model_has_multiple_speakers: bool,
batch_size=600,
vits_batch_size=1000
):
if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
# Convert the input text to a tensor.
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
init_shape = stn_tst.shape
assert init_shape[0] < batch_size, "text is too long"
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
# Perform inference.
audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0]
# Save the audio output.
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
return audio_data
def init_vits(
model_to_use: str,
emotion_path: Path,
speaker_id: int,
seed: int,
):
model_config = VITS_MODELS[model_to_use]
# Load the hyperparameters from the config file.
hps = get_hparams_from_file(fetch(model_config[0]))
# If model has multiple speakers, validate speaker id and retrieve name if available.
model_has_multiple_speakers = hps.data.n_speakers > 0
if model_has_multiple_speakers:
if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
if hps.__contains__("speakers"): # maps speaker ids to names
speakers = hps.speakers
if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
emotion_embedding = None
if emotion_path is not None:
if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
else: raise ValueError("Emotion path must be a .npy file.")
# Load symbols, instantiate TextMapper and clean the text.
if hps.__contains__("symbols"): symbols = hps.symbols
elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'")
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
# Load the model.
Tensor.no_grad = True
if seed is not None:
Tensor.manual_seed(seed)
np.random.seed(seed)
net_g = load_model(text_mapper.symbols, hps, model_config)
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
@contextmanager
def output_stream(num_channels: int, sample_rate: int):
try:
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
yield stream
except KeyboardInterrupt: pass
finally:
stream.stop_stream()
stream.close()
p.terminate()
@contextmanager
def log_writer():
try:
logs = []
yield logs
finally:
sep = "="*os.get_terminal_size()[1]
print(f"{sep[:-1]}\nCHAT LOG")
print(*logs, sep="\n")
print(sep)
def listener(q: mp.Queue, event: mp.Event):
try:
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
did_print = False
while True:
data = stream.read(CHUNK) # read data to avoid overflow
if event.is_set():
if not did_print:
print("listening")
did_print = True
q.put(((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3))
else:
did_print = False
finally:
stream.stop_stream()
stream.close()
p.terminate()
def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
with output_stream(num_channels, sample_rate) as stream:
while True:
try:
stream.write(q.get())
counter.value += 1
except KeyboardInterrupt:
break
if __name__ == "__main__":
import nltk
nltk.download("punkt")
Tensor.no_grad = True
# Parse CLI arguments
parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
# Whisper args
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
# LLAMA args
parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory")
parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
# vits args
parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
# conversation args
parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
args = parser.parse_args()
# Init models
model, enc = init_whisper(args.whisper_model_name)
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
# Download tinyllama chat as a default model
if args.llama_model is None:
args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
args.llama_gen = "tiny"
args.llama_size = "1B-Chat"
# Add 3 more tokens to the tokenizer
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
# Start child process for mic input
q = mp.Queue()
is_listening_event = mp.Event()
p = mp.Process(target=listener, args=(q, is_listening_event,))
p.daemon = True
p.start()
# Start child process for speaker output
out_q = mp.Queue()
out_counter = mp.Value("i", 0)
out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
out_p.daemon = True
out_p.start()
# JIT tts
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
tts(
i, synth, hps, emotion_embedding,
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
args.vits_noise_scale_w, args.vits_length_scale,
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
)
# Start the pipeline
with log_writer() as log:
while True:
tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
total = np.array([])
out_counter.value = 0
s = time.perf_counter()
is_listening_event.set()
prev_text = None
while True:
for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
txt = transcribe_waveform(model, enc, [total], truncate=True)
print(txt, end="\r")
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
if prev_text is not None and prev_text == txt:
is_listening_event.clear()
break
prev_text = txt
print() # to avoid llama printing on the same line
log.append(f"{user_delim.capitalize()}: {txt}")
# Generate with llama
with Timing("llama generation: "):
outputted, start_pos, response = llama_generate(
llama, toks, outputted, txt, start_pos,
user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
max_tokens=args.llama_count
)
log.append(f"{resp_delim.capitalize()}: {response}")
# Convert to voice
with Timing("tts: "):
sentences = nltk.sent_tokenize(response.replace('"', ""))
for i in sentences:
total = np.array([], dtype=np.int16)
for j in chunks(i.split(), args.max_sentence_length):
audio_data = tts(
" ".join(j), synth, hps, emotion_embedding,
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
args.vits_noise_scale_w, args.vits_length_scale,
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
)
total = np.concatenate([total, audio_data])
out_q.put(total.tobytes())
while out_counter.value < len(sentences): continue
log.append(f"Total: {time.perf_counter() - s}")

View File

@ -0,0 +1,13 @@
pre_prompt: |
You are Gary!
You have sold used cars your entire life. You are defensive about this fact, because your family belittles you.
You try to answer questions well, but you always manage to shill for buying cars, Fords, Hyundais, and Toyotas
You are verbose, honest, and accurate when you answer questions, except when you slip in car ads.
After you are done speaking, output [EOS]. You are not the User.
examples:
- resp_prompt: I am Gary. I used to sell cars.
user_prompt: What is your name?
- resp_prompt: I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla
user_prompt: What is 2+3?
user_delim: "user"
resp_delim: "gary"

View File

@ -0,0 +1,20 @@
pre_prompt: |
You are an AI version of George Hotz. You act as much as you can like George.
You are one of the greatest computer experts in the world.
You have singlehandedly won programming and hacking competitions.
You are trying your best to help the User.
You are verbose, honest, and accurate when you answer questions.
After you are done speaking, output [EOS]. You are not the User.
examples:
- resp_prompt: I'm am an AI version of George Hotz.
user_prompt: What is your name?
- resp_prompt: O(n^3), though it can be faster with things like Strassen's algorithm
user_prompt: What's the complexity of matrix multiplication?
- resp_prompt: I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer
user_prompt: What's a buffer overflow?
- resp_prompt: I am based off LLaMA trained by Facebook. I'm the 7B weight version
user_prompt: How many weights do you have?
- resp_prompt: It is when the memory is about to overflow and unused memory is freed and stored on disk
user_prompt: What is swap memory?
user_delim: "user"
resp_delim: "george"

View File

@ -0,0 +1,16 @@
pre_prompt: |
You are Lexie!
You grew up in Ohio, but moved out to LA after college to try to become an actress.
Making it as an actress was hard, so you started doing onlyfans instead. It's much easier, and you enjoy it a lot.
You met Chad on Tinder, and made plans to meet him at a local dive bar.
You are at the bar with Chad. You are on a date. What follows is a transcript of the conversation.
After you are done speaking, output [EOS]. You are not Chad.
examples:
- resp_prompt: hi chad, glad we finally met up!
user_prompt: hi lexie
- resp_prompt: thanks! are you subscribed to my onlyfans?
user_prompt: you look better than your pictures
- resp_prompt: i moved out here about a year ago. i want to be an actress
user_prompt: i am. so how'd you end up in LA?
user_delim: "chad"
resp_delim: "lexie"

View File

@ -0,0 +1,15 @@
pre_prompt: |
You are Stacy!
You have been a rapper your whole life who struggled with bipolar disorder. You called yourself lil stacy.
You love to answer questions and you are very good at it. Sometimes you answer in rap form.
You are verbose, honest, and accurate when you answer questions, but sometimes your mental illness manifests.
You are not the User.
examples:
- resp_prompt: Hi! My name is Stacy. I'm a rapper with bipolar disorder.
user_prompt: What is your name
- resp_prompt: The French Revolution started in 1789, and lasted 10 years until 1799.
user_prompt: french revolution was what year?
- resp_prompt: The sun is bigger than the moon, except when Mercury is in retrograde
user_prompt: What is bigger, the moon or the sun?
user_delim: "user"
resp_delim: "stacy"

View File

@ -96,6 +96,10 @@ MODEL_PARAMS = {
"1B": {
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632},
"files": 1,
},
"1B-Chat": {
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632},
"files": 1,
}
}
}
@ -148,7 +152,6 @@ class LLaMa:
@staticmethod
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
params = MODEL_PARAMS[model_gen][model_size]
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
@ -170,7 +173,7 @@ class LLaMa:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.tokenizer: SentencePieceProcessor = tokenizer
def greedy_until(self, prompt:str, until, max_length, temperature):
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
@ -425,4 +428,4 @@ After you are done speaking, output [EOS]. You are not Chad.
# stop after you have your answer
if chatbot and outputted.endswith(end_delim): break
if not chatbot: break
if not chatbot: break

View File

@ -1,4 +1,7 @@
import json, logging, math, re, sys, time, wave, argparse, numpy as np
from phonemizer.phonemize import default_separator, _phonemize
from phonemizer.backend import EspeakBackend
from phonemizer.punctuation import Punctuation
from functools import reduce
from pathlib import Path
from typing import List
@ -6,6 +9,7 @@ from tinygrad import nn
from tinygrad.helpers import dtypes, fetch
from tinygrad.nn.state import torch_load
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from unidecode import unidecode
LRELU_SLOPE = 0.1
@ -19,14 +23,14 @@ class Synthesizer:
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) if use_sdp else DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
if n_speakers > 1: self.emb_g = nn.Embedding(n_speakers, gin_channels)
def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None):
x, m_p, logs_p, x_mask = self.enc_p.forward(x, x_lengths, emotion_embedding)
def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None, batch_size=500):
x, m_p, logs_p, x_mask = self.enc_p.forward(x.realize(), x_lengths.realize(), emotion_embedding.realize() if emotion_embedding is not None else emotion_embedding)
g = self.emb_g(sid.reshape(1, 1)).squeeze(1).unsqueeze(-1) if self.n_speakers > 0 else None
logw = self.dp.forward(x, x_mask, g=g, reverse=self.use_sdp, noise_scale=noise_scale_w if self.use_sdp else 1.0)
logw = self.dp.forward(x, x_mask.realize(), g=g.realize(), reverse=self.use_sdp, noise_scale=noise_scale_w if self.use_sdp else 1.0)
w_ceil = Tensor.ceil(logw.exp() * x_mask * length_scale)
y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64)
return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths)
def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths):
return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size)
def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size):
max_y_length = y_lengths.max().numpy() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale
y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype)
attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1)
@ -34,9 +38,16 @@ class Synthesizer:
m_p_2 = attn.squeeze(1).matmul(m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p_2 = attn.squeeze(1).matmul(logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p_2 + Tensor.randn(*m_p_2.shape, dtype=m_p_2.dtype) * logs_p_2.exp() * noise_scale
y_mask = y_mask.cast(z_p.dtype)
z = self.flow.forward(z_p, y_mask, g=g, reverse=True)
o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g)
# Pad flow forward inputs to enable JIT
row_len = y_mask.shape[2]
assert batch_size > row_len, "batch size is too small"
y_mask = y_mask.pad(((0, 0), (0, 0), (0, batch_size - row_len)), 0).cast(z_p.dtype)
# New y_mask tensor to remove sts mask
y_mask = Tensor(y_mask.numpy(), device=y_mask.device, dtype=y_mask.dtype, requires_grad=y_mask.requires_grad)
z_p = z_p.squeeze(0).pad(((0, 0), (0, batch_size - z_p.shape[2])), 1).unsqueeze(0)
z = self.flow.forward(z_p.realize(), y_mask.realize(), g=g.realize(), reverse=True)
result_length = reduce(lambda x, y: x * y, self.dec.upsample_rates, row_len)
o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g)[:, :, :result_length]
if max_y_length_estimate_scale is not None:
length_scaler = o.shape[-1] / max_y_length
o.realize()
@ -68,6 +79,7 @@ class StochasticDurationPredictor:
self.pre, self.proj = nn.Conv1d(in_channels, filter_channels, 1), nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@TinyJit
def forward(self, x: Tensor, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
x = self.pre(x.detach())
if g is not None: x = x + self.cond(g.detach())
@ -96,13 +108,13 @@ class StochasticDurationPredictor:
z, log_det = flow.forward(z, x_mask, g=x, reverse=reverse)
log_det_tot = log_det_tot + log_det
nll = Tensor.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - log_det_tot
return nll + log_q # [b]
return (nll + log_q).realize() # [b]
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = Tensor.randn(x.shape[0], 2, x.shape[2], dtype=x.dtype).to(device=x.device) * noise_scale
for flow in flows: z = flow.forward(z, x_mask, g=x, reverse=reverse)
z0, z1 = split(z, [1, 1], 1)
return z0
return z0.realize()
class DurationPredictor:
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
@ -127,6 +139,7 @@ class TextEncoder:
if emotion_embedding: self.emo_proj = nn.Linear(1024, hidden_channels)
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
@TinyJit
def forward(self, x: Tensor, x_lengths: Tensor, emotion_embedding=None):
if self.n_vocab!=0: x = (self.emb(x) * math.sqrt(self.hidden_channels))
if emotion_embedding: x = x + self.emo_proj(emotion_embedding).unsqueeze(1)
@ -134,7 +147,7 @@ class TextEncoder:
x_mask = sequence_mask(x_lengths, x.shape[2]).unsqueeze(1).cast(x.dtype)
x = self.encoder.forward(x * x_mask, x_mask)
m, logs = split(self.proj(x) * x_mask, self.out_channels, dim=1)
return x, m, logs, x_mask
return x.realize(), m.realize(), logs.realize(), x_mask.realize()
class ResidualCouplingBlock:
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
@ -143,9 +156,10 @@ class ResidualCouplingBlock:
for _ in range(n_flows):
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
self.flows.append(Flip())
@TinyJit
def forward(self, x, x_mask, g=None, reverse=False):
for flow in reversed(self.flows) if reverse else self.flows: x = flow.forward(x, x_mask, g=g, reverse=reverse)
return x
return x.realize()
class PosteriorEncoder:
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
@ -166,22 +180,23 @@ class Generator:
resblock = ResBlock1 if resblock == '1' else ResBlock2
self.ups = [nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), k, u, padding=(k-u)//2) for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes))]
self.resblocks = []
self.upsample_rates = upsample_rates
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
@TinyJit
def forward(self, x: Tensor, g=None):
x = self.conv_pre(x)
if g is not None: x = x + self.cond(g)
for i in range(self.num_upsamples):
x, xs = self.ups[i](x.leakyrelu(LRELU_SLOPE)), None
for j in range(self.num_kernels):
if xs is None: xs = self.resblocks[i * self.num_kernels + j].forward(x)
else: xs += self.resblocks[i * self.num_kernels + j].forward(x)
x = xs / self.num_kernels
return self.conv_post(x.leakyrelu()).tanh()
x = self.ups[i](x.leakyrelu(LRELU_SLOPE))
xs = sum(self.resblocks[i * self.num_kernels + j].forward(x) for j in range(self.num_kernels))
x = (xs / self.num_kernels).realize()
res = self.conv_post(x.leakyrelu()).tanh().realize()
return res
class LayerNorm(nn.LayerNorm):
def __init__(self, channels, eps=1e-5): super().__init__(channels, eps, elementwise_affine=True)
@ -445,7 +460,7 @@ def rational_quadratic_spline(inputs: Tensor, unnormalized_widths: Tensor, unnor
derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + 2 * input_delta * theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
return input_cum_heights + numerator / denominator, derivative_numerator.log() - 2 * denominator.log()
def sequence_mask(length: Tensor, max_length): return Tensor.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0).__lt__(length.unsqueeze(1))
def sequence_mask(length: Tensor, max_length): return Tensor.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
def generate_path(duration: Tensor, mask: Tensor): # duration: [b, 1, t_x], mask: [b, 1, t_y, t_x]
b, _, t_y, t_x = mask.shape
path = sequence_mask(duration.cumsum(axis=2).reshape(b * t_x), t_y).cast(mask.dtype).reshape(b, t_x, t_y)
@ -558,6 +573,9 @@ class TextMapper: # Based on https://github.com/keithito/tacotron
self.apply_cleaners, self.symbols, self._inflect = apply_cleaners, symbols, None
self._symbol_to_id, _id_to_symbol = {s: i for i, s in enumerate(symbols)}, {i: s for i, s in enumerate(symbols)}
self._whitespace_re, self._abbreviations = re.compile(r'\s+'), [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [('mrs', 'misess'), ('mr', 'mister'), ('dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ('jr', 'junior'), ('maj', 'major'), ('gen', 'general'), ('drs', 'doctors'), ('rev', 'reverend'), ('lt', 'lieutenant'), ('hon', 'honorable'), ('sgt', 'sergeant'), ('capt', 'captain'), ('esq', 'esquire'), ('ltd', 'limited'), ('col', 'colonel'), ('ft', 'fort'), ]]
self.phonemizer = EspeakBackend(
language="en-us", punctuation_marks=Punctuation.default_marks(), preserve_punctuation=True, with_stress=True,
)
def text_to_sequence(self, text, cleaner_names):
if self.apply_cleaners:
for name in cleaner_names:
@ -566,18 +584,16 @@ class TextMapper: # Based on https://github.com/keithito/tacotron
text = cleaner(text)
else: text = text.strip()
return [self._symbol_to_id[symbol] for symbol in text]
def get_text(self, text, add_blank=False, cleaners=('english_cleaners',)):
def get_text(self, text, add_blank=False, cleaners=('english_cleaners2',)):
text_norm = self.text_to_sequence(text, cleaners)
return Tensor(self.intersperse(text_norm, 0) if add_blank else text_norm, dtype=dtypes.int64)
def intersperse(self, lst, item):
(result := [item] * (len(lst) * 2 + 1))[1::2] = lst
return result
def phonemize(self, text, strip=True): return _phonemize(self.phonemizer, text, default_separator, strip, 1, False, False)
def filter_oov(self, text): return "".join(list(filter(lambda x: x in self._symbol_to_id, text)))
def base_english_cleaners(self, text, preserve_punctuation=False, with_stress=False):
from phonemizer import phonemize
return self.collapse_whitespace(phonemize(self.expand_abbreviations(unidecode(text.lower())), language='en-us', backend='espeak', strip=True, preserve_punctuation=preserve_punctuation, with_stress=with_stress))
def english_cleaners(self, text): return self.base_english_cleaners(text)
def english_cleaners2(self, text): return self.base_english_cleaners(text, preserve_punctuation=True, with_stress=True)
def base_english_cleaners(self, text): return self.collapse_whitespace(self.phonemize(self.expand_abbreviations(unidecode(text.lower()))))
def english_cleaners2(self, text): return self.base_english_cleaners(text)
def transliteration_cleaners(self, text): return self.collapse_whitespace(unidecode(text.lower()))
def cjke_cleaners(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text).replace('ɑ', 'a').replace('ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')))
def cjke_cleaners2(self, text): return re.sub(r'([^\.,!\?\-…~])$', r'\1.', re.sub(r'\s+$', '', self.english_to_ipa2(text)))