1
0
Fork 0

Cleanup & fix llama.py (#2524)

* docs, cleanup crap

* comma AI

* fix 70B

* this is why lexical scope exists
pull/2527/head^2
Davi Silva 2023-12-01 04:00:17 +07:00 committed by GitHub
parent 7d26452305
commit ddeec24fa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 44 deletions

View File

@ -13,105 +13,93 @@ from tinygrad.tensor import Tensor
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad.helpers import GlobalCounters
from extra.models.llama import Transformer, convert_from_huggingface
from sentencepiece import SentencePieceProcessor
MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
# **** files and arguments ****
# calculating params:
# traditionally, the MLP in the transformer architecture has hidden_dim = dim*4 [arxiv/1706.03762, 3.3]
# however, Llama uses SwiGLU. in order to preserve param count to original transformer arch, hidden_dim must be = 2/3 * (dim*4) [arxiv/2002.05202]
# for models using MQA (n_kv_heads != n_heads), preserving param count means hidden dim must be further multiplied by 1.3 [arxiv/2307.09288, A.2.1]
MODEL_PARAMS = {
"1": {
"7B": {
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000},
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 11008},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000},
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 13824},
"files": 2,
},
"30B": {
"args": {"dim": 6656, "multiple_of": 256, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000},
"args": {"dim": 6656, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 17920},
"files": 4,
},
"65B": {
"args": {"dim": 8192, "multiple_of": 256, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
"args": {"dim": 8192, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 22016},
"files": 8,
},
},
"2": {
"7B": {
"args": {"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000},
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 11008},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "multiple_of": 256, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000},
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 13824},
"files": 2,
},
"70B": {
"args": {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000},
"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 28672},
"files": 8,
},
},
"code": {
"7B": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
"files": 1,
},
"7B-Python": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 11008},
"files": 1,
},
"7B-Instruct": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
"files": 2,
},
"13B-Python": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 13824},
"files": 2,
},
"13B-Instruct": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32016},
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
"files": 2,
},
"34B": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
"files": 4,
},
"34B-Python": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
"files": 4,
},
"34B-Instruct": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "multiple_of": 256, "ffn_dim_multiplier": 1.0, "norm_eps": 1e-5, "rope_theta": 1000000, "vocab_size": 32000},
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
"files": 4,
},
},
"tiny": {
"1B": {
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "multiple_of": 256, "norm_eps": 1e-05, "vocab_size": 32000},
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632},
"files": 1,
}
}
}
# fix up MODEL_PARAMS to have hidden_dim
for model_gen in MODEL_PARAMS.values():
for model_type in model_gen.values():
model_args = model_type['args']
hidden_dim = model_args['dim'] * 4
multiple_of = model_args['multiple_of']
# TODO: what is this?
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
ffn_dim_multiplier = getattr(model_args, 'ffn_dim_multiplier', None)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
del model_args['ffn_dim_multiplier']
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
model_args['hidden_dim'] = hidden_dim
del model_args['multiple_of']
# **** helper functions ****
def concat_weights(models):
@ -159,11 +147,11 @@ class AbsmaxQuantizedLinear:
class LLaMa:
@staticmethod
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
from sentencepiece import SentencePieceProcessor
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
assert sp_model.vocab_size() == MODEL_PARAMS[model_gen][model_size]["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {MODEL_PARAMS[model_gen][model_size]['args']['vocab_size']}"
params = MODEL_PARAMS[model_gen][model_size]
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
if model_path.is_dir():
@ -171,7 +159,7 @@ class LLaMa:
else:
weights = load(str(model_path))
if "model.embed_tokens.weight" in weights:
weights = convert_from_huggingface(weights, model, model_args["n_heads"], model_args.get("n_kv_heads", model_args["n_heads"]))
weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
if quantize:
weights = AbsmaxQuantizedLinear.quantize(weights)
@ -437,4 +425,4 @@ After you are done speaking, output [EOS]. You are not Chad.
# stop after you have your answer
if chatbot and outputted.endswith(end_delim): break
if not chatbot: break
if not chatbot: break

View File

@ -41,7 +41,7 @@ class RMSNorm:
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
@ -80,10 +80,10 @@ class FeedForward:
def __init__(self, dim, hidden_dim, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
def __call__(self, x:Tensor) -> Tensor:
return self.w2(self.w1(x).silu() * self.w3(x))
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear):
@ -148,4 +148,4 @@ def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_he
elif "k_proj" in k:
v = permute(v, n_kv_heads)
sd[keymap[k]] = v
return sd
return sd