1
0
Fork 0

Reformat, uh, everything, with black

deepcrayon
Jeff Moe 2023-12-04 22:01:04 -07:00
parent 01503ca90d
commit 661dcc5ed0
236 changed files with 48096 additions and 26819 deletions

View File

@ -4,13 +4,17 @@ import pathlib
from hexdump import hexdump
fxn = None
def disasm(buf):
global fxn
if fxn is None:
shared = pathlib.Path(__file__).parent / "disasm.so"
if not shared.is_file():
os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so')
fxn = ctypes.CDLL(shared.as_posix())['disasm']
os.system(
f"cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so"
)
fxn = ctypes.CDLL(shared.as_posix())["disasm"]
# hexdump(buf)
END = b"\x00\x00\x00\x00\x00\x00\x00\x03"
buf = buf[0x510:] # this right?

View File

@ -23,21 +23,24 @@ from abc import ABC
# we will be using the clang backend
from tinygrad import Device
Device.DEFAULT = "CLANG"
# first, 2+3 as a Tensor, the highest level
from tinygrad.tensor import Tensor
a = Tensor([2])
b = Tensor([3])
result = a + b
print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
assert result.numpy()[0] == 5.
assert result.numpy()[0] == 5.0
# %%
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
# it's worth reading tinygrad/tensor.py. it's pretty beautiful
import tinygrad.mlops as mlops
# this is the good old familiar Tensor class
class Tensor:
# these two are pretty straightforward
@ -51,10 +54,13 @@ class Tensor:
lazydata: LazyBuffer
# high level ops (hlops) are defined on this class. example: relu
def relu(self): return self.maximum(0)
def relu(self):
return self.maximum(0)
# log is an mlop, this is the wrapper function in Tensor
def log(self): return mlops.Log.apply(self)
def log(self):
return mlops.Log.apply(self)
# all the definitions of the derivatives are subclasses of Function (like mlops.Log)
# there's only 18 mlops for derivatives for everything (in tinygrad/mlops.py, code 9/10)
@ -62,13 +68,18 @@ class Tensor:
# you can differentiate the world using the chain rule
class Function:
# example types of forward and backward
def forward(self, x:LazyBuffer) -> LazyBuffer: pass
def backward(self, x:LazyBuffer) -> LazyBuffer: pass
def forward(self, x: LazyBuffer) -> LazyBuffer:
pass
def backward(self, x: LazyBuffer) -> LazyBuffer:
pass
# %%
# == LazyBuffer (in tinygrad/lazy.py, code 5/10) ==
from tinygrad.helpers import DType
# this is where the properties live that you thought were a part of Tensor
# LazyBuffer is like a Tensor without derivatives, at the mlop layer
class LazyBuffer:
@ -91,6 +102,7 @@ class LazyBuffer:
# this LazyOp describes the computation needed to realize this LazyBuffer
op: Optional[LazyOp]
# LazyOp (in tinygrad/ops.py, code 4/10)
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
class LazyOp:
@ -98,13 +110,52 @@ class LazyOp:
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
arg: Optional[Any] = None # and an optional static argument
# there's currently 26 Ops you have to implement for an accelerator.
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto()
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto()
class ReduceOps(Enum): SUM = auto(); MAX = auto()
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto()
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
class UnaryOps(Enum):
EXP2 = auto()
LOG2 = auto()
CAST = auto()
SIN = auto()
SQRT = auto()
class BinaryOps(Enum):
ADD = auto()
SUB = auto()
MUL = auto()
DIV = auto()
CMPLT = auto()
MAX = auto()
class ReduceOps(Enum):
SUM = auto()
MAX = auto()
class MovementOps(Enum):
RESHAPE = auto()
PERMUTE = auto()
EXPAND = auto()
PAD = auto()
SHRINK = auto()
STRIDE = auto()
class TernaryOps(Enum):
MULACC = auto()
WHERE = auto()
class LoadOps(Enum):
EMPTY = auto()
CONST = auto()
FROM = auto()
CONTIGUOUS = auto()
CUSTOM = auto()
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
# you do not need to implement the MovementOps
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)
@ -135,7 +186,9 @@ assert len(lazyop.src) == 2
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
assert lazyop.src[0].op.op == LoadOps.FROM
assert lazyop.src[0].op.src[0].device == "CPU"
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
assert (
lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2
), "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
# now we realize the LazyBuffer
@ -151,12 +204,15 @@ assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU,
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
# Interpreted backends are very simple (example: CPU and TORCH)
class Interpreted:
# and they have a lookup table to functions for the Ops
fxn_for_op: Dict[Op, Callable] = {
UnaryOps.EXP2: lambda x: np.exp2(x),
BinaryOps.ADD: lambda x,y: x+y}
BinaryOps.ADD: lambda x, y: x + y,
}
# Compiled backends take a little more (example: GPU and LLVM)
class Compiled:
@ -166,26 +222,40 @@ class Compiled:
# and a runtime, which runs the generated code
runtime: Type[Runtime]
# Runtime is what actually runs the kernels for a compiled backend
class Runtime(ABC):
# `name` is the name of the function, and `prg` is the code
# the constructor compiles the code
def __init__(self, name:str, prg:str): pass
def __init__(self, name: str, prg: str):
pass
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
def __call__(self, *bufs:List[Buffer], global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
def __call__(
self,
*bufs: List[Buffer],
global_size: Optional[List[int]],
local_size: Optional[List[int]],
):
pass
# %%
# == Buffer (in tinygrad/device.py, code 6/10) ==
import numpy as np
# Buffer is where the data is actually held. it's pretty close to just memory
class Buffer(ABC):
# create an empty rawbuffer that holds `size` elements of type `dtype`
# `opaque` is an opaque container class
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None): pass
def __init__(self, device: str, size: int, dtype: DType, opaque: Any = None):
pass
# toCPU converts the RawBuffer to a numpy array with shape (size,)
def toCPU(self) -> np.ndarray: pass
def toCPU(self) -> np.ndarray:
pass
# %%
# == Example: 2+3 in raw clang ==
@ -205,6 +275,7 @@ from tinygrad.runtime.ops_clang import ClangProgram, compile_clang
# then we copy the numpy in to RawMallocBuffers
# last, we create an empty output buffer
from tinygrad.helpers import dtypes
input_a, input_b = MallocAllocator.alloc(4), MallocAllocator.alloc(4)
output = MallocAllocator.alloc(4)
@ -214,7 +285,9 @@ MallocAllocator.copyin(input_a, numpy_a.data.cast("B"))
MallocAllocator.copyin(input_b, numpy_b.data.cast("B"))
# compile the program, run it, and 2+3 does indeed equal 5
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
program = ClangProgram(
"add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
)
program(output, input_a, input_b)
numpy_out = np.empty(1, dtype=np.float32)
MallocAllocator.copyout(numpy_out.data.cast("B"), output)
@ -229,7 +302,16 @@ np.testing.assert_allclose(numpy_out, numpy_a+numpy_b)
# the first step of transforming an AST into code is to "linearize" it, think like toposort on the AST
# for that, we use the Linearizer, which turns an AST into a list of (linear) UOps
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto();
class UOps(Enum):
LOOP = auto()
DEFINE_LOCAL = auto()
LOAD = auto()
ALU = auto()
CONST = auto()
ENDLOOP = auto()
STORE = auto()
class UOp:
uop: UOps
@ -238,26 +320,34 @@ class UOp:
arg: Any
num: int # UOps are unique
class Linearizer:
# create the kernel with the AST
# NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change
def __init__(self, ast:LazyOp): pass
def linearize(self): pass
def __init__(self, ast: LazyOp):
pass
def linearize(self):
pass
# when linearize is run, it fills in this list
uops: List[UOp]
from tinygrad.tensor import Tensor
result = Tensor(2).realize() + Tensor(3).realize()
# use the real Linearizer to linearize 2+3
from tinygrad.codegen.linearizer import Linearizer
sched = result.lazydata.schedule()
linearizer = Linearizer(sched[-1].ast)
linearizer.linearize()
# print the uops
for uop in linearizer.uops: print(uop)
for uop in linearizer.uops:
print(uop)
# output:
"""
@ -275,11 +365,13 @@ for uop in linearizer.uops: print(uop)
# here, we have an example where we fetch the generated code from the JIT
from tinygrad.tensor import Tensor
result = Tensor(2) + Tensor(3)
# we have a global cache used by the JIT
# from there, we can see the generated clang code
from tinygrad.jit import CacheCollector
CacheCollector.start() # enables the cache
result.realize() # create the program and runs it
cache_saved = CacheCollector.finish() # disable the cache
@ -319,7 +411,9 @@ print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
# we can then reshape it, and the strides change again
# note how the permute stays applied
a = a.reshape((5, 2, 5, 2))
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
print(
a
) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
a = a.reshape((100,))

View File

@ -49,14 +49,21 @@ b = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I",
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
# describe the computation
ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
ld_1 = LazyOp(
BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,)))
)
ld_2 = LazyOp(
BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,)))
)
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
st_0 = LazyOp(
BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,)))
)
# convert the computation to a "linearized" format (print the format)
lin = Device[DEVICE].get_linearizer(st_0).linearize()
for u in lin.uops: print(u)
for u in lin.uops:
print(u)
# compile a program (and print the source)
fxn = Device[DEVICE].to_program(lin)
@ -79,6 +86,7 @@ from tinygrad.realize import run_schedule
# allocate some values + load in values
# TODO: remove numpy here
import numpy as np
a = LazyBuffer.fromCPU(np.array([2], np.int32)).copy_to_device(DEVICE)
b = LazyBuffer.fromCPU(np.array([3], np.int32)).copy_to_device(DEVICE)
@ -87,10 +95,12 @@ out = a.e(BinaryOps.ADD, b)
# schedule the computation as a list of kernels
sched = out.schedule()
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
for si in sched:
print(si.ast.op) # NOTE: the first two convert it to CLANG
# DEBUGGING: print the compute ast as a tree
from tinygrad.graph import print_tree
print_tree(sched[-1].ast)
# NOTE: sched[-1].ast is the same as st_0 above

View File

@ -1,11 +1,14 @@
from typing import Tuple
import time
from tinygrad import Tensor, TinyJit, nn, Variable
from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype
from tinygrad.helpers import (
dtypes,
) # TODO: wouldn't need this if argmax returned the right dtype
import gymnasium as gym
from tqdm import trange
import numpy as np # TODO: remove numpy import
class ActorCritic:
def __init__(self, in_features, out_features, hidden_state=32):
self.l1 = nn.Linear(in_features, hidden_state)
@ -20,6 +23,7 @@ class ActorCritic:
x = self.c1(obs).relu()
return act, self.c2(x)
def evaluate(model: ActorCritic, test_env: gym.Env) -> float:
(obs, _), terminated, truncated = test_env.reset(), False, False
total_rew = 0.0
@ -29,21 +33,26 @@ def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
total_rew += float(rew)
return total_rew
# TODO: time should be < 5s on M1 Max
if __name__ == "__main__":
env = gym.make('CartPole-v1')
env = gym.make("CartPole-v1")
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
@TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
def train_step(
x: Tensor, selected_action: Tensor, reward: Tensor, old_log_dist: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train():
log_dist, value = model(x)
# get advantage
advantage = reward.reshape(-1, 1) - value
mask = selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)
mask = selected_action.reshape(-1, 1) == Tensor.arange(
log_dist.shape[1]
).reshape(1, -1).expand(selected_action.shape[0], -1)
masked_advantage = mask * advantage.detach()
# PPO
@ -51,7 +60,9 @@ if __name__ == "__main__":
clipped_ratios = ratios.clip(1 - 0.2, 1 + 0.2) * masked_advantage
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
entropy_loss = (
(log_dist.exp() * log_dist).sum(-1).mean()
) # this encourages diversity
critic_loss = advantage.square().mean()
opt.zero_grad()
(action_loss + entropy_loss * 0.0005 + critic_loss).backward()
@ -96,7 +107,11 @@ if __name__ == "__main__":
discounts = np.power(0.99, np.arange(len(rews)))
Rn += [np.sum(rews[i:] * discounts[: len(rews) - i]) for i in range(len(rews))]
Xn, An, Rn = Xn[-MAX_REPLAY_BUFFER:], An[-MAX_REPLAY_BUFFER:], Rn[-MAX_REPLAY_BUFFER:]
Xn, An, Rn = (
Xn[-MAX_REPLAY_BUFFER:],
An[-MAX_REPLAY_BUFFER:],
Rn[-MAX_REPLAY_BUFFER:],
)
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
# TODO: make this work
@ -105,10 +120,16 @@ if __name__ == "__main__":
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
for i in range(5):
samples = Tensor.randint(BS, high=X.shape[0]).realize() # TODO: remove the need for this
samples = Tensor.randint(
BS, high=X.shape[0]
).realize() # TODO: remove the need for this
# TODO: is this recompiling based on the shape?
action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}")
action_loss, entropy_loss, critic_loss = train_step(
X[samples], A[samples], R[samples], old_log_dist[samples]
)
t.set_description(
f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}"
)
test_rew = evaluate(model, gym.make('CartPole-v1', render_mode='human'))
test_rew = evaluate(model, gym.make("CartPole-v1", render_mode="human"))
print(f"test reward: {test_rew}")

View File

@ -4,18 +4,29 @@ from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from extra.datasets import fetch_mnist
from tqdm import trange
class Model:
def __init__(self):
self.layers: List[Callable[[Tensor], Tensor]] = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu,
nn.BatchNorm2d(32), Tensor.max_pool2d,
nn.Conv2d(32, 64, 3), Tensor.relu,
nn.Conv2d(64, 64, 3), Tensor.relu,
nn.BatchNorm2d(64), Tensor.max_pool2d,
lambda x: x.flatten(1), nn.Linear(576, 10)]
nn.Conv2d(1, 32, 5),
Tensor.relu,
nn.Conv2d(32, 32, 5),
Tensor.relu,
nn.BatchNorm2d(32),
Tensor.max_pool2d,
nn.Conv2d(32, 64, 3),
Tensor.relu,
nn.Conv2d(64, 64, 3),
Tensor.relu,
nn.BatchNorm2d(64),
Tensor.max_pool2d,
lambda x: x.flatten(1),
nn.Linear(576, 10),
]
def __call__(self, x: Tensor) -> Tensor:
return x.sequential(self.layers)
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
@ -29,17 +40,25 @@ if __name__ == "__main__":
with Tensor.train():
opt.zero_grad()
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
loss = (
model(X_train[samples])
.sparse_categorical_crossentropy(Y_train[samples])
.backward()
)
opt.step()
return loss.realize()
@TinyJit
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
def get_test_acc() -> Tensor:
return ((model(X_test).argmax(axis=1) == Y_test).mean() * 100).realize()
test_acc = float('nan')
test_acc = float("nan")
for i in (t := trange(70)):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
samples = Tensor.randint(
512, high=X_train.shape[0]
) # TODO: put this in the JIT when rand is fixed
loss = train_step(samples)
if i%10 == 9: test_acc = get_test_acc().item()
if i % 10 == 9:
test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")

View File

@ -10,9 +10,11 @@ from tinygrad.helpers import GlobalCounters
from tinygrad.helpers import getenv
from tinygrad.jit import CacheCollector
def tensors_allocated():
return sum(isinstance(x, Tensor) for x in gc.get_objects())
NUM = getenv("NUM", 2)
BS = getenv("BS", 8)
CNT = getenv("CNT", 10)
@ -25,9 +27,12 @@ if __name__ == "__main__":
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
parameters = get_parameters(model)
for p in parameters: p.realize()
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
else: optimizer = optim.SGD(parameters, lr=0.001)
for p in parameters:
p.realize()
if ADAM:
optimizer = optim.Adam(parameters, lr=0.001)
else:
optimizer = optim.SGD(parameters, lr=0.001)
Tensor.training = TRAINING
Tensor.no_grad = not BACKWARD
@ -42,7 +47,8 @@ if __name__ == "__main__":
st = time.monotonic()
out = model.forward(x_train)
loss = out.log_softmax().mul(y_train).mean()
if i == 2 and CLCACHE: CacheCollector.start()
if i == 2 and CLCACHE:
CacheCollector.start()
if BACKWARD:
optimizer.zero_grad()
loss.backward()
@ -54,7 +60,8 @@ if __name__ == "__main__":
et = time.monotonic()
else:
st = mt = time.monotonic()
for prg, args in cl_cache: prg(*args)
for prg, args in cl_cache:
prg(*args)
et = time.monotonic()
if i == 2 and CLCACHE:
@ -64,4 +71,6 @@ if __name__ == "__main__":
loss_cpu = loss.detach().numpy()
cl = time.monotonic()
print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
print(
f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
)

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
import os, sys, traceback
sys.path.append(os.getcwd())
from io import StringIO
@ -9,27 +10,46 @@ from tinygrad.helpers import Timing, colored, getenv, fetch
from extra.models.llama import Transformer, convert_from_huggingface
from sentencepiece import SentencePieceProcessor
def create_fixed_tokenizer(output_file):
print("creating fixed tokenizer")
import extra.junk.sentencepiece_model_pb2 as spb2
mp = spb2.ModelProto()
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
mp.ParseFromString(
fetch(
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true"
).read_bytes()
)
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
with open(output_file, "wb") as f:
f.write(mp.SerializeToString())
# TODO: make loading bf16 fast so we can remove this
def create_model_cache(output_file, model):
print(f"creating model cache at {output_file}")
# TODO: add read only Tensors
with Timing("download weights: "):
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
part1 = nn.state.torch_load(
fetch(
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"
)
)
part2 = nn.state.torch_load(
fetch(
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"
)
)
with Timing("weights -> model: "):
nn.state.load_state_dict(model, convert_from_huggingface(part1, model, 32, 8), strict=False)
nn.state.load_state_dict(model, convert_from_huggingface(part2, model, 32, 8), strict=False)
nn.state.load_state_dict(
model, convert_from_huggingface(part1, model, 32, 8), strict=False
)
nn.state.load_state_dict(
model, convert_from_huggingface(part2, model, 32, 8), strict=False
)
with Timing("saving float16 cache: "):
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
@ -37,27 +57,44 @@ def create_model_cache(output_file, model):
print("cache created, rerun to use")
exit(0)
if __name__ == "__main__":
Tensor.no_grad = True
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
with Timing("create model: "):
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
model = Transformer(
4096,
14336,
n_heads=32,
n_layers=32,
norm_eps=1e-5,
vocab_size=32002,
n_kv_heads=8,
max_context=4096,
)
cached_model = "/tmp/cached_openhermes.safetensors"
if not os.path.isfile(cached_model): create_model_cache(cached_model, model)
if not os.path.isfile(cached_model):
create_model_cache(cached_model, model)
with Timing("loading float16 cache: "):
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
if not os.path.isfile("/tmp/tokenizer.model"):
create_fixed_tokenizer("/tmp/tokenizer.model")
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
IM_END = 32000
IM_START = 32001
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
def encode_prompt(k, v):
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
def start_prompt(k):
return [IM_START] + spp.encode(f"{k}\n")
def output(outputted, toks, color):
cur = spp.decode(toks)[len(outputted) :]
sys.stdout.write(colored(cur, color))
@ -67,7 +104,10 @@ if __name__ == "__main__":
# *** app below this line ***
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
toks = [spp.bos_id()] + encode_prompt(
"system",
"You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input",
)
PROMPT = getenv("PROMPT", 1)
temperature = getenv("TEMP", 0.7)
@ -83,24 +123,34 @@ if __name__ == "__main__":
turn = not turn
old_output_len = len(outputted)
while 1:
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
tok = (
model(Tensor([toks[start_pos:]]), start_pos, temperature)
.multinomial()
.item()
)
start_pos = len(toks)
toks.append(tok)
outputted = output(outputted, toks, "blue" if not turn else "cyan")
if tok == IM_END: break
if tok == spp.eos_id(): break
if tok == IM_END:
break
if tok == spp.eos_id():
break
new_output = outputted[old_output_len:]
if new_output.endswith("```") and '```python\n' in new_output:
python_code = new_output.split('```python\n')[1].split("```")[0]
if new_output.endswith("```") and "```python\n" in new_output:
python_code = new_output.split("```python\n")[1].split("```")[0]
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
if (
input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower()
== "y"
):
my_stdout = StringIO()
try:
with redirect_stdout(my_stdout): exec(python_code)
with redirect_stdout(my_stdout):
exec(python_code)
result = my_stdout.getvalue()
except Exception as e:
result = ''.join(traceback.format_exception_only(e))
result = "".join(traceback.format_exception_only(e))
toks += spp.encode(f"\nOutput:\n```\n{result}```")
outputted = output(outputted, toks, "yellow")
old_output_len = len(outputted)

View File

@ -9,8 +9,16 @@ import ast
if __name__ == "__main__":
model = EfficientNet(0)
model.load_from_pretrained()
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
mode = (
"clang"
if getenv("CLANG", "") != ""
else "webgpu"
if getenv("WEBGPU", "") != ""
else ""
)
prg, inp_sizes, out_sizes, state = export_model(
model, mode, Tensor.randn(1, 3, 224, 224)
)
dirname = Path(__file__).parent
if getenv("CLANG", "") == "":
safe_save(state, (dirname / "net.safetensors").as_posix())
@ -20,19 +28,33 @@ if __name__ == "__main__":
else:
cprog = [prg]
# image library!
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")]
cprog += [
"#define STB_IMAGE_IMPLEMENTATION",
fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h")
.read_text()
.replace("half", "_half"),
]
# imagenet labels, move to datasets?
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
lbls = ast.literal_eval(
fetch(
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
).read_text()
)
lbls = ['"' + lbls[i] + '"' for i in range(1000)]
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
inputs = "\n".join(
[f"float {inp}[{inp_size}];" for inp, inp_size in inp_sizes.items()]
)
outputs = "\n".join(
[f"float {out}[{out_size}];" for out, out_size in out_sizes.items()]
)
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
cprog.append(inputs)
cprog.append(outputs)
# buffers (empty + weights)
cprog.append("""
cprog.append(
"""
int main(int argc, char* argv[]) {
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
int X=0, Y=0, chan=0;
@ -62,8 +84,9 @@ if __name__ == "__main__":
}
if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
else printf("%s\\n", lbls[best_idx]);
}""")
}"""
)
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
# category : 281 (tabby, tabby cat) with 9.452788
print('\n'.join(cprog))
print("\n".join(cprog))

View File

@ -1,8 +1,9 @@
# An example to compile a small Tensorflow model to extremely portable C code
import os, sys
os.environ["CLANG"] = '1'
os.environ["GPU"] = '1'
os.environ["CLANG"] = "1"
os.environ["GPU"] = "1"
import numpy as np
import subprocess
@ -12,32 +13,42 @@ from examples.compile_efficientnet import compile_net
from extra.onnx import get_run_onnx
from tinygrad.tensor import Tensor
def get_uncompiled_model2(dataset_size=32, output_size=4):
inputs = tf.keras.Input(shape=(dataset_size,), name="inputs")
x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x)
outputs = tf.keras.layers.Dense(output_size, activation="sigmoid", name="predictions")(x)
outputs = tf.keras.layers.Dense(
output_size, activation="sigmoid", name="predictions"
)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
def create_onnx_model(keras_model):
input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
input_signature = [tf.TensorSpec([1, 32], tf.float32, name="x")]
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
return onnx_model
def compile_onnx_model(onnx_model):
run_onnx = get_run_onnx(onnx_model)
from tinygrad.jit import TinyJit
@TinyJit
def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize()
def run(x):
return run_onnx({"x": x}, debug=False)["predictions"].realize()
the_input = Tensor.randn(1, 32)
the_output = run(the_input)
the_output = run(the_input)
special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"}
special_names = {
id(the_input.lazydata.realized.cl): "input",
id(the_output.lazydata.realized.cl): "outputs",
}
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
cprog = ["#include <string.h>", "#include <stdio.h>", "#include <stdlib.h>"] + cprog
@ -60,7 +71,8 @@ def compile_onnx_model(onnx_model):
cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"]
# test program
cprog.append(f"""int main(int argc, char *argv[]) {{
cprog.append(
f"""int main(int argc, char *argv[]) {{
// read in the weights from disk
FILE *f = fopen("/tmp/tf_weights", "rb");
float *weights = (float *)malloc({len(weights)});
@ -75,25 +87,38 @@ def compile_onnx_model(onnx_model):
for (int i = 0; i < 32; i++) scanf("%f", &input[i]);
float *outputs = infer(input);
printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]);
}}""")
}}"""
)
# ready the program
prg = '\n'.join(cprog)
prg = "\n".join(cprog)
print(prg)
# add test weights
subprocess.check_output(['clang', '-O2', '-lm', '-fPIC', '-x', 'c', '-', '-o', "/tmp/tf_test"], input=prg.encode('utf-8'))
subprocess.check_output(
["clang", "-O2", "-lm", "-fPIC", "-x", "c", "-", "-o", "/tmp/tf_test"],
input=prg.encode("utf-8"),
)
tinygrad_output = [x for x in the_output.numpy()[0]]
print("tinygrad:", tinygrad_output, file=sys.stderr)
c_input = ' '.join(["%f" % x for x in the_input[0].numpy()])+"\n"
c_output = [float(x) for x in subprocess.check_output(["/tmp/tf_test"], input=c_input.encode('utf-8')).decode('utf-8').strip().split(" ")]
c_input = " ".join(["%f" % x for x in the_input[0].numpy()]) + "\n"
c_output = [
float(x)
for x in subprocess.check_output(
["/tmp/tf_test"], input=c_input.encode("utf-8")
)
.decode("utf-8")
.strip()
.split(" ")
]
print("compiled:", c_output, file=sys.stderr)
np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5)
return the_input.numpy(), c_output
if __name__ == "__main__":
keras_model = get_uncompiled_model2()
onnx_model = create_onnx_model(keras_model)
@ -101,4 +126,3 @@ if __name__ == "__main__":
tf_output = keras_model(test_input).numpy()[0]
print("keras: ", tf_output, file=sys.stderr)
np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)

View File

@ -12,7 +12,14 @@ import pyaudio
import yaml
from llama import LLaMa
from vits import MODELS as VITS_MODELS
from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
from vits import (
Y_LENGTH_ESTIMATE_SCALARS,
HParams,
Synthesizer,
TextMapper,
get_hparams_from_file,
load_model,
)
from whisper import init_whisper, transcribe_waveform
from sentencepiece import SentencePieceProcessor
@ -29,16 +36,26 @@ IM_END = 32002
# Functions for encoding prompts to chatml md
def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
def encode_prompt(spp, k, v):
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
def start_prompt(spp, k):
return [IM_START] + spp.encode(f"{k}\n")
def chunks(lst, n):
for i in range(0, len(lst), n): yield lst[i:i + n]
for i in range(0, len(lst), n):
yield lst[i : i + n]
def create_fixed_tokenizer():
"""Function needed for extending tokenizer with additional chat tokens"""
import extra.junk.sentencepiece_model_pb2 as spb2
tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
tokenizer_path = fetch(
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model"
)
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
print("creating fixed tokenizer")
mp = spb2.ModelProto()
@ -50,16 +67,28 @@ def create_fixed_tokenizer():
tokenizer_path.write_bytes(mp.SerializeToString())
return tokenizer_path
def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
def llama_prepare(
llama: LLaMa, temperature: float, pre_prompt_path: Path
) -> tuple[list[int], str, str, str]:
"""Prepares a llama model from a specified pre-prompt file"""
with open(str(pre_prompt_path)) as f:
config = yaml.safe_load(f.read())
toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
toks = [llama.tokenizer.bos_id()] + encode_prompt(
llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ")
)
for i in config["examples"]:
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
return (
toks,
config["user_delim"],
config["resp_delim"],
len(toks),
llama.tokenizer.decode(toks),
)
def llama_generate(
llama: LLaMa,
@ -70,7 +99,7 @@ def llama_generate(
user_delim: str,
resp_delim: str,
temperature=0.7,
max_tokens=1000
max_tokens=1000,
):
"""Generates an output for the specified prompt"""
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
@ -79,7 +108,9 @@ def llama_generate(
outputted = llama.tokenizer.decode(toks)
init_length = len(outputted)
for _ in range(max_tokens):
probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy()
probs_np = llama.model(
Tensor([toks[start_pos:]]), start_pos, temperature
).numpy()
token = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)
toks.append(token)
@ -90,12 +121,14 @@ def llama_generate(
sys.stdout.write(cur[len(outputted) :])
sys.stdout.flush()
outputted = cur
if toks[-1] == IM_END: break
if toks[-1] == IM_END:
break
else:
toks.append(IM_END)
print() # because the output is flushed
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
def tts(
text_to_synthesize: str,
synth: Synthesizer,
@ -110,24 +143,45 @@ def tts(
text_mapper: TextMapper,
model_has_multiple_speakers: bool,
batch_size=600,
vits_batch_size=1000
vits_batch_size=1000,
):
if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
if model_to_use == "mmts-tts":
text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
# Convert the input text to a tensor.
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
stn_tst = text_mapper.get_text(
text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners
)
init_shape = stn_tst.shape
assert init_shape[0] < batch_size, "text is too long"
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(
0
), Tensor([init_shape[0]], dtype=dtypes.int64)
sid = (
Tensor([speaker_id], dtype=dtypes.int64)
if model_has_multiple_speakers
else None
)
# Perform inference.
audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0]
audio_tensor = synth.infer(
x_tst,
x_tst_lengths,
sid,
noise_scale,
length_scale,
noise_scale_w,
emotion_embedding=emotion_embedding,
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use]
if estimate_max_y_length
else None,
batch_size=vits_batch_size,
)[0, 0]
# Save the audio output.
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
return audio_data
def init_vits(
model_to_use: str,
emotion_path: Path,
@ -142,21 +196,44 @@ def init_vits(
# If model has multiple speakers, validate speaker id and retrieve name if available.
model_has_multiple_speakers = hps.data.n_speakers > 0
if model_has_multiple_speakers:
if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
if speaker_id >= hps.data.n_speakers:
raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
if hps.__contains__("speakers"): # maps speaker ids to names
speakers = hps.speakers
if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
if isinstance(speakers, list):
speakers = {speaker: i for i, speaker in enumerate(speakers)}
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
emotion_embedding = None
if emotion_path is not None:
if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
else: raise ValueError("Emotion path must be a .npy file.")
if emotion_path.endswith(".npy"):
emotion_embedding = Tensor(
np.load(emotion_path), dtype=dtypes.int64
).unsqueeze(0)
else:
raise ValueError("Emotion path must be a .npy file.")
# Load symbols, instantiate TextMapper and clean the text.
if hps.__contains__("symbols"): symbols = hps.symbols
elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'")
if hps.__contains__("symbols"):
symbols = hps.symbols
elif model_to_use == "mmts-tts":
symbols = [
x.replace("\n", "")
for x in fetch(
"https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt"
)
.open(encoding="utf-8")
.readlines()
]
else:
symbols = (
["_"]
+ list(';:,.!?¡¿—…"«»“” ')
+ list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")
+ list(
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
)
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
# Load the model.
@ -168,18 +245,23 @@ def init_vits(
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
@contextmanager
def output_stream(num_channels: int, sample_rate: int):
try:
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
stream = p.open(
format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True
)
yield stream
except KeyboardInterrupt: pass
except KeyboardInterrupt:
pass
finally:
stream.stop_stream()
stream.close()
p.terminate()
@contextmanager
def log_writer():
try:
@ -191,10 +273,17 @@ def log_writer():
print(*logs, sep="\n")
print(sep)
def listener(q: mp.Queue, event: mp.Event):
try:
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=RATE,
input=True,
frames_per_buffer=CHUNK,
)
did_print = False
while True:
data = stream.read(CHUNK) # read data to avoid overflow
@ -210,7 +299,10 @@ def listener(q: mp.Queue, event: mp.Event):
stream.close()
p.terminate()
def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
def mp_output_stream(
q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int
):
with output_stream(num_channels, sample_rate) as stream:
while True:
try:
@ -219,8 +311,10 @@ def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_r
except KeyboardInterrupt:
break
if __name__ == "__main__":
import nltk
nltk.download("punkt")
Tensor.no_grad = True
# Parse CLI arguments
@ -230,75 +324,212 @@ if __name__ == "__main__":
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
# LLAMA args
parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory")
parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
parser.add_argument(
"--llama_pre_prompt_path",
type=Path,
default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml",
help="Path to yaml file which contains all pre-prompt data needed. ",
)
parser.add_argument(
"--llama_count", type=int, default=1000, help="Max number of tokens to generate"
)
parser.add_argument(
"--llama_temperature",
type=float,
default=0.7,
help="Temperature in the softmax",
)
parser.add_argument(
"--llama_quantize",
action="store_true",
help="Quantize the weights to int8 in memory",
)
parser.add_argument(
"--llama_model",
type=Path,
default=None,
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
)
parser.add_argument(
"--llama_gen",
type=str,
default="tiny",
required=False,
help="Generation of the model to use",
)
parser.add_argument(
"--llama_size",
type=str,
default="1B-Chat",
required=False,
help="Size of model to use",
)
parser.add_argument(
"--llama_tokenizer",
type=Path,
default=None,
required=False,
help="Path to llama tokenizer.model",
)
# vits args
parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
parser.add_argument(
"--vits_model_to_use",
default="vctk",
help="Specify the model to use. Default is 'vctk'.",
)
parser.add_argument(
"--vits_speaker_id",
type=int,
default=12,
help="Specify the speaker ID. Default is 6.",
)
parser.add_argument(
"--vits_noise_scale",
type=float,
default=0.667,
help="Specify the noise scale. Default is 0.667.",
)
parser.add_argument(
"--vits_noise_scale_w",
type=float,
default=0.8,
help="Specify the noise scale w. Default is 0.8.",
)
parser.add_argument(
"--vits_length_scale",
type=float,
default=1,
help="Specify the length scale. Default is 1.",
)
parser.add_argument(
"--vits_seed",
type=int,
default=None,
help="Specify the seed (set to None if no seed). Default is 1337.",
)
parser.add_argument(
"--vits_num_channels",
type=int,
default=1,
help="Specify the number of audio output channels. Default is 1.",
)
parser.add_argument(
"--vits_sample_width",
type=int,
default=2,
help="Specify the number of bytes per sample, adjust if necessary. Default is 2.",
)
parser.add_argument(
"--vits_emotion_path",
type=Path,
default=None,
help="Specify the path to emotion reference.",
)
parser.add_argument(
"--vits_estimate_max_y_length",
type=str,
default=False,
help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.",
)
parser.add_argument(
"--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary."
)
# conversation args
parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
parser.add_argument(
"--max_sentence_length",
type=int,
default=20,
help="Max words in one sentence to pass to vits",
)
args = parser.parse_args()
# Init models
model, enc = init_whisper(args.whisper_model_name)
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(
args.vits_model_to_use,
args.vits_emotion_path,
args.vits_speaker_id,
args.vits_seed,
)
# Download tinyllama chat as a default model
if args.llama_model is None:
args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
args.llama_model = fetch(
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors",
"tinyllamachat.safetensors",
)
args.llama_gen = "tiny"
args.llama_size = "1B-Chat"
# Add 3 more tokens to the tokenizer
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"):
args.llama_tokenizer = create_fixed_tokenizer()
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
llama = LLaMa.build(
args.llama_model,
tokenizer_path,
args.llama_gen,
args.llama_size,
args.llama_quantize,
)
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(
llama, args.llama_temperature, args.llama_pre_prompt_path
)
# Start child process for mic input
q = mp.Queue()
is_listening_event = mp.Event()
p = mp.Process(target=listener, args=(q, is_listening_event,))
p = mp.Process(
target=listener,
args=(
q,
is_listening_event,
),
)
p.daemon = True
p.start()
# Start child process for speaker output
out_q = mp.Queue()
out_counter = mp.Value("i", 0)
out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
out_p = mp.Process(
target=mp_output_stream,
args=(
out_q,
out_counter,
args.vits_num_channels,
hps.data.sampling_rate,
),
)
out_p.daemon = True
out_p.start()
# JIT tts
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
tts(
i, synth, hps, emotion_embedding,
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
args.vits_noise_scale_w, args.vits_length_scale,
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
i,
synth,
hps,
emotion_embedding,
args.vits_speaker_id,
args.vits_model_to_use,
args.vits_noise_scale,
args.vits_noise_scale_w,
args.vits_length_scale,
args.vits_estimate_max_y_length,
text_mapper,
model_has_multiple_speakers,
)
# Start the pipeline
with log_writer() as log:
while True:
tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
tokens = [
enc._special_tokens["<|startoftranscript|>"],
enc._special_tokens["<|notimestamps|>"],
]
total = np.array([])
out_counter.value = 0
@ -306,10 +537,12 @@ if __name__ == "__main__":
is_listening_event.set()
prev_text = None
while True:
for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
for _ in range(RATE // CHUNK):
total = np.concatenate([total, q.get()])
txt = transcribe_waveform(model, enc, [total], truncate=True)
print(txt, end="\r")
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()):
continue
if prev_text is not None and prev_text == txt:
is_listening_event.clear()
break
@ -320,9 +553,15 @@ if __name__ == "__main__":
# Generate with llama
with Timing("llama generation: "):
outputted, start_pos, response = llama_generate(
llama, toks, outputted, txt, start_pos,
user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
max_tokens=args.llama_count
llama,
toks,
outputted,
txt,
start_pos,
user_delim=user_delim,
resp_delim=resp_delim,
temperature=args.llama_temperature,
max_tokens=args.llama_count,
)
log.append(f"{resp_delim.capitalize()}: {response}")
@ -333,12 +572,21 @@ if __name__ == "__main__":
total = np.array([], dtype=np.int16)
for j in chunks(i.split(), args.max_sentence_length):
audio_data = tts(
" ".join(j), synth, hps, emotion_embedding,
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
args.vits_noise_scale_w, args.vits_length_scale,
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
" ".join(j),
synth,
hps,
emotion_embedding,
args.vits_speaker_id,
args.vits_model_to_use,
args.vits_noise_scale,
args.vits_noise_scale_w,
args.vits_length_scale,
args.vits_estimate_max_y_length,
text_mapper,
model_has_multiple_speakers,
)
total = np.concatenate([total, audio_data])
out_q.put(total.tobytes())
while out_counter.value < len(sentences): continue
while out_counter.value < len(sentences):
continue
log.append(f"Total: {time.perf_counter() - s}")

View File

@ -11,12 +11,14 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, fetch, Timing
from tinygrad.jit import TinyJit
from extra.models.efficientnet import EfficientNet
np.set_printoptions(suppress=True)
# TODO: you should be able to put these in the jitted function
bias = Tensor([0.485, 0.456, 0.406])
scale = Tensor([0.229, 0.224, 0.225])
@TinyJit
def _infer(model, img):
img = img.permute((2, 0, 1))
@ -25,10 +27,13 @@ def _infer(model, img):
img = img / scale.reshape((1, -1, 1, 1))
return model.forward(img).realize()
def infer(model, img):
# preprocess image
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
img = img.resize(
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
)
img = np.array(img)
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
@ -52,18 +57,28 @@ def infer(model, img):
"""
return out, retimg
if __name__ == "__main__":
# instantiate my net
model = EfficientNet(getenv("NUM", 0))
model.load_from_pretrained()
# category labels
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
lbls = ast.literal_eval(
fetch(
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
).read_text()
)
# load image and preprocess
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
if url == 'webcam':
url = (
sys.argv[1]
if len(sys.argv) >= 2
else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
)
if url == "webcam":
import cv2
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
while 1:
@ -72,12 +87,17 @@ if __name__ == "__main__":
img = Image.fromarray(frame[:, :, [2, 1, 0]])
lt = time.monotonic_ns()
out, retimg = infer(model, img)
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
print(
f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms",
np.argmax(out),
np.max(out),
lbls[np.argmax(out)],
)
SCALE = 3
simg = cv2.resize(retimg, (224 * SCALE, 224 * SCALE))
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
cv2.imshow('capture', retimg)
if cv2.waitKey(1) & 0xFF == ord('q'):
cv2.imshow("capture", retimg)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cap.release()
cv2.destroyAllWindows()

View File

@ -3,6 +3,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad import Device
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
def bit_extract(x, s, e) -> Tensor:
# extract the top bits we don't want
@ -10,11 +11,16 @@ def bit_extract(x, s, e) -> Tensor:
x = (x - top_bits) / (1 << e)
return x.contiguous()
def u16_to_f16(x):
sign = bit_extract(x, 15, 15).float()
exponent = bit_extract(x, 14, 10).float()
fraction = bit_extract(x, 9, 0).float()
return sign.where(-1, 1) * exponent.where((exponent - 15).exp2() * (1 + fraction / 0x400), 6.103515625e-5 * (fraction / 0x400))
return sign.where(-1, 1) * exponent.where(
(exponent - 15).exp2() * (1 + fraction / 0x400),
6.103515625e-5 * (fraction / 0x400),
)
def u32_to_f16(oo):
oo1 = (oo / 0x10000).floor().contiguous()
@ -24,6 +30,7 @@ def u32_to_f16(oo):
f2 = u16_to_f16(oo2)
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
if __name__ == "__main__":
# random float16
Tensor.manual_seed(2)

View File

@ -10,11 +10,20 @@ from tinygrad.shape.symbolic import Variable
from tinygrad.jit import TinyJit
import tiktoken
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
from tinygrad.helpers import (
GlobalCounters,
Timing,
DEBUG,
getenv,
fetch,
colored,
dtypes,
)
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
HALF = getenv("HALF")
class Attention:
def __init__(self, dim, n_heads):
self.c_attn = Linear(dim, 3 * dim, bias=True)
@ -23,18 +32,27 @@ class Attention:
self.dim = dim
self.head_dim = dim // n_heads
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
def __call__(
self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]
) -> Tensor:
if mask is not None:
# no symbolic shape qkv when consuming prompts
start_pos = start_pos.val
xqkv = self.c_attn(x)
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
xq, xk, xv = [
xqkv.shrink((None, None, (i * self.dim, (i + 1) * self.dim))).reshape(
xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim
)
for i in range(3)
]
bsz, seqlen, n_heads, head_dim = xq.shape
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
self.cache_k, self.cache_v = Tensor.zeros(
bsz, MAX_CONTEXT, self.n_heads, self.head_dim
), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
if HALF:
self.cache_k = self.cache_k.half()
self.cache_v = self.cache_v.half()
@ -43,11 +61,28 @@ class Attention:
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_k.assign(
keys.pad(
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
).contiguous()
).realize()
self.cache_v.assign(
values.pad(
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
).contiguous()
).realize()
xq, keys, values = (
xq.transpose(1, 2),
keys.transpose(1, 2),
values.transpose(1, 2),
)
return self.c_proj(
xq.scaled_dot_product_attention(keys, values, mask)
.transpose(1, 2)
.reshape(bsz, seqlen, -1)
)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
class FeedForward:
def __init__(self, dim, hidden_dim):
@ -57,6 +92,7 @@ class FeedForward:
def __call__(self, x: Tensor) -> Tensor:
return self.c_proj(self.c_fc(x).gelu())
class TransformerBlock:
def __init__(self, dim, n_heads, norm_eps):
self.attn = Attention(dim, n_heads)
@ -66,7 +102,8 @@ class TransformerBlock:
def __call__(self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]):
h = x + self.attn(self.ln_1(x), start_pos, mask)
return (h + self.mlp(self.ln_2(h)))
return h + self.mlp(self.ln_2(h))
class Transformer:
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
@ -78,7 +115,8 @@ class Transformer:
self.forward_jit = TinyJit(self.forward)
def forward(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0):
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
if not hasattr(self, "allpos"):
self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
_bsz, seqlen = tokens.shape
# NOTE: cannot convert token indices into half due to precision
@ -86,44 +124,73 @@ class Transformer:
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos + seqlen))))
h = tok_emb + pos_emb
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
mask = (
Tensor.full((1, 1, seqlen, start_pos.val + seqlen), float("-inf"))
.triu(start_pos.val + 1)
.realize()
if seqlen > 1
else None
)
if HALF:
h = h.half()
if mask is not None: mask = mask.half()
if mask is not None:
mask = mask.half()
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
for hi in self.h:
h = hi(h, start_pos=start_pos, mask=mask)
logits = self.lm_head(self.ln_f(h))
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().realize()
# TODO: fix empty token
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
def __call__(
self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0
) -> Tensor:
return (
self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward
)(tokens, start_pos, temperature)
VOCAB_SIZE = 50257
MODEL_PARAMS = {
'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params
'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params
'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
"gpt2": dict(
n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE
), # 124M params
"gpt2-medium": dict(
n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE
), # 350M params
"gpt2-large": dict(
n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE
), # 774M params
"gpt2-xl": dict(
n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE
), # 1558M params
}
class GPT2:
@staticmethod
def build(model_size="gpt2"):
tokenizer = tiktoken.get_encoding("gpt2")
model = Transformer(**MODEL_PARAMS[model_size])
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
weights = torch_load(
fetch(f"https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin")
)
# special treatment for the Conv1D weights we need to transpose
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
transposed = [
"attn.c_attn.weight",
"attn.c_proj.weight",
"mlp.c_fc.weight",
"mlp.c_proj.weight",
]
for k in weights.keys():
if any(k.endswith(w) for w in transposed):
weights[k] = Tensor(weights[k].numpy().T)
# lm head and wte are tied
weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy())
weights["lm_head.weight"] = Tensor(weights["wte.weight"].numpy())
load_state_dict(model, weights)
return GPT2(model, tokenizer)
@ -132,42 +199,98 @@ class GPT2:
self.model = model
self.tokenizer = tokenizer
def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
def greedy_until(
self,
prompt: str,
max_length: int,
temperature: float,
timing: bool = False,
batch_size: int = 1,
):
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
toks = [prompt_tokens[:] for _ in range(batch_size)]
start_pos = 0
for _ in trange(max_length, disable=(timing == True)):
GlobalCounters.reset()
if timing: print("")
if timing:
print("")
st = GlobalCounters.time_sum_s
with Timing("total ", enabled=timing):
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
probs = self.model(Tensor([x[start_pos:] for x in toks]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
with Timing(
"ran model in ",
on_exit=(
lambda et: (
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
if DEBUG >= 2
else ""
)
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
+ (
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s"
if DEBUG >= 2
else ""
)
)
if DEBUG
else None,
enabled=timing,
):
probs = self.model(
Tensor([x[start_pos:] for x in toks]),
Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(
start_pos
),
temperature,
)
# TODO: fix JIT rand so we can put this in the JIT
tok = probs.multinomial().flatten().numpy().tolist()
start_pos = len(toks[0])
for i,t in enumerate(tok): toks[i].append(t)
for i, t in enumerate(tok):
toks[i].append(t)
output = [self.tokenizer.decode(x) for x in toks]
return output
# **** main code ****
if __name__ == "__main__":
Tensor.no_grad = True
print(f"using {Device.DEFAULT} backend")
parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--prompt', type=str, default="What is the answer to life, the universe, and everything?", help="Phrase to start with")
parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
parser.add_argument('--timing', action='store_true', help="Print timing per token")
parser.add_argument('--seed', type=int, help="Set the random seed")
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
parser = argparse.ArgumentParser(
description="Run GPT2 in tinygrad",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--prompt",
type=str,
default="What is the answer to life, the universe, and everything?",
help="Phrase to start with",
)
parser.add_argument(
"--count", type=int, default=100, help="Max number of tokens to generate"
)
parser.add_argument(
"--temperature", type=float, default=0.8, help="Temperature in the softmax"
)
parser.add_argument(
"--model_size",
type=str,
default="gpt2-medium",
help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]",
)
parser.add_argument("--timing", action="store_true", help="Print timing per token")
parser.add_argument("--seed", type=int, help="Set the random seed")
parser.add_argument(
"--batch_size", type=int, default=1, help="Set the input batch size"
)
parser.add_argument(
"--benchmark",
type=int,
default=-1,
help="Benchmark GPT with the given number of tokens",
)
parser.add_argument("--noshow", action="store_true", help="Don't show the output")
args = parser.parse_args()
if args.seed is not None:
@ -182,11 +305,22 @@ if __name__ == "__main__":
l.assign(l.cast(dtypes.float16).realize())
if args.benchmark != -1:
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
gpt2.model(
Tensor.rand(args.batch_size, args.benchmark),
Variable("a", 0, MAX_CONTEXT).bind(0),
).realize()
else:
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
texts = gpt2.greedy_until(
args.prompt,
args.count,
args.temperature,
timing=args.timing,
batch_size=args.batch_size,
)
if not args.noshow:
print('Generating text...')
if len(texts) == 1: print(texts[0])
print("Generating text...")
if len(texts) == 1:
print(texts[0])
else:
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
for i, text in enumerate(texts):
print(colored(f"Response {i}:", "green"), text)

View File

@ -28,7 +28,8 @@ if __name__ == "__main__":
sched = [x for x in sched if x.ast.op not in LoadOps]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
if getenv("KERNEL", -1) >= 0:
sched = sched[getenv("KERNEL", -1) : getenv("KERNEL", -1) + 1]
# work with the schedule
total_tm = 0
@ -52,20 +53,33 @@ if __name__ == "__main__":
# try a beam search
if getenv("BEAM"):
lin = Linearizer(si.ast, device.linearizer_opts)
lin = beam_search(lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1)))
lin = beam_search(
lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1))
)
lins.append(lin)
# benchmark the programs
choices = []
for lin in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
gflops = (
sym_infer(lin.info.flops, {k: k.min for k in vars_from_ast(lin.ast)})
* 1e-9
/ tm
)
choices.append((tm, gflops, lin.linearize()))
# print all kernels
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
if DEBUG >= 1:
print(
f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
)
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
print(
f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
)
total_tm += tm
running_gflops += gflops * tm
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
print(
f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS"
)

View File

@ -2,6 +2,7 @@
# setup for distributed
from extra import dist
from tinygrad.helpers import getenv, dtypes
if __name__ == "__main__":
if getenv("DIST"):
dist.preinit()
@ -24,7 +25,7 @@ from tinygrad.shape.symbolic import Node
from extra.lr_scheduler import OneCycleLR
from tinygrad.jit import TinyJit
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv("EVAL_BS", 500), getenv("STEPS", 1000)
if getenv("HALF", 0):
Tensor.default_type = dtypes.float16
@ -33,16 +34,28 @@ else:
Tensor.default_type = dtypes.float32
np_dtype = np.float32
class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features):
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
super().__init__(
num_features,
track_running_stats=False,
eps=1e-12,
momentum=0.85,
affine=True,
)
self.weight.requires_grad = False
self.bias.requires_grad = True
class ConvGroup:
def __init__(self, channels_in, channels_out):
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
self.conv1 = nn.Conv2d(
channels_in, channels_out, kernel_size=3, padding=1, bias=False
)
self.conv2 = nn.Conv2d(
channels_out, channels_out, kernel_size=3, padding=1, bias=False
)
self.norm1 = BatchNorm(channels_out)
self.norm2 = BatchNorm(channels_out)
@ -63,6 +76,7 @@ class ConvGroup:
return x + residual
class SpeedyResNet:
def __init__(self, W):
self.whitening = W
@ -74,54 +88,58 @@ class SpeedyResNet:
ConvGroup(256, 512),
lambda x: x.max((2, 3)),
nn.Linear(512, 10, bias=False),
lambda x: x.mul(1./9)
lambda x: x.mul(1.0 / 9),
]
def __call__(self, x, training=True):
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
# TODO: remove the pad but instead let the kernel optimizer itself
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
forward = (
lambda x: x.conv2d(self.whitening).pad2d((1, 0, 0, 1)).sequential(self.net)
)
return (
forward(x) if training else forward(x) * 0.5 + forward(x[..., ::-1]) * 0.5
)
def train_cifar():
# hyper-parameters were exactly the same as the original repo
bias_scaler = 58
hyp: Dict[str, Any] = {
'seed' : 209,
'opt': {
'bias_lr': 1.76 * bias_scaler/512,
'non_bias_lr': 1.76 / 512,
'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
'non_bias_decay': 1.08 * 6.45e-4 * BS,
'final_lr_ratio': 0.025,
'initial_div_factor': 1e16,
'label_smoothing': 0.20,
'momentum': 0.85,
'percent_start': 0.23,
'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16)
"seed": 209,
"opt": {
"bias_lr": 1.76 * bias_scaler / 512,
"non_bias_lr": 1.76 / 512,
"bias_decay": 1.08 * 6.45e-4 * BS / bias_scaler,
"non_bias_decay": 1.08 * 6.45e-4 * BS,
"final_lr_ratio": 0.025,
"initial_div_factor": 1e16,
"label_smoothing": 0.20,
"momentum": 0.85,
"percent_start": 0.23,
"loss_scale_scaler": 1.0 / 128, # (range: ~1/512 - 16+, 1/128 w/ FP16)
},
'net': {
'kernel_size': 2, # kernel size for the whitening layer
'cutmix_size': 3,
'cutmix_steps': 499,
'pad_amount': 2
"net": {
"kernel_size": 2, # kernel size for the whitening layer
"cutmix_size": 3,
"cutmix_steps": 499,
"pad_amount": 2,
},
"ema": {
"steps": 399,
"decay_base": 0.95,
"decay_pow": 1.6,
"every_n_steps": 5,
},
'ema': {
'steps': 399,
'decay_base': .95,
'decay_pow': 1.6,
'every_n_steps': 5,
}
}
def set_seed(seed):
Tensor.manual_seed(getenv('SEED', seed))
random.seed(getenv('SEED', seed))
Tensor.manual_seed(getenv("SEED", seed))
random.seed(getenv("SEED", seed))
# ========== Model ==========
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
def whitening(X, kernel_size=hyp['net']['kernel_size']):
def whitening(X, kernel_size=hyp["net"]["kernel_size"]):
def _cov(X):
X = X / np.sqrt(X.shape[0] - 1)
return X.T @ X
@ -130,12 +148,18 @@ def train_cifar():
h, w = patch_size
c = data.shape[1]
axis: SupportsIndex = (2, 3) # type: ignore
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
return (
np.lib.stride_tricks.sliding_window_view(
data, window_shape=(h, w), axis=axis
)
.transpose((0, 3, 2, 1, 4, 5))
.reshape((-1, c, h, w))
)
def _eigens(patches):
n, c, h, w = patches.shape
Σ = _cov(patches.reshape(n, c * h * w))
Λ, V = np.linalg.eigh(Σ, UPLO='U')
Λ, V = np.linalg.eigh(Σ, UPLO="U")
return np.flip(Λ, 0), np.flip(V.T.reshape(c * h * w, c, h, w), 0)
Λ, V = _eigens(_patches(X.numpy()))
@ -144,12 +168,16 @@ def train_cifar():
return Tensor(W.astype(np_dtype), requires_grad=False)
# ========== Loss ==========
def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
def cross_entropy(
x: Tensor, y: Tensor, reduction: str = "mean", label_smoothing: float = 0.0
) -> Tensor:
divisor = y.shape[1]
assert not isinstance(divisor, Node), "sint not supported as divisor"
y = (1 - label_smoothing) * y + label_smoothing / divisor
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
if reduction == "none":
return -x.log_softmax(axis=1).mul(y).sum(axis=1)
if reduction == "sum":
return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
# ========== Preprocessing ==========
@ -161,12 +189,20 @@ def train_cifar():
p = padding[3]
s = X.shape[3]
X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0)))
X_lr = X[..., :, 1 : 1 + p[0]].flip(3).pad(
((0, 0), (0, 0), (0, 0), (0, s + p[0]))
) + X[..., :, -1 - p[1] : -1].flip(3).pad(
((0, 0), (0, 0), (0, 0), (s + p[1], 0))
)
X = X.pad(((0, 0), (0, 0), (0, 0), p)) + X_lr
p = padding[2]
s = X.shape[2]
X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0)))
X_lr = X[..., 1 : 1 + p[0], :].flip(2).pad(
((0, 0), (0, 0), (0, s + p[0]), (0, 0))
) + X[..., -1 - p[1] : -1, :].flip(2).pad(
((0, 0), (0, 0), (s + p[1], 0), (0, 0))
)
X = X.pad(((0, 0), (0, 0), p, (0, 0))) + X_lr
return X
@ -176,10 +212,18 @@ def train_cifar():
is_even = int(mask_size % 2 == 0)
center_max = shape[-2] - mask_size // 2 - is_even
center_min = mask_size // 2 - is_even
center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
center_x = (
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
).floor()
center_y = (
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
).floor()
d_x = Tensor.arange(0, shape[-1]).reshape(
(1, 1, 1, shape[-1])
) - center_x.reshape((-1, 1, 1, 1))
d_y = Tensor.arange(0, shape[-2]).reshape(
(1, 1, shape[-2], 1)
) - center_y.reshape((-1, 1, 1, 1))
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
mask = d_y * d_x
@ -200,7 +244,7 @@ def train_cifar():
Y_patch = Tensor(Y.numpy()[order])
X_cutmix = Tensor.where(mask, X_patch, X)
mix_portion = float(mask_size**2) / (X.shape[-2] * X.shape[-1])
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
Y_cutmix = mix_portion * Y_patch + (1.0 - mix_portion) * Y
return X_cutmix, Y_cutmix
# the operations that remain inside batch fetcher is the ones that involves random operations
@ -213,11 +257,16 @@ def train_cifar():
random.shuffle(order)
if is_train:
X = random_crop(X, crop_size=32)
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
X = Tensor.where(
Tensor.rand(X.shape[0], 1, 1, 1) < 0.5, X[..., ::-1], X
) # flip LR
if step >= hyp["net"]["cutmix_steps"]:
X, Y = cutmix(X, Y, mask_size=hyp["net"]["cutmix_size"])
X, Y = X.numpy(), Y.numpy()
et = time.monotonic()
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
print(
f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})"
)
for i in range(0, X.shape[0], BS):
# pad the last batch
batch_end = min(i + BS, Y.shape[0])
@ -226,18 +275,24 @@ def train_cifar():
step += 1
yield x, y
cnt += 1
if not is_train: break
if not is_train:
break
transform = [
lambda x: x / 255.0,
lambda x: (x.reshape((-1,3,32,32)) - Tensor(cifar_mean).reshape((1,3,1,1)))/Tensor(cifar_std).reshape((1,3,1,1))
lambda x: (
x.reshape((-1, 3, 32, 32)) - Tensor(cifar_mean).reshape((1, 3, 1, 1))
)
/ Tensor(cifar_std).reshape((1, 3, 1, 1)),
]
class modelEMA():
class modelEMA:
def __init__(self, w, net):
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
self.net_ema = SpeedyResNet(w)
for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
for net_ema_param, net_param in zip(
get_state_dict(self.net_ema).values(), get_state_dict(net).values()
):
net_ema_param.requires_grad = False
net_ema_param.assign(net_param.numpy())
@ -245,23 +300,37 @@ def train_cifar():
def update(self, net, decay):
# TODO with Tensor.no_grad()
Tensor.no_grad = True
for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
for net_ema_param, (param_name, net_param) in zip(
get_state_dict(self.net_ema).values(), get_state_dict(net).items()
):
# batchnorm currently is not being tracked
if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
if not ("num_batches_tracked" in param_name) and not (
"running" in param_name
):
net_ema_param.assign(
net_ema_param.detach() * decay
+ net_param.detach() * (1.0 - decay)
).realize()
Tensor.no_grad = False
set_seed(hyp['seed'])
set_seed(hyp["seed"])
# this import needs to be done here because this is running in a subprocess
from extra.dist import OOB
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
X_train, Y_train, X_test, Y_test = fetch_cifar()
# load data and label into GPU and convert to dtype accordingly
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
X_train, X_test = (
X_train.to(device=Device.DEFAULT).float(),
X_test.to(device=Device.DEFAULT).float(),
)
Y_train, Y_test = (
Y_train.to(device=Device.DEFAULT).float(),
Y_test.to(device=Device.DEFAULT).float(),
)
# one-hot encode labels
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
# preprocess data
@ -274,10 +343,15 @@ def train_cifar():
model = SpeedyResNet(W)
# padding is not timed in the original repo since it can be done all at once
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
X_train = pad_reflect(X_train, size=hyp["net"]["pad_amount"])
# Convert data and labels to the default dtype
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
X_train, Y_train, X_test, Y_test = (
X_train.cast(Tensor.default_type),
Y_train.cast(Tensor.default_type),
X_test.cast(Tensor.default_type),
Y_test.cast(Tensor.default_type),
)
# parse the training params into bias and non-bias
params_dict = get_state_dict(model)
@ -285,26 +359,60 @@ def train_cifar():
params_non_bias = []
for params in params_dict:
if params_dict[params].requires_grad is not False:
if 'bias' in params:
if "bias" in params:
params_bias.append(params_dict[params])
else:
params_non_bias.append(params_dict[params])
opt_bias = optim.SGD(params_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
opt_non_bias = optim.SGD(params_non_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
opt_bias = optim.SGD(
params_bias,
lr=0.01,
momentum=hyp["opt"]["momentum"],
nesterov=True,
weight_decay=hyp["opt"]["bias_decay"],
)
opt_non_bias = optim.SGD(
params_non_bias,
lr=0.01,
momentum=hyp["opt"]["momentum"],
nesterov=True,
weight_decay=hyp["opt"]["non_bias_decay"],
)
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
initial_div_factor = hyp['opt']['initial_div_factor']
final_lr_ratio = hyp['opt']['final_lr_ratio']
pct_start = hyp['opt']['percent_start']
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
initial_div_factor = hyp["opt"]["initial_div_factor"]
final_lr_ratio = hyp["opt"]["final_lr_ratio"]
pct_start = hyp["opt"]["percent_start"]
lr_sched_bias = OneCycleLR(
opt_bias,
max_lr=hyp["opt"]["bias_lr"],
pct_start=pct_start,
div_factor=initial_div_factor,
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
total_steps=STEPS,
)
lr_sched_non_bias = OneCycleLR(
opt_non_bias,
max_lr=hyp["opt"]["non_bias_lr"],
pct_start=pct_start,
div_factor=initial_div_factor,
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
total_steps=STEPS,
)
loss_batchsize_scaler = 512 / BS
@TinyJit
def train_step_jitted(model, optimizer, lr_scheduler, X, Y):
out = model(X)
loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
loss = (
cross_entropy(
out, Y, reduction="none", label_smoothing=hyp["opt"]["label_smoothing"]
)
.mul(hyp["opt"]["loss_scale_scaler"] * loss_batchsize_scaler)
.sum()
.div(hyp["opt"]["loss_scale_scaler"])
)
if not getenv("DISABLE_BACKWARD"):
# index 0 for bias and 1 for non-bias
@ -316,11 +424,16 @@ def train_cifar():
# sync gradients across ranks
bucket, offset = [], 0
for _, v in params_dict.items():
if v.grad is not None: bucket.append(v.grad.flatten())
if v.grad is not None:
bucket.append(v.grad.flatten())
grads = collectives.allreduce(Tensor.cat(*bucket))
for _, v in params_dict.items():
if v.grad is not None:
v.grad.assign(grads[offset:offset+v.grad.numel()].reshape(*v.grad.shape))
v.grad.assign(
grads[offset : offset + v.grad.numel()].reshape(
*v.grad.shape
)
)
offset += v.grad.numel()
optimizer[0].step()
@ -331,9 +444,10 @@ def train_cifar():
def eval_step(model, X, Y):
out = model(X, training=False)
loss = cross_entropy(out, Y, reduction='mean')
loss = cross_entropy(out, Y, reduction="mean")
correct = out.argmax(axis=1) == Y.argmax(axis=1)
return correct.realize(), loss.realize()
eval_step_jitted = TinyJit(eval_step)
eval_step_ema_jitted = TinyJit(eval_step)
@ -347,7 +461,7 @@ def train_cifar():
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
model_ema: Optional[modelEMA] = None
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
projected_ema_decay_val = hyp["ema"]["decay_base"] ** hyp["ema"]["every_n_steps"]
i = 0
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
with Tensor.train():
@ -363,24 +477,37 @@ def train_cifar():
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
# further split batch if distributed
if getenv("DIST"):
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
Xt, Yt = (
Xt.chunk(min(world_size, 5), 0)[min(rank, 4)],
Yt.chunk(min(world_size, 5), 0)[min(rank, 4)],
)
correct, loss = eval_step_jitted(model, Xt, Yt)
losses.append(loss.numpy().tolist())
corrects.extend(correct.numpy().tolist())
if model_ema:
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
correct_ema, loss_ema = eval_step_ema_jitted(
model_ema.net_ema, Xt, Yt
)
losses_ema.append(loss_ema.numpy().tolist())
corrects_ema.extend(correct_ema.numpy().tolist())
# collect accuracy across ranks
correct_sum, correct_len = sum(corrects), len(corrects)
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
if model_ema:
correct_sum_ema, correct_len_ema = sum(corrects_ema), len(
corrects_ema
)
if getenv("DIST"):
if rank == 0:
for j in range(1, min(world_size, 5)):
if model_ema:
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
(
recv_sum,
recv_len,
recv_sum_ema,
recv_len_ema,
) = OOB.recv(j)
else:
recv_sum, recv_len = OOB.recv(j)
correct_sum += recv_sum
@ -390,55 +517,95 @@ def train_cifar():
correct_len_ema += recv_len_ema
elif rank < min(world_size, 5):
if model_ema:
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
OOB.send(
(
correct_sum,
correct_len,
correct_sum_ema,
correct_len_ema,
),
0,
)
else:
OOB.send((correct_sum, correct_len), 0)
# only rank 0 prints
if rank == 0:
acc = correct_sum / correct_len * 100.0
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
if model_ema:
acc_ema = correct_sum_ema / correct_len_ema * 100.0
print(
f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)"
)
if model_ema:
print(
f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}"
)
if STEPS == 0 or i==STEPS: break
if STEPS == 0 or i == STEPS:
break
X, Y = next(batcher)
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
GlobalCounters.reset()
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
loss = train_step_jitted(
model,
[opt_bias, opt_non_bias],
[lr_sched_bias, lr_sched_non_bias],
X,
Y,
)
et = time.monotonic()
loss_cpu = loss.numpy()
# EMA for network weights
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
if i > hyp["ema"]["steps"] and (i + 1) % hyp["ema"]["every_n_steps"] == 0:
if model_ema is None:
model_ema = modelEMA(W, model)
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
model_ema.update(
model,
Tensor(
[
projected_ema_decay_val
* (i / STEPS) ** hyp["ema"]["decay_pow"]
]
),
)
cl = time.monotonic()
if not getenv("DIST"):
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
print(
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
)
else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
print(
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
)
st = cl
i += 1
if __name__ == "__main__":
if not getenv("DIST"):
train_cifar()
else: # distributed
if getenv("HIP"):
from tinygrad.runtime.ops_hip import HIP
devices = [f"hip:{i}" for i in range(HIP.device_count)]
else:
from tinygrad.runtime.ops_gpu import CLDevice
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
world_size = len(devices)
# ensure that the batch size is divisible by the number of devices
assert BS % world_size == 0, f"batch size {BS} is not divisible by world size {world_size}"
assert (
BS % world_size == 0
), f"batch size {BS} is not divisible by world size {world_size}"
# ensure that the evaluation batch size is divisible by the number of devices
assert EVAL_BS % min(world_size, 5) == 0, f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
assert (
EVAL_BS % min(world_size, 5) == 0
), f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
# init out-of-band communication
dist.init_oob(world_size)
@ -447,4 +614,5 @@ if __name__ == "__main__":
processes = []
for rank, device in enumerate(devices):
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
for p in processes: p.join()
for p in processes:
p.join()

View File

@ -6,6 +6,7 @@
from pathlib import Path
import sys, argparse, json
import numpy as np
np.set_printoptions(linewidth=200)
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes
from tinygrad import Device
@ -24,84 +25,225 @@ MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
MODEL_PARAMS = {
"1": {
"7B": {
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 11008},
"args": {
"dim": 4096,
"n_heads": 32,
"n_layers": 32,
"norm_eps": 1e-06,
"vocab_size": 32000,
"hidden_dim": 11008,
},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 13824},
"args": {
"dim": 5120,
"n_heads": 40,
"n_layers": 40,
"norm_eps": 1e-06,
"vocab_size": 32000,
"hidden_dim": 13824,
},
"files": 2,
},
"30B": {
"args": {"dim": 6656, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 17920},
"args": {
"dim": 6656,
"n_heads": 52,
"n_layers": 60,
"norm_eps": 1e-06,
"vocab_size": 32000,
"hidden_dim": 17920,
},
"files": 4,
},
"65B": {
"args": {"dim": 8192, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 22016},
"args": {
"dim": 8192,
"n_heads": 64,
"n_layers": 80,
"norm_eps": 1e-05,
"vocab_size": 32000,
"hidden_dim": 22016,
},
"files": 8,
},
},
"2": {
"7B": {
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 11008},
"args": {
"dim": 4096,
"n_heads": 32,
"n_layers": 32,
"norm_eps": 1e-05,
"vocab_size": 32000,
"hidden_dim": 11008,
},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 13824},
"args": {
"dim": 5120,
"n_heads": 40,
"n_layers": 40,
"norm_eps": 1e-05,
"vocab_size": 32000,
"hidden_dim": 13824,
},
"files": 2,
},
"70B": {
"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 28672},
"args": {
"dim": 8192,
"n_heads": 64,
"n_kv_heads": 8,
"n_layers": 80,
"norm_eps": 1e-05,
"vocab_size": 32000,
"hidden_dim": 28672,
},
"files": 8,
},
},
"code": {
"7B": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
"args": {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32016,
"hidden_dim": 11008,
},
"files": 1,
},
"7B-Python": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 11008},
"args": {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32000,
"hidden_dim": 11008,
},
"files": 1,
},
"7B-Instruct": {
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
"args": {
"dim": 4096,
"n_layers": 32,
"n_heads": 32,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32016,
"hidden_dim": 11008,
},
"files": 1,
},
"13B": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
"args": {
"dim": 5120,
"n_layers": 40,
"n_heads": 40,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32016,
"hidden_dim": 13824,
},
"files": 2,
},
"13B-Python": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 13824},
"args": {
"dim": 5120,
"n_layers": 40,
"n_heads": 40,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32000,
"hidden_dim": 13824,
},
"files": 2,
},
"13B-Instruct": {
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
"args": {
"dim": 5120,
"n_layers": 40,
"n_heads": 40,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32016,
"hidden_dim": 13824,
},
"files": 2,
},
"34B": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
"args": {
"dim": 8192,
"n_layers": 48,
"n_heads": 64,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32000,
"hidden_dim": 22016,
},
"files": 4,
},
"34B-Python": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
"args": {
"dim": 8192,
"n_layers": 48,
"n_heads": 64,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32000,
"hidden_dim": 22016,
},
"files": 4,
},
"34B-Instruct": {
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
"args": {
"dim": 8192,
"n_layers": 48,
"n_heads": 64,
"n_kv_heads": 8,
"norm_eps": 1e-05,
"rope_theta": 1000000,
"vocab_size": 32000,
"hidden_dim": 22016,
},
"files": 4,
},
},
"tiny": {
"1B": {
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632},
"args": {
"dim": 2048,
"n_layers": 22,
"n_heads": 32,
"n_kv_heads": 4,
"norm_eps": 1e-05,
"vocab_size": 32000,
"hidden_dim": 5632,
},
"files": 1,
},
"1B-Chat": {
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632},
"args": {
"dim": 2048,
"n_layers": 22,
"n_heads": 32,
"n_kv_heads": 4,
"norm_eps": 1e-05,
"vocab_size": 32003,
"hidden_dim": 5632,
},
"files": 1,
}
}
},
},
}
@ -111,21 +253,37 @@ def concat_weights(models):
disk_tensors = [model[name] for model in models]
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
return disk_tensors[0].to(device=Device.DEFAULT)
axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
axis = (
1
if name.startswith("tok_embeddings.")
or name.endswith(".attention.wo.weight")
or name.endswith(".feed_forward.w2.weight")
else 0
)
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
return {name: convert(name) for name in {name: None for model in models for name in model}}
return {
name: convert(name)
for name in {name: None for model in models for name in model}
}
def load(fn: str):
if fn.endswith('.index.json'):
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
if fn.endswith(".index.json"):
with open(fn) as fp:
weight_map = json.load(fp)["weight_map"]
parts = {
n: load(str(Path(fn).parent / Path(n).name))
for n in set(weight_map.values())
}
return {k: parts[n][k] for k, n in weight_map.items()}
elif fn.endswith(".safetensors"):
return safe_load(fn)
else:
return torch_load(fn)
class AbsmaxQuantizedLinear:
def __init__(self, in_features, out_features, bias=False):
assert bias == False
@ -139,34 +297,63 @@ class AbsmaxQuantizedLinear:
def quantize(tensors):
new_tensors = {}
for name, v in tensors.items():
if "feed_forward" in name or ("attention.w") in name or name == "output.weight":
if (
"feed_forward" in name
or ("attention.w") in name
or name == "output.weight"
):
scale = v.abs().max(axis=1) / 127.0
int8_weight = (v.T / scale).T.cast(dtype=dtypes.int8)
new_tensors[name] = int8_weight
new_tensors[name.replace('weight', 'scale')] = scale
new_tensors[name.replace("weight", "scale")] = scale
else:
new_tensors[name] = v
return new_tensors
class LLaMa:
@staticmethod
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
def build(
model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False
):
params = MODEL_PARAMS[model_gen][model_size]
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
assert (
sp_model.vocab_size() == params["args"]["vocab_size"]
), f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
model = (
Transformer(
**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT
)
if quantize
else Transformer(**params["args"], max_context=MAX_CONTEXT)
)
if model_path.is_dir():
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
weights = concat_weights(
[
load(filename)
for filename in [
f"{model_path}/consolidated.{i:02d}.pth"
for i in range(params["files"])
]
]
)
else:
weights = load(str(model_path))
if "model.embed_tokens.weight" in weights:
weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
weights = convert_from_huggingface(
weights,
model,
params["args"]["n_heads"],
params["args"].get("n_kv_heads", params["args"]["n_heads"]),
)
if quantize:
weights = AbsmaxQuantizedLinear.quantize(weights)
for _,v in weights.items(): v.realize()
for _, v in weights.items():
v.realize()
load_state_dict(model, weights, strict=False)
return LLaMa(model, sp_model)
@ -179,18 +366,23 @@ class LLaMa:
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
start_pos = 0
for i in range(max_length):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).realize()
probs = llama.model(
Tensor([toks[start_pos:]]), start_pos, temperature
).realize()
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
start_pos = len(toks)
toks.append(tok)
if tok == self.tokenizer.eos_id(): break
if tok == self.tokenizer.eos_id():
break
output = self.tokenizer.decode(toks)
for s in until:
if output.endswith(s): return output[0:-len(s)]
if output.endswith(s):
return output[0 : -len(s)]
return output
# **** main code ****
"""
test:
@ -256,21 +448,58 @@ if __name__ == "__main__":
Tensor.no_grad = True
print(f"using {Device.DEFAULT} backend")
parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate")
parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
parser = argparse.ArgumentParser(
description="Run LLaMA in tinygrad",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--prompt",
type=str,
default=None,
help="Phrase to start with. Without this, it goes into chatbot mode",
)
parser.add_argument(
"--count", type=int, default=1000, help="Max number of tokens to generate"
)
parser.add_argument(
"--personality",
type=str,
default="Stacy",
help="Personality, can be Stacy, George, Gary, or Lexie",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Temperature in the softmax"
)
parser.add_argument("--timing", action="store_true", help="Print timing per token")
parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof")
parser.add_argument("--gen", default="1", help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""")
parser.add_argument("--size", type=str, default=None, help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""")
parser.add_argument("--quantize", action="store_true", help="Quantize the weights to int8 in memory")
parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
parser.add_argument(
"--profile", action="store_true", help="Output profile data to out.prof"
)
parser.add_argument(
"--gen",
default="1",
help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""",
)
parser.add_argument(
"--size",
type=str,
default=None,
help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""",
)
parser.add_argument(
"--quantize", action="store_true", help="Quantize the weights to int8 in memory"
)
parser.add_argument(
"--model",
type=Path,
default=None,
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
)
args = parser.parse_args()
if args.gen not in MODEL_PARAMS: raise ValueError("Invalid model generation")
if args.size is None: args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
if args.gen not in MODEL_PARAMS:
raise ValueError("Invalid model generation")
if args.size is None:
args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
chatbot = args.prompt == None
# *** prompt engineers work here ****
@ -294,9 +523,13 @@ After you are done speaking, output [EOS]. You are not the User.
user_delim = "\nUser: "
resp_delim = "Stacy: "
end_delim = " [EOS]\n"
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
pre_prompt += "".join(
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
)
elif args.personality.lower() == "george":
print("WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter.")
print(
"WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter."
)
pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User
You are an AI version of George Hotz. You act as much as you can like George.
You are one of the greatest computer experts in the world.
@ -312,13 +545,15 @@ After you are done speaking, output [EOS]. You are not the User.
"What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm",
"What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer",
"How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version",
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk"
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk",
}
user_delim = "\nUser: "
resp_delim = "George: "
end_delim = " [EOS]\n"
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
pre_prompt += "".join(
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
)
elif args.personality.lower() == "gary":
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User
You are Gary!
@ -331,13 +566,15 @@ After you are done speaking, output [EOS]. You are not the User.
"""
examples = {
"What is your name?": "I am Gary. I used to sell cars.",
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla"
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla",
}
user_delim = "\nUser: "
resp_delim = "Gary: "
end_delim = " [EOS]\n"
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
pre_prompt += "".join(
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
)
elif args.personality.lower() == "lexie":
pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad
You are Lexie!
@ -352,21 +589,34 @@ After you are done speaking, output [EOS]. You are not Chad.
examples = {
"hi lexie": "hi chad, glad we finally met up!",
"you look better than your pictures": "thanks! are you subscribed to my onlyfans?",
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress"
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress",
}
user_delim = "\nChad: "
resp_delim = "Lexie: "
end_delim = " [EOS]\n"
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
pre_prompt += "".join(
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
)
# *** prompt engineers stop here ****
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen]
MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
MODEL_PATH = (
args.model
or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
)
TOKENIZER_PATH = (
MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent
) / "tokenizer.model"
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
llama = LLaMa.build(
MODEL_PATH,
TOKENIZER_PATH,
model_gen=args.gen,
model_size=args.size,
quantize=args.quantize,
)
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
if chatbot:
@ -375,7 +625,9 @@ After you are done speaking, output [EOS]. You are not Chad.
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
with Timing():
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: outputs are not used
llama.model(
Tensor([toks]), 0, args.temperature
).realize() # NOTE: outputs are not used
start_pos = len(toks)
else:
# non chat bot mode
@ -403,14 +655,37 @@ After you are done speaking, output [EOS]. You are not Chad.
for i in range(args.count):
GlobalCounters.reset()
if args.timing or args.profile: print("")
if args.timing or args.profile:
print("")
st = GlobalCounters.time_sum_s
with Profiling(enabled=args.profile):
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
with Timing(
"total ",
enabled=args.timing,
on_exit=lambda x: f", {1e9/x:.2f} tok/sec",
):
with Timing(
"ran model in ",
on_exit=(
lambda et: (
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
if DEBUG >= 2
else ""
)
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
+ (
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s"
if DEBUG >= 2
else ""
)
)
if DEBUG
else None,
enabled=args.timing,
):
probs = llama.model(
Tensor([toks[start_pos:]]), start_pos, args.temperature
).realize()
# TODO: fix JIT rand so we can put this in the JIT
tok = probs.multinomial().item()
@ -427,5 +702,7 @@ After you are done speaking, output [EOS]. You are not Chad.
outputted = cur
# stop after you have your answer
if chatbot and outputted.endswith(end_delim): break
if not chatbot: break
if chatbot and outputted.endswith(end_delim):
break
if not chatbot:
break

View File

@ -63,21 +63,23 @@ class Normalize:
image = Ft.normalize(image, mean=self.mean, std=self.std)
return image
transforms = lambda size_scale: T.Compose(
[
Resize(int(800 * size_scale), int(1333 * size_scale)),
T.ToTensor(),
Normalize(
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_bgr255=True
),
]
)
def expand_boxes(boxes, scale):
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
w_half *= scale
h_half *= scale
@ -118,7 +120,7 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
mask = mask.expand((1, 1, -1, -1))
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
mask = mask[0][0]
if thresh >= 0:
@ -169,11 +171,13 @@ class Masker:
masker = Masker(threshold=0.5, padding=1)
def select_top_predictions(predictions, confidence_threshold=0.9):
scores = predictions.get_field("scores").numpy()
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
return predictions[keep]
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
image = transforms(size_scale)(original_image).numpy()
image = Tensor(image, requires_grad=False)
@ -189,6 +193,7 @@ def compute_prediction(original_image, model, confidence_threshold, size_scale=1
prediction.add_field("mask", masks)
return prediction
def compute_prediction_batched(batch, model, size_scale=1.0):
imgs = []
for img in batch:
@ -198,21 +203,25 @@ def compute_prediction_batched(batch, model, size_scale=1.0):
del image
return predictions
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
def findContours(*args, **kwargs):
if cv2.__version__.startswith('4'):
if cv2.__version__.startswith("4"):
contours, hierarchy = cv2.findContours(*args, **kwargs)
elif cv2.__version__.startswith('3'):
elif cv2.__version__.startswith("3"):
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
return contours, hierarchy
def compute_colors_for_labels(labels):
l = labels[:, None]
colors = l * palette
colors = (colors % 255).astype("uint8")
return colors
def overlay_mask(image, predictions):
image = np.asarray(image)
masks = predictions.get_field("mask").numpy()
@ -231,17 +240,92 @@ def overlay_mask(image, predictions):
return composite
CATEGORIES = [
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
"__background",
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
def overlay_boxes(image, predictions):
labels = predictions.get_field("labels").numpy()
boxes = predictions.bbox
@ -258,6 +342,7 @@ def overlay_boxes(image, predictions):
return image
def overlay_class_names(image, predictions):
scores = predictions.get_field("scores").numpy().tolist()
labels = predictions.get_field("labels").numpy().tolist()
@ -269,26 +354,35 @@ def overlay_class_names(image, predictions):
x, y = box[:2]
s = template.format(label, score)
x, y = int(x), int(y)
cv2.putText(
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
)
cv2.putText(image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--image', type=str, help="Path of the image to run")
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run MaskRCNN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--image", type=str, help="Path of the image to run")
parser.add_argument(
"--threshold", type=float, default=0.7, help="Detector threshold"
)
parser.add_argument(
"--size_scale", type=float, default=1.0, help="Image resize multiplier"
)
parser.add_argument(
"--out", type=str, default="/tmp/rendered.png", help="Output filename"
)
args = parser.parse_args()
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
model_tiny = MaskRCNN(resnet)
model_tiny.load_from_pretrained()
img = Image.open(args.image)
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
top_result_tiny = compute_prediction(
img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale
)
bbox_image = overlay_boxes(img, top_result_tiny)
mask_image = overlay_mask(bbox_image, top_result_tiny)
final_image = overlay_class_names(mask_image, top_result_tiny)

View File

@ -3,6 +3,7 @@ import unicodedata
import numpy as np
from scipy import signal
def gaussian_kernel(n, std):
gaussian_1d = signal.gaussian(n, std)
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
@ -12,14 +13,18 @@ def gaussian_kernel(n, std):
gaussian_3d /= gaussian_3d.max()
return gaussian_3d
def prepare_arrays(image, roi_shape=(128, 128, 128)):
assert len(roi_shape) == 3 and any(roi_shape)
image_shape = list(image.shape[2:])
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
norm_map = np.zeros_like(result)
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(
norm_map.dtype
)
return result, norm_map, norm_patch
def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
assert len(roi_shape) == 3 and any(roi_shape)
assert 0 < overlap_factor < 1
@ -31,25 +36,35 @@ def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
for k in range(0, strides[2] * size[2], strides[2]):
yield i, j, k
def _get_best_indices(logits, n_best_size):
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
return list(map(lambda x: x[0], index_and_score))[:n_best_size]
def _is_punctuation(char):
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
if (
(cp := ord(char)) in range(33, 48)
or cp in range(58, 65)
or cp in range(91, 97)
or cp in range(123, 127)
):
return True
return unicodedata.category(char).startswith("P")
def _is_whitespace(char):
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
return unicodedata.category(char) == "Zs"
def _is_control(char):
if char == "\t" or char == "\n" or char == "\r":
return False
return unicodedata.category(char).startswith("C")
def _run_split_on_punc(text):
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
return [text]
@ -66,6 +81,7 @@ def _run_split_on_punc(text):
output[-1].append(char)
return ["".join(x) for x in output]
def _run_strip_accents(text):
output = []
for char in unicodedata.normalize("NFD", text):
@ -73,13 +89,15 @@ def _run_strip_accents(text):
output.append(char)
return "".join(output)
def _clean_text(text):
output = []
for char in text:
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
if not ((cp := ord(char)) == 0 or cp == 0xFFFD or _is_control(char)):
output.append(" " if _is_whitespace(char) else char)
return "".join(output)
def _get_final_text(pred_text, orig_text):
def _strip_spaces(text):
ns_text = ""
@ -128,32 +146,46 @@ def _get_final_text(pred_text, orig_text):
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
return output_text
def get_bert_qa_prediction(features, example, start_end_logits):
prelim_predictions = []
for i, feature in enumerate(features):
for start_index in _get_best_indices(start_end_logits[i][0], 20):
for end_index in _get_best_indices(start_end_logits[i][1], 20):
if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]):
if start_index >= len(feature["tokens"]) or end_index >= len(
feature["tokens"]
):
continue
if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]:
if (
start_index not in feature["token_to_orig_map"]
or end_index not in feature["token_to_orig_map"]
):
continue
if not feature["token_is_max_context"].get(start_index, False):
continue
if end_index < start_index or end_index - start_index + 1 > 30:
continue
prelim_predictions.append({
prelim_predictions.append(
{
"feature_index": i,
"start_index": start_index,
"end_index": end_index,
"start_logit": start_end_logits[i][0, start_index],
"end_logit": start_end_logits[i][1, end_index]
})
predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True)
"end_logit": start_end_logits[i][1, end_index],
}
)
predictions = sorted(
prelim_predictions,
key=lambda x: (x["start_logit"] + x["end_logit"]),
reverse=True,
)
if len(predictions) > 0:
feature = features[predictions[0]["feature_index"]]
tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)]
tok_tokens = feature["tokens"][
predictions[0]["start_index"] : (predictions[0]["end_index"] + 1)
]
orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
orig_tokens = example["context"][orig_doc_start : (orig_doc_end + 1)]

View File

@ -3,6 +3,7 @@ import string
from collections import Counter
import numpy as np
def levenshtein(a, b):
n, m = len(a), len(b)
if n > m:
@ -20,6 +21,7 @@ def levenshtein(a, b):
return current[n]
def word_error_rate(x, y):
scores = words = 0
for h, r in zip(x, y):
@ -29,12 +31,14 @@ def word_error_rate(x, y):
scores += levenshtein(h_list, r_list)
return float(scores) / words, float(scores), words
def one_hot(arr, num_classes=3):
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
arr = res.reshape(list(arr.shape) + [num_classes])
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
return arr
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
prediction = prediction.argmax(axis=channel_axis)
@ -42,14 +46,18 @@ def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr
intersection = np.sum(prediction * target, axis=reduce_axis)
target_sum = np.sum(target, axis=reduce_axis)
prediction_sum = np.sum(prediction, axis=reduce_axis)
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
result = (2.0 * intersection + smooth_nr) / (
target_sum + prediction_sum + smooth_dr
)
return result[0]
def normalize_string(s):
s = "".join(c for c in s.lower() if c not in string.punctuation)
s = re.sub(r'\b(a|an|the)\b', ' ', s)
s = re.sub(r"\b(a|an|the)\b", " ", s)
return " ".join(s.split())
def f1_score(x, y):
xt = normalize_string(x).split()
yt = normalize_string(y).split()

View File

@ -6,15 +6,18 @@ from tinygrad.jit import TinyJit
from tinygrad.helpers import getenv, dtypes, GlobalCounters
from examples.mlperf import helpers
def eval_resnet():
# Resnet50-v1.5
from tinygrad.jit import TinyJit
from extra.models.resnet import ResNet50
mdl = ResNet50()
mdl.load_from_pretrained()
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
def input_fixup(x):
x = x.permute([0, 3, 1, 2]).cast(dtypes.float32) / 255.0
x -= input_mean
@ -48,14 +51,18 @@ def eval_resnet():
et = time.perf_counter()
n += (t == y).sum()
d += len(t)
print(f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS")
print(
f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS"
)
st = time.perf_counter()
def eval_unet3d():
# UNet3D
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, sliding_window_inference
from examples.mlperf.metrics import get_dice_score
mdl = UNet3D()
mdl.load_from_pretrained()
s = 0
@ -69,15 +76,18 @@ def eval_unet3d():
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
st = time.perf_counter()
def eval_retinanet():
# RetinaNet with ResNeXt50_32X4D
from extra.models.resnet import ResNeXt50_32X4D
from extra.models.retinanet import RetinaNet
mdl = RetinaNet(ResNeXt50_32X4D())
mdl.load_from_pretrained()
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
def input_fixup(x):
x = x.permute([0, 3, 1, 2]) / 255.0
x -= input_mean
@ -88,11 +98,18 @@ def eval_retinanet():
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from contextlib import redirect_stdout
coco = COCO(openimages())
coco_eval = COCOeval(coco, iouType="bbox")
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
coco_evalimgs, evaluated_imgs, ncats, narea = (
[],
[],
len(coco_eval.params.catIds),
len(coco_eval.params.areaRng),
)
from tinygrad.jit import TinyJit
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
n, bs = 0, 8
@ -106,19 +123,35 @@ def eval_retinanet():
mdlrun.jit_cache = None
outs = mdl(input_fixup(dat)).numpy()
et = time.perf_counter()
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
predictions = mdl.postprocess_detections(
outs,
input_size=dat.shape[1:3],
orig_image_sizes=[t["image_size"] for t in targets],
)
ext = time.perf_counter()
n += len(targets)
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
print(
f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing"
)
img_ids = [t["image_id"] for t in targets]
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score}
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
coco_results = [
{
"image_id": targets[i]["image_id"],
"category_id": label,
"bbox": box,
"score": score,
}
for i, prediction in enumerate(predictions)
for box, score, label in zip(*prediction.values())
]
with redirect_stdout(None):
coco_eval.cocoDt = coco.loadRes(coco_results)
coco_eval.params.imgIds = img_ids
coco_eval.evaluate()
evaluated_imgs.extend(img_ids)
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
coco_evalimgs.append(
np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids))
)
st = time.perf_counter()
coco_eval.params.imgIds = evaluated_imgs
@ -127,16 +160,47 @@ def eval_retinanet():
coco_eval.accumulate()
coco_eval.summarize()
def eval_rnnt():
# RNN-T
from extra.models.rnnt import RNNT
mdl = RNNT()
mdl.load_from_pretrained()
from extra.datasets.librispeech import iterate
from examples.mlperf.metrics import word_error_rate
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
LABELS = [
" ",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
"'",
]
c = 0
scores = 0
@ -149,16 +213,20 @@ def eval_rnnt():
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
for n, t in enumerate(tt):
tnp = np.array(t)
_, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
_, scores_, words_ = word_error_rate(
["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]]
)
scores += scores_
words += words_
c += len(tt)
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
st = time.perf_counter()
def eval_bert():
# Bert-QA
from extra.models.bert import BertForQuestionAnswering
mdl = BertForQuestionAnswering()
mdl.load_from_pretrained()
@ -180,9 +248,17 @@ def eval_bert():
mt = time.perf_counter()
outs = []
for x in X:
outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy())
outs.append(
run(
Tensor(x["input_ids"]),
Tensor(x["input_mask"]),
Tensor(x["segment_ids"]),
).numpy()
)
et = time.perf_counter()
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features")
print(
f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features"
)
pred = get_bert_qa_prediction(X, Y, outs)
print(f"pred: {pred}\nans: {Y['answers']}")
@ -192,17 +268,27 @@ def eval_bert():
st = time.perf_counter()
def eval_mrcnn():
from tqdm import tqdm
from extra.models.mask_rcnn import MaskRCNN
from extra.models.resnet import ResNet
from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
from extra.datasets.coco import (
BASEDIR,
images,
convert_prediction_to_coco_bbox,
convert_prediction_to_coco_mask,
accumulate_predictions_for_coco,
evaluate_predictions_on_coco,
iterate,
)
from examples.mask_rcnn import compute_prediction_batched, Image
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
mdl.load_from_pretrained()
bbox_output = '/tmp/results_bbox.json'
mask_output = '/tmp/results_mask.json'
bbox_output = "/tmp/results_bbox.json"
mask_output = "/tmp/results_mask.json"
accumulate_predictions_for_coco([], bbox_output, rm=True)
accumulate_predictions_for_coco([], mask_output, rm=True)
@ -213,12 +299,12 @@ def eval_mrcnn():
for batch in tqdm(iterate(images, bs=bs), total=len(images) // bs):
batch_imgs = []
for image_row in batch:
image_name = image_row['file_name']
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
image_name = image_row["file_name"]
img = Image.open(BASEDIR / f"val2017/{image_name}").convert("RGB")
batch_imgs.append(img)
batch_result = compute_prediction_batched(batch_imgs, mdl)
for image_row, result in zip(batch, batch_result):
image_name = image_row['file_name']
image_name = image_row["file_name"]
box_pred = convert_prediction_to_coco_bbox(image_name, result)
mask_pred = convert_prediction_to_coco_mask(image_name, result)
accumulate_predictions_for_coco(box_pred, bbox_output)
@ -226,8 +312,9 @@ def eval_mrcnn():
del batch_imgs
del batch_result
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
evaluate_predictions_on_coco(mask_output, iou_type='segm')
evaluate_predictions_on_coco(bbox_output, iou_type="bbox")
evaluate_predictions_on_coco(mask_output, iou_type="segm")
if __name__ == "__main__":
# inference only

View File

@ -3,46 +3,60 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters, getenv
import numpy as np
def test_model(model, *inputs):
GlobalCounters.reset()
out = model(*inputs)
if isinstance(out, Tensor): out = out.numpy()
if isinstance(out, Tensor):
out = out.numpy()
# TODO: return event future to still get the time_sum_s without DEBUG=2
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
print(
f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms"
)
def spec_resnet():
# Resnet50-v1.5
from extra.models.resnet import ResNet50
mdl = ResNet50()
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
def spec_retinanet():
# Retinanet with ResNet backbone
from extra.models.resnet import ResNet50
from extra.models.retinanet import RetinaNet
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
def spec_unet3d():
# 3D UNET
from extra.models.unet3d import UNet3D
mdl = UNet3D()
# mdl.load_from_pretrained()
img = Tensor.randn(1, 1, 128, 128, 128)
test_model(mdl, img)
def spec_rnnt():
from extra.models.rnnt import RNNT
mdl = RNNT()
# mdl.load_from_pretrained()
x = Tensor.randn(220, 1, 240)
y = Tensor.randn(1, 220)
test_model(mdl, x, y)
def spec_bert():
from extra.models.bert import BertForQuestionAnswering
mdl = BertForQuestionAnswering()
# mdl.load_from_pretrained()
x = Tensor.randn(1, 384)
@ -50,13 +64,16 @@ def spec_bert():
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
test_model(mdl, x, am, tt)
def spec_mrcnn():
from extra.models.mask_rcnn import MaskRCNN, ResNet
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
# mdl.load_from_pretrained()
x = Tensor.randn(3, 224, 224)
test_model(mdl, [x])
if __name__ == "__main__":
# inference only for now
Tensor.training = False
@ -67,4 +84,3 @@ if __name__ == "__main__":
if nm in globals():
print(f"testing {m}")
globals()[nm]()

View File

@ -1,36 +1,43 @@
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
def train_resnet():
# TODO: Resnet50-v1.5
pass
def train_retinanet():
# TODO: Retinanet
pass
def train_unet3d():
# TODO: Unet3d
pass
def train_rnnt():
# TODO: RNN-T
pass
def train_bert():
# TODO: BERT
pass
def train_maskrcnn():
# TODO: Mask RCNN
pass
if __name__ == "__main__":
with Tensor.train():
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(
","
):
nm = f"train_{m}"
if nm in globals():
print(f"training {m}")
globals()[nm]()

View File

@ -9,6 +9,7 @@ from tinygrad.helpers import getenv
from tinygrad.nn import optim
from extra.datasets import fetch_mnist
class LinearGen:
def __init__(self):
self.l1 = Tensor.scaled_uniform(128, 256)
@ -23,6 +24,7 @@ class LinearGen:
x = x.dot(self.l4).tanh()
return x
class LinearDisc:
def __init__(self):
self.l1 = Tensor.scaled_uniform(784, 1024)
@ -38,16 +40,21 @@ class LinearDisc:
x = x.dot(self.l4).log_softmax()
return x
def make_batch(images):
sample = np.random.randint(0, len(images), size=(batch_size))
image_b = images[sample].reshape(-1, 28 * 28).astype(np.float32) / 127.5 - 1.0
return Tensor(image_b)
def make_labels(bs, col, val=-2.0):
y = np.zeros((bs, 2), np.float32)
y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
y[
range(bs), [col] * bs
] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
return Tensor(y)
def train_discriminator(optimizer, data_real, data_fake):
real_labels = make_labels(batch_size, 1)
fake_labels = make_labels(batch_size, 0)
@ -61,6 +68,7 @@ def train_discriminator(optimizer, data_real, data_fake):
optimizer.step()
return (loss_real + loss_fake).numpy()
def train_generator(optimizer, data_fake):
real_labels = make_labels(batch_size, 1)
optimizer.zero_grad()
@ -70,6 +78,7 @@ def train_generator(optimizer, data_fake):
optimizer.step()
return loss.numpy()
if __name__ == "__main__":
# data for training and validation
images_real = np.vstack(fetch_mnist()[::2])
@ -85,7 +94,9 @@ if __name__ == "__main__":
output_dir = Path(".").resolve() / "outputs"
output_dir.mkdir(exist_ok=True)
# optimizers
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_g = optim.Adam(
get_parameters(generator), lr=0.0002, b1=0.5
) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
# training loop
for epoch in (t := trange(epochs)):
@ -102,6 +113,11 @@ if __name__ == "__main__":
if (epoch + 1) % sample_interval == 0:
fake_images = generator.forward(ds_noise).detach().numpy()
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
save_image(
make_grid(torch.tensor(fake_images)),
output_dir / f"image_{epoch+1}.jpg",
)
t.set_description(
f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}"
)
print("Training Completed!")

View File

@ -9,10 +9,12 @@ from tinygrad.helpers import getenv
from extra.datasets import fetch_mnist
from extra.augment import augment_img
from extra.training import train, evaluate
GPU = getenv("GPU")
QUICK = getenv("QUICK")
DEBUG = getenv("DEBUG")
class SqueezeExciteBlock2D:
def __init__(self, filters):
self.filters = filters
@ -22,7 +24,9 @@ class SqueezeExciteBlock2D:
self.bias2 = Tensor.scaled_uniform(1, self.filters)
def __call__(self, input):
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
se = input.avg_pool2d(
kernel_size=(input.shape[2], input.shape[3])
) # GlobalAveragePool2D
se = se.reshape(shape=(-1, self.filters))
se = se.dot(self.weight1) + self.bias1
se = se.relu()
@ -31,12 +35,16 @@ class SqueezeExciteBlock2D:
se = input.mul(se)
return se
class ConvBlock:
def __init__(self, h, w, inp, filters=128, conv=3):
self.h, self.w = h, w
self.inp = inp
# init weights
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
self.cweights = [
Tensor.scaled_uniform(filters, inp if i == 0 else filters, conv, conv)
for i in range(3)
]
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
# init layers
self._bn = BatchNorm2d(128)
@ -50,9 +58,14 @@ class ConvBlock:
x = self._seb(x)
return x
class BigConvNet:
def __init__(self):
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
self.conv = [
ConvBlock(28, 28, 1),
ConvBlock(28, 28, 128),
ConvBlock(14, 14, 128),
]
self.weight1 = Tensor.scaled_uniform(128, 10)
self.weight2 = Tensor.scaled_uniform(128, 10)
@ -63,19 +76,19 @@ class BigConvNet:
for par in pars:
print(par.shape)
no_pars += np.prod(par.shape)
print('no of parameters', no_pars)
print("no of parameters", no_pars)
return pars
else:
return get_parameters(self)
def save(self, filename):
with open(filename+'.npy', 'wb') as f:
with open(filename + ".npy", "wb") as f:
for par in get_parameters(self):
# if par.requires_grad:
np.save(f, par.numpy())
def load(self, filename):
with open(filename+'.npy', 'rb') as f:
with open(filename + ".npy", "rb") as f:
for par in get_parameters(self):
# if par.requires_grad:
try:
@ -83,7 +96,7 @@ class BigConvNet:
if GPU:
par.gpu()
except:
print('Could not load parameter')
print("Could not load parameter")
def forward(self, x):
x = self.conv[0](x)
@ -102,7 +115,10 @@ if __name__ == "__main__":
BS = 32
lmbd = 0.00025
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
lossfn = (
lambda out, y: out.sparse_categorical_crossentropy(y)
+ lmbd * (model.weight1.abs() + model.weight2.abs()).sum()
)
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
@ -133,4 +149,4 @@ if __name__ == "__main__":
X_aug = X_train if epoch == 1 else augment_img(X_train)
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
accuracy = evaluate(model, X_test, Y_test, BS=BS)
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')
model.save(f"examples/checkpoint{accuracy * 1e6:.0f}")

View File

@ -6,14 +6,14 @@ from tinygrad.nn.state import get_parameters
if __name__ == "__main__":
with Tensor.train():
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1
x = Tensor.uniform(BS, C1, H, W)
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
bn = BatchNorm2d(C2, track_running_stats=False)
for t in get_parameters([x, conv, bn]): t.realize()
for t in get_parameters([x, conv, bn]):
t.realize()
print("running network")
x.sequential([conv, bn]).numpy()

File diff suppressed because it is too large Load Diff

View File

@ -7,9 +7,17 @@ import soundfile
import numpy as np
import parselmouth
class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm"
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = (
hop_length,
f0_min,
f0_max,
sampling_rate,
"pm",
)
def interpolate_f0(self, f0):
vuv_vector = np.zeros_like(f0, dtype=np.float32)
vuv_vector[f0 > 0.0] = 1.0
@ -19,79 +27,142 @@ class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
nzindex = nzindex.astype(np.float32)
time_org = self.hop_length / self.sampling_rate * nzindex
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
if data.shape[0] <= 0:
return np.zeros(f0.shape[0], dtype=np.float32), vuv_vector
if data.shape[0] == 1:
return np.ones(f0.shape[0], dtype=np.float32) * f0[0], vuv_vector
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
return f0, vuv_vector
def compute_f0(self, wav, p_len=None):
x = wav
if p_len is None: p_len = x.shape[0]//self.hop_length
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
if p_len is None:
p_len = x.shape[0] // self.hop_length
else:
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
time_step = self.hop_length / self.sampling_rate * 1000
f0 = parselmouth.Sound(x, self.sampling_rate) \
.to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \
.selected_array['frequency']
f0 = (
parselmouth.Sound(x, self.sampling_rate)
.to_pitch_ac(
time_step=time_step / 1000,
voicing_threshold=0.6,
pitch_floor=self.f0_min,
pitch_ceiling=self.f0_max,
)
.selected_array["frequency"]
)
pad_size = (p_len - len(f0) + 1) // 2
if(pad_size>0 or p_len - len(f0) - pad_size>0):
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
f0, uv = self.interpolate_f0(f0)
return f0
def compute_f0_uv(self, wav, p_len=None):
x = wav
if p_len is None: p_len = x.shape[0]//self.hop_length
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
if p_len is None:
p_len = x.shape[0] // self.hop_length
else:
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
time_step = self.hop_length / self.sampling_rate * 1000
f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
time_step=time_step / 1000, voicing_threshold=0.6,
pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
f0 = (
parselmouth.Sound(x, self.sampling_rate)
.to_pitch_ac(
time_step=time_step / 1000,
voicing_threshold=0.6,
pitch_floor=self.f0_min,
pitch_ceiling=self.f0_max,
)
.selected_array["frequency"]
)
pad_size = (p_len - len(f0) + 1) // 2
if(pad_size>0 or p_len - len(f0) - pad_size>0):
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
f0, uv = self.interpolate_f0(f0)
return f0, uv
class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000):
def __init__(
self,
sr: int,
threshold: float = -40.0,
min_length: int = 5000,
min_interval: int = 300,
hop_size: int = 20,
max_sil_kept: int = 5000,
):
if not min_length >= min_interval >= hop_size:
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
raise ValueError(
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
)
if not max_sil_kept >= hop_size:
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
raise ValueError(
"The following condition must be satisfied: max_sil_kept >= hop_size"
)
min_interval = sr * min_interval / 1000
self.threshold = 10 ** (threshold / 20.)
self.threshold = 10 ** (threshold / 20.0)
self.hop_size = round(sr * hop_size / 1000)
self.win_size = min(round(min_interval), 4 * self.hop_size)
self.min_length = round(sr * min_length / 1000 / self.hop_size)
self.min_interval = round(min_interval / self.hop_size)
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
def _apply_slice(self, waveform, begin, end):
if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
if len(waveform.shape) > 1:
return waveform[
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
]
else:
return waveform[
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
]
def slice(self, waveform):
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
if samples.shape[0] <= self.min_length:
return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
rms_list = librosa.feature.rms(
y=samples, frame_length=self.win_size, hop_length=self.hop_size
).squeeze(0)
sil_tags, silence_start, clip_start = [], None, 0
for i, rms in enumerate(rms_list):
if rms < self.threshold: # Keep looping while frame is silent.
if silence_start is None: # Record start of silent frames.
silence_start = i
continue
if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded.
if silence_start is None:
continue # Keep looping while frame is not silent and silence start has not been recorded.
# Clear recorded silence start if interval is not enough or clip is too short
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
need_slice_middle = (
i - silence_start >= self.min_interval
and i - clip_start >= self.min_length
)
if not is_leading_silence and not need_slice_middle:
silence_start = None
continue
if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed.
if (
i - silence_start <= self.max_sil_kept
): # Need slicing. Record the range of silent frames to be removed.
pos = rms_list[silence_start : i + 1].argmin() + silence_start
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
clip_start = pos
elif i - silence_start <= self.max_sil_kept * 2:
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
pos = rms_list[
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
].argmin()
pos += i - self.max_sil_kept
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
pos_l = (
rms_list[
silence_start : silence_start + self.max_sil_kept + 1
].argmin()
+ silence_start
)
pos_r = (
rms_list[i - self.max_sil_kept : i + 1].argmin()
+ i
- self.max_sil_kept
)
if silence_start == 0:
sil_tags.append((0, pos_r))
clip_start = pos_r
@ -99,41 +170,105 @@ class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
clip_start = max(pos_r, pos)
else:
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
pos_l = (
rms_list[
silence_start : silence_start + self.max_sil_kept + 1
].argmin()
+ silence_start
)
pos_r = (
rms_list[i - self.max_sil_kept : i + 1].argmin()
+ i
- self.max_sil_kept
)
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
clip_start = pos_r
silence_start = None
total_frames = rms_list.shape[0]
if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence.
if (
silence_start is not None
and total_frames - silence_start >= self.min_interval
): # Deal with trailing silence.
silence_end = min(total_frames, silence_start + self.max_sil_kept)
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
sil_tags.append((pos, total_frames + 1))
if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices.
if len(sil_tags) == 0:
return {
"0": {"slice": False, "split_time": f"0,{len(waveform)}"}
} # Apply and return slices.
chunks = []
if sil_tags[0][0]:
chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
chunks.append(
{
"slice": False,
"split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}",
}
)
for i in range(0, len(sil_tags)):
if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
if i:
chunks.append(
{
"slice": False,
"split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}",
}
)
chunks.append(
{
"slice": True,
"split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}",
}
)
if sil_tags[-1][1] * self.hop_size < len(waveform):
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
chunks.append(
{
"slice": False,
"split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}",
}
)
chunk_dict = {}
for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i]
for i in range(len(chunks)):
chunk_dict[str(i)] = chunks[i]
return chunk_dict
# sinc_interp_hann audio resampling
class Resample:
def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None):
self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta
def __init__(
self,
orig_freq: int = 16000,
new_freq: int = 16000,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None,
dtype: Optional[dtypes] = None,
):
(
self.orig_freq,
self.new_freq,
self.lowpass_filter_width,
self.rolloff,
self.beta,
) = (orig_freq, new_freq, lowpass_filter_width, rolloff, beta)
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None)
self.kernel, self.width = (
self._get_sinc_resample_kernel(dtype)
if self.orig_freq != self.new_freq
else (None, None)
)
def __call__(self, waveform: Tensor) -> Tensor:
if self.orig_freq == self.new_freq: return waveform
if self.orig_freq == self.new_freq:
return waveform
return self._apply_sinc_resample_kernel(waveform)
def _apply_sinc_resample_kernel(self, waveform: Tensor):
if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.")
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
if not waveform.is_floating_point():
raise TypeError(
f"Waveform tensor expected to be of type float, but received {waveform.dtype}."
)
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
int(self.new_freq) // self.gcd
)
shape = waveform.shape
waveform = waveform.reshape(-1, shape[-1]) # pack batch
num_wavs, length = waveform.shape
@ -144,34 +279,58 @@ class Resample:
resampled = resampled[..., :target_length]
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
return resampled
def _get_sinc_resample_kernel(self, dtype=None):
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.")
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
int(self.new_freq) // self.gcd
)
if self.lowpass_filter_width <= 0:
raise ValueError("Low pass filter width should be positive.")
base_freq = min(orig_freq, new_freq)
base_freq *= self.rolloff
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq
idx = (
Tensor.arange(
-width,
width + orig_freq,
dtype=(dtype if dtype is not None else dtypes.float32),
)[None, None]
/ orig_freq
)
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
t *= base_freq
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
t *= math.pi
scale = base_freq / orig_freq
kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t)
kernels = Tensor.where(
t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t
)
kernels *= window * scale
if dtype is None: kernels = kernels.cast(dtype=dtypes.float32)
if dtype is None:
kernels = kernels.cast(dtype=dtypes.float32)
return kernels, width
def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None):
def sinc_interp_resample(
x: Tensor,
orig_freq: int = 16000,
new_freq: int = 1600,
lowpass_filter_width: int = 6,
rolloff: float = 0.99,
beta: Optional[float] = None,
):
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
return resamp(x)
def cut(audio_path, db_thresh=-30, min_len=5000):
audio, sr = librosa.load(audio_path, sr=None)
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
chunks = slicer.slice(audio)
return chunks
def chunks2audio(audio_path, chunks):
chunks = dict(chunks)
audio, sr = load_audiofile(audio_path)
@ -185,19 +344,30 @@ def chunks2audio(audio_path, chunks):
result.append((v["slice"], audio[int(tag[0]) : int(tag[1])]))
return result, sr
def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True):
def load_audiofile(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
channels_first: bool = True,
):
with soundfile.SoundFile(filepath, "r") as file_:
frames = file_._prepare_read(frame_offset, None, num_frames)
waveform = file_.read(frames, "float32", always_2d=True)
sample_rate = file_.samplerate
waveform = Tensor(waveform)
if channels_first: waveform = waveform.transpose(0, 1)
if channels_first:
waveform = waveform.transpose(0, 1)
return waveform, sample_rate
def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]:
def get_unit_f0(
wav: Tensor, tran, hop_length, target_sample, f0_filter=False
) -> Tuple[Tensor, Tensor, Tensor]:
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected")
if f0_filter and sum(f0) == 0:
raise RuntimeError("No voice detected")
f0 = Tensor(f0.astype(np.float32)).float()
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)

View File

@ -14,6 +14,7 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.jit import TinyJit
class AttnBlock:
def __init__(self, in_channels):
self.norm = GroupNorm(32, in_channels)
@ -30,22 +31,32 @@ class AttnBlock:
# compute attention
b, c, h, w = q.shape
q, k, v = [x.reshape(b, c, h * w).transpose(1, 2) for x in (q, k, v)]
h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
h_ = (
Tensor.scaled_dot_product_attention(q, k, v)
.transpose(1, 2)
.reshape(b, c, h, w)
)
return x + self.proj_out(h_)
class ResnetBlock:
def __init__(self, in_channels, out_channels=None):
self.norm1 = GroupNorm(32, in_channels)
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
self.norm2 = GroupNorm(32, out_channels)
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
self.nin_shortcut = (
Conv2d(in_channels, out_channels, 1)
if in_channels != out_channels
else lambda x: x
)
def __call__(self, x):
h = self.conv1(self.norm1(x).swish())
h = self.conv2(self.norm2(h).swish())
return self.nin_shortcut(x) + h
class Mid:
def __init__(self, block_in):
self.block_1 = ResnetBlock(block_in, block_in)
@ -55,6 +66,7 @@ class Mid:
def __call__(self, x):
return x.sequential([self.block_1, self.attn_1, self.block_2])
class Decoder:
def __init__(self):
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
@ -63,11 +75,17 @@ class Decoder:
arr = []
for i, s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[1], s[0]),
arr.append(
{
"block": [
ResnetBlock(s[1], s[0]),
ResnetBlock(s[0], s[0]),
ResnetBlock(s[0], s[0])]})
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
ResnetBlock(s[0], s[0]),
]
}
)
if i != 0:
arr[-1]["upsample"] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
self.up = arr
self.norm_out = GroupNorm(32, 128)
@ -79,16 +97,22 @@ class Decoder:
for l in self.up[::-1]:
print("decode", x.shape)
for b in l['block']: x = b(x)
if 'upsample' in l:
for b in l["block"]:
x = b(x)
if "upsample" in l:
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
bs, c, py, px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
x = l['upsample']['conv'](x)
x = (
x.reshape(bs, c, py, 1, px, 1)
.expand(bs, c, py, 2, px, 2)
.reshape(bs, c, py * 2, px * 2)
)
x = l["upsample"]["conv"](x)
x.realize()
return self.conv_out(self.norm_out(x).swish())
class Encoder:
def __init__(self):
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
@ -96,10 +120,11 @@ class Encoder:
arr = []
for i, s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[0], s[1]),
ResnetBlock(s[1], s[1])]})
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
if i != 3:
arr[-1]["downsample"] = {
"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))
}
self.down = arr
self.mid = Mid(512)
@ -111,12 +136,15 @@ class Encoder:
for l in self.down:
print("encode", x.shape)
for b in l['block']: x = b(x)
if 'downsample' in l: x = l['downsample']['conv'](x)
for b in l["block"]:
x = b(x)
if "downsample" in l:
x = l["downsample"]["conv"](x)
x = self.mid(x)
return self.conv_out(self.norm_out(x).swish())
class AutoencoderKL:
def __init__(self):
self.encoder = Encoder()
@ -132,25 +160,27 @@ class AutoencoderKL:
latent = self.post_quant_conv(latent)
return self.decoder(latent)
# not to be confused with ResnetBlock
class ResBlock:
def __init__(self, channels, emb_channels, out_channels):
self.in_layers = [
GroupNorm(32, channels),
Tensor.silu,
Conv2d(channels, out_channels, 3, padding=1)
]
self.emb_layers = [
Tensor.silu,
Linear(emb_channels, out_channels)
Conv2d(channels, out_channels, 3, padding=1),
]
self.emb_layers = [Tensor.silu, Linear(emb_channels, out_channels)]
self.out_layers = [
GroupNorm(32, out_channels),
Tensor.silu,
lambda x: x, # needed for weights loading code to work
Conv2d(out_channels, out_channels, 3, padding=1)
Conv2d(out_channels, out_channels, 3, padding=1),
]
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
self.skip_connection = (
Conv2d(channels, out_channels, 1)
if channels != out_channels
else lambda x: x
)
def __call__(self, x, emb):
h = x.sequential(self.in_layers)
@ -160,6 +190,7 @@ class ResBlock:
ret = self.skip_connection(x) + h
return ret
class CrossAttention:
def __init__(self, query_dim, context_dim, n_heads, d_head):
self.to_q = Linear(query_dim, n_heads * d_head, bias=False)
@ -172,11 +203,15 @@ class CrossAttention:
def __call__(self, x, context=None):
context = x if context is None else context
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
q, k, v = [
y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1, 2)
for y in (q, k, v)
]
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2)
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
return h_.sequential(self.to_out)
class GEGLU:
def __init__(self, dim_in, dim_out):
self.proj = Linear(dim_in, dim_out * 2)
@ -186,17 +221,19 @@ class GEGLU:
x, gate = self.proj(x).chunk(2, dim=-1)
return x * gate.gelu()
class FeedForward:
def __init__(self, dim, mult=4):
self.net = [
GEGLU(dim, dim * mult),
lambda x: x, # needed for weights loading code to work
Linear(dim*mult, dim)
Linear(dim * mult, dim),
]
def __call__(self, x):
return x.sequential(self.net)
class BasicTransformerBlock:
def __init__(self, dim, context_dim, n_heads, d_head):
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
@ -212,12 +249,15 @@ class BasicTransformerBlock:
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer:
def __init__(self, channels, context_dim, n_heads, d_head):
self.norm = GroupNorm(32, channels)
assert channels == n_heads * d_head
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
self.transformer_blocks = [
BasicTransformerBlock(channels, context_dim, n_heads, d_head)
]
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
def __call__(self, x, context=None):
@ -232,6 +272,7 @@ class SpatialTransformer:
ret = self.proj_out(x) + x_in
return ret
class Downsample:
def __init__(self, channels):
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
@ -239,21 +280,28 @@ class Downsample:
def __call__(self, x):
return self.op(x)
class Upsample:
def __init__(self, channels):
self.conv = Conv2d(channels, channels, 3, padding=1)
def __call__(self, x):
bs, c, py, px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
x = (
x.reshape(bs, c, py, 1, px, 1)
.expand(bs, c, py, 2, px, 2)
.reshape(bs, c, py * 2, px * 2)
)
return self.conv(x)
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
args = timesteps * freqs
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)
class UNetModel:
def __init__(self):
self.time_embed = [
@ -273,12 +321,12 @@ class UNetModel:
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[Downsample(1280)],
[ResBlock(1280, 1280, 1280)],
[ResBlock(1280, 1280, 1280)]
[ResBlock(1280, 1280, 1280)],
]
self.middle_block = [
ResBlock(1280, 1280, 1280),
SpatialTransformer(1280, 768, 8, 160),
ResBlock(1280, 1280, 1280)
ResBlock(1280, 1280, 1280),
]
self.output_blocks = [
[ResBlock(2560, 1280, 1280)],
@ -286,10 +334,18 @@ class UNetModel:
[ResBlock(2560, 1280, 1280), Upsample(1280)],
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)],
[
ResBlock(1920, 1280, 1280),
SpatialTransformer(1280, 768, 8, 160),
Upsample(1280),
],
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
[ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)],
[
ResBlock(960, 1280, 640),
SpatialTransformer(640, 768, 8, 80),
Upsample(640),
],
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
@ -297,7 +353,7 @@ class UNetModel:
self.out = [
GroupNorm(32, 320),
Tensor.silu,
Conv2d(320, 4, kernel_size=3, padding=1)
Conv2d(320, 4, kernel_size=3, padding=1),
]
def __call__(self, x, timesteps=None, context=None):
@ -306,9 +362,12 @@ class UNetModel:
emb = t_emb.sequential(self.time_embed)
def run(x, bb):
if isinstance(bb, ResBlock): x = bb(x, emb)
elif isinstance(bb, SpatialTransformer): x = bb(x, context)
else: x = bb(x)
if isinstance(bb, ResBlock):
x = bb(x, emb)
elif isinstance(bb, SpatialTransformer):
x = bb(x, context)
else:
x = bb(x)
return x
saved_inputs = []
@ -326,6 +385,7 @@ class UNetModel:
x = run(x, bb)
return x.sequential(self.out)
class CLIPMLP:
def __init__(self):
self.fc1 = Linear(768, 3072)
@ -337,6 +397,7 @@ class CLIPMLP:
hidden_states = self.fc2(hidden_states)
return hidden_states
class CLIPAttention:
def __init__(self):
self.embed_dim = 768
@ -349,10 +410,22 @@ class CLIPAttention:
def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
q, k, v = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
q, k, v = [
x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
for x in (q, k, v)
]
attn_output = Tensor.scaled_dot_product_attention(
q, k, v, attn_mask=causal_attention_mask
)
return self.out_proj(
attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
)
class CLIPEncoderLayer:
def __init__(self):
@ -374,6 +447,7 @@ class CLIPEncoderLayer:
return hidden_states
class CLIPEncoder:
def __init__(self):
self.layers = [CLIPEncoderLayer() for i in range(12)]
@ -383,6 +457,7 @@ class CLIPEncoder:
hidden_states = l(hidden_states, causal_attention_mask)
return hidden_states
class CLIPTextEmbeddings:
def __init__(self):
self.token_embedding = Embedding(49408, 768)
@ -391,6 +466,7 @@ class CLIPTextEmbeddings:
def __call__(self, input_ids, position_ids):
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
class CLIPTextTransformer:
def __init__(self):
self.embeddings = CLIPTextEmbeddings()
@ -402,9 +478,15 @@ class CLIPTextTransformer:
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1))
return self.final_layer_norm(x)
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
@lru_cache()
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
def default_bpe():
return fetch(
"https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz",
"bpe_simple_vocab_16e6.txt.gz",
)
def get_pairs(word):
"""Return set of symbol pairs in a word.
@ -417,11 +499,13 @@ def get_pairs(word):
prev_char = char
return pairs
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
@ -432,7 +516,11 @@ def bytes_to_unicode():
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
@ -443,33 +531,40 @@ def bytes_to_unicode():
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
class ClipTokenizer:
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
self.cache = {
"<|startoftext|>": "<|startoftext|>",
"<|endoftext|>": "<|endoftext|>",
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""",
re.IGNORECASE,
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
return token + "</w>"
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
@ -495,7 +590,7 @@ class ClipTokenizer:
if len(word) == 1:
break
pairs = get_pairs(word)
word = ' '.join(word)
word = " ".join(word)
self.cache[token] = word
return word
@ -503,19 +598,28 @@ class ClipTokenizer:
bpe_tokens = []
text = whitespace_clean(text.strip()).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
)
# Truncation, keeping two slots for start and end tokens.
if len(bpe_tokens) > 75:
bpe_tokens = bpe_tokens[:75]
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
class StableDiffusion:
def __init__(self):
self.alphas_cumprod = Tensor.empty(1000)
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
diffusion_model=UNetModel()
)
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
transformer=namedtuple("Transformer", ["text_model"])(
text_model=CLIPTextTransformer()
)
)
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
temperature = 1
@ -526,17 +630,30 @@ class StableDiffusion:
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
return x_prev, pred_x0
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
def get_model_output(
self,
unconditional_context,
context,
latent,
timestep,
unconditional_guidance_scale,
):
# put into diffuser
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
latents = self.model.diffusion_model(
latent.expand(2, *latent.shape[1:]),
timestep,
unconditional_context.cat(context, dim=0),
)
unconditional_latent, latent = latents[0:1], latents[1:2]
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
e_t = unconditional_latent + unconditional_guidance_scale * (
latent - unconditional_latent
)
return e_t
def decode(self, x):
@ -548,14 +665,26 @@ class StableDiffusion:
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
def __call__(
self,
unconditional_context,
context,
latent,
timestep,
alphas,
alphas_prev,
guidance,
):
e_t = self.get_model_output(
unconditional_context, context, latent, timestep, guidance
)
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
# e_t_next = get_model_output(x_prev)
# e_t_prime = (e_t + e_t_next) / 2
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
return x_prev.realize()
# ** ldm.models.autoencoder.AutoencoderKL (done!)
# 3x512x512 <--> 4x64x64 (16384)
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
@ -573,22 +702,48 @@ class StableDiffusion:
# cond_stage_model.transformer.text_model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
parser.add_argument('--timing', action='store_true', help="Print timing per step")
parser.add_argument('--seed', type=int, help="Set the random latent seed")
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
parser = argparse.ArgumentParser(
description="Run Stable Diffusion",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--steps", type=int, default=5, help="Number of steps in diffusion"
)
parser.add_argument(
"--prompt",
type=str,
default="a horse sized cat eating a bagel",
help="Phrase to render",
)
parser.add_argument(
"--out",
type=str,
default=Path(tempfile.gettempdir()) / "rendered.png",
help="Output filename",
)
parser.add_argument("--noshow", action="store_true", help="Don't show the image")
parser.add_argument(
"--fp16", action="store_true", help="Cast the weights to float16"
)
parser.add_argument("--timing", action="store_true", help="Print timing per step")
parser.add_argument("--seed", type=int, help="Set the random latent seed")
parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength")
args = parser.parse_args()
Tensor.no_grad = True
model = StableDiffusion()
# load in weights
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
load_state_dict(
model,
torch_load(
fetch(
"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
"sd-v1-4.ckpt",
)
)["state_dict"],
strict=False,
)
if args.fp16:
for l in get_state_dict(model).values():
@ -601,7 +756,9 @@ if __name__ == "__main__":
print("got CLIP context", context.shape)
prompt = Tensor([tokenizer.encode("")])
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
unconditional_context = model.cond_stage_model.transformer.text_model(
prompt
).realize()
print("got unconditional CLIP context", unconditional_context.shape)
# done with clip model
@ -613,21 +770,37 @@ if __name__ == "__main__":
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
# start with random noise
if args.seed is not None: Tensor._seed = args.seed
if args.seed is not None:
Tensor._seed = args.seed
latent = Tensor.randn(1, 4, 64, 64)
@TinyJit
def run(model, *x): return model(*x).realize()
def run(model, *x):
return model(*x).realize()
# this is diffusion
with Context(BEAM=getenv("LATEBEAM")):
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
with Timing(
"step in ",
enabled=args.timing,
on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB",
):
tid = Tensor([index])
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
latent = run(
model,
unconditional_context,
context,
latent,
Tensor([timestep]),
alphas[tid],
alphas_prev[tid],
Tensor([args.guidance]),
)
if args.timing:
Device[Device.DEFAULT].synchronize()
del run
# upsample latent space to image with autoencoder
@ -637,8 +810,10 @@ if __name__ == "__main__":
# save image
from PIL import Image
import numpy as np
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
print(f"saving {args.out}")
im.save(args.out)
# Open image.
if not args.noshow: im.show()
if not args.noshow:
im.show()

View File

@ -10,6 +10,7 @@ from tinygrad.tensor import Tensor
from extra.datasets import fetch_cifar
from extra.models.efficientnet import EfficientNet
class TinyConvNet:
def __init__(self, classes=10):
conv = 3
@ -24,6 +25,7 @@ class TinyConvNet:
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1)
if __name__ == "__main__":
IMAGENET = getenv("IMAGENET")
classes = 1000 if IMAGENET else 10
@ -47,12 +49,14 @@ if __name__ == "__main__":
if IMAGENET:
from extra.datasets.imagenet import fetch_batch
def loader(q):
while 1:
try:
q.put(fetch_batch(BS))
except Exception:
traceback.print_exc()
q = Queue(16)
for i in range(2):
p = Process(target=loader, args=(q,))
@ -97,9 +101,17 @@ if __name__ == "__main__":
finish_time = (time.time() - st) * 1000.0
# printing
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
(loss, accuracy,
fp_time, bp_time, opt_time, finish_time,
fp_time + bp_time + opt_time + finish_time))
t.set_description(
"loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f"
% (
loss,
accuracy,
fp_time,
bp_time,
opt_time,
finish_time,
fp_time + bp_time + opt_time + finish_time,
)
)
del out, y, loss

View File

@ -19,27 +19,30 @@ class ComposeTransforms:
x = t(x)
return x
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
classes = 10
TRANSFER = getenv('TRANSFER')
model = ResNet(getenv('NUM', 18), num_classes=classes)
TRANSFER = getenv("TRANSFER")
model = ResNet(getenv("NUM", 18), num_classes=classes)
if TRANSFER:
model.load_from_pretrained()
lr = 5e-3
transform = ComposeTransforms([
lambda x: [Image.fromarray(xx, mode='L').resize((64, 64)) for xx in x],
transform = ComposeTransforms(
[
lambda x: [Image.fromarray(xx, mode="L").resize((64, 64)) for xx in x],
lambda x: np.stack([np.asarray(xx) for xx in x], 0),
lambda x: x / 255.0,
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
])
]
)
for _ in range(5):
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
lr /= 1.2
print(f'reducing lr to {lr:.7f}')
print(f"reducing lr to {lr:.7f}")

View File

@ -7,13 +7,16 @@ from tinygrad.nn.optim import Adam
from extra.training import train, evaluate
from extra.models.transformer import Transformer
# dataset idea from https://github.com/karpathy/minGPT/blob/master/projects/adder/adder.py
def make_dataset():
ds = []
for i in range(100):
for j in range(100):
s = i + j
ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
ds.append(
[i // 10, i % 10, j // 10, j % 10, s // 100, (s // 10) % 10, s % 10]
)
random.shuffle(ds)
ds = np.array(ds).astype(np.float32)
ds_X = ds[:, 0:6]
@ -22,6 +25,7 @@ def make_dataset():
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
if __name__ == "__main__":
model = Transformer(10, 6, 2, 128, 4, 32)
X_train, Y_train, X_test, Y_test = make_dataset()
@ -29,14 +33,23 @@ if __name__ == "__main__":
for i in range(10):
optim = Adam(get_parameters(model), lr=lr)
train(model, X_train, Y_train, optim, 50, BS=64)
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
acc, Y_test_preds = evaluate(
model, X_test, Y_test, num_classes=10, return_predict=True
)
lr /= 1.2
print(f'reducing lr to {lr:.4f}')
print(f"reducing lr to {lr:.4f}")
if acc > 0.998:
wrong = 0
for k in range(len(Y_test_preds)):
if (Y_test_preds[k] != Y_test[k]).any():
wrong += 1
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
a, b, c, x = (
X_test[k, :2],
X_test[k, 2:4],
Y_test[k, -3:],
Y_test_preds[k, -3:],
)
print(
f"{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})"
)
print(f"Wrong predictions: {wrong}, acc = {acc:.4f}")

View File

@ -12,6 +12,7 @@ from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7
# amount of context erased by model
CONTEXT = 7
def get_sample_count(samples_dir):
try:
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
@ -21,18 +22,24 @@ def get_sample_count(samples_dir):
except:
return 0
def set_sample_count(samples_dir, sc):
with open(samples_dir + "/sample_count.txt", "w") as file:
file.write(str(sc) + "\n")
if len(sys.argv) < 2:
print("python3 -m examples.vgg7 import MODELJSON MODEL")
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
print(
" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json"
)
print(" into a safetensors file")
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
print(" *this format is used by most other commands in this program*")
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors")
print(
" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors"
)
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
print(" output image has 7 pixels removed on all edges")
@ -53,7 +60,9 @@ if len(sys.argv) < 2:
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
print(" won't pad or tile, so keep image sizes sane")
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
print(
" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training"
)
print(" maintains/creates samples_count.txt automatically")
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
sys.exit(1)
@ -61,9 +70,13 @@ if len(sys.argv) < 2:
cmd = sys.argv[1]
vgg7 = Vgg7()
def nansbane(p):
if numpy.isnan(numpy.min(p.numpy())):
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
raise Exception(
"A NaN in the model has been detected. This model will not be interacted with to prevent further damage."
)
def load_and_save(path, save):
if save:
@ -77,6 +90,7 @@ def load_and_save(path, save):
for v in vgg7.get_parameters():
nansbane(v)
if cmd == "import":
src = sys.argv[2]
model = sys.argv[3]
@ -158,7 +172,9 @@ elif cmd == "train":
sample_idx = 0
try:
sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
sample_idx = numpy.random.choice(
samples_count, p=sample_probs / sample_probs.sum()
)
except:
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
sample_idx = random.randint(0, samples_count - 1)
@ -204,7 +220,7 @@ elif cmd == "train":
rnum = rnum + 1
# Probability management
# there must always be a probability, no matter how slim, even if loss goes to 0
sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
sample_probs[sample_idx] = max(loss_indicator, 1.0e-10)
# if we were told to save every round, we already saved
if rounds_per_save != 1:
@ -237,8 +253,12 @@ elif cmd == "samplify":
samples_added = 0
# actual patch extraction
for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
for posy in range(
CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size
):
for posx in range(
CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size
):
# this is a viable patch location, add it
# note the ranges here:
# + there are always CONTEXT pixels *before* the point
@ -247,7 +267,12 @@ elif cmd == "samplify":
# + additionally, there are sample_size - 1 additional sample pixels
# + additionally, there are CONTEXT additional pixels
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
patch_x = a_img[
:,
:,
posy - CONTEXT : posy + CONTEXT + sample_size,
posx - CONTEXT : posx + CONTEXT + sample_size,
]
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)

View File

@ -11,6 +11,7 @@ from tinygrad.helpers import fetch
# tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
# tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
def image_load(path) -> numpy.ndarray:
"""
Loads an image in the shape expected by other functions in this module.
@ -29,6 +30,7 @@ def image_load(path) -> numpy.ndarray:
na = na.astype("float32") / 255.0
return na
def image_save(path, na: numpy.ndarray):
"""
Saves an image of the shape expected by other functions in this module.
@ -44,12 +46,15 @@ def image_save(path, na: numpy.ndarray):
# file
Image.fromarray(na).save(path)
# The Model
class Conv3x3Biased:
"""
A 3x3 convolution layer with some utility functions.
"""
def __init__(self, inC, outC, last=False):
# The properties must be named as "W" and "b".
# This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x
@ -80,9 +85,12 @@ class Conv3x3Biased:
# Not outChannel,inChannel,Y,X.
# Therefore, transpose it before assignment.
# I have long since forgotten how I worked this out.
self.W.assign(Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3))
self.W.assign(
Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3)
)
self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape))
class Vgg7:
"""
The 'vgg7' waifu2x network.
@ -115,14 +123,31 @@ class Vgg7:
return x
def get_parameters(self) -> list:
return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
return (
self.conv1.get_parameters()
+ self.conv2.get_parameters()
+ self.conv3.get_parameters()
+ self.conv4.get_parameters()
+ self.conv5.get_parameters()
+ self.conv6.get_parameters()
+ self.conv7.get_parameters()
)
def load_from_pretrained(self, intent="art", subtype="scale2.0x"):
"""
Downloads a nagadomi/waifu2x JSON weight file and loads it.
"""
import json
data = json.loads(fetch("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json").read_bytes())
data = json.loads(
fetch(
"https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/"
+ intent
+ "/"
+ subtype
+ "_model.json"
).read_bytes()
)
self.load_waifu2x_json(data)
def load_waifu2x_json(self, data: list):
@ -157,7 +182,9 @@ class Vgg7:
# Padding next. Note that this padding is done on the whole image.
# Padding the tiles would lose critical context, cause seams, etc.
image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge")
image = numpy.pad(
image, [[0, 0], [0, 0], [context, context], [context, context]], mode="edge"
)
# Now for tiling.
# The output tile size is the usable output from an input tile (tile_size).
@ -187,7 +214,8 @@ class Vgg7:
tile_t = Tensor(tile)
tile_fwd_t = self.forward(tile_t)
# Replace tile.
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy()
image_out[
:, :, out_y : out_y + out_h, out_x : out_x + out_w
] = tile_fwd_t.numpy()
return image_out

View File

@ -4,6 +4,7 @@ from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, fetch
from extra.models.vit import ViT
"""
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
import tensorflow as tf
@ -22,7 +23,11 @@ else:
m.load_from_pretrained()
# category labels
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
lbls = ast.literal_eval(
fetch(
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
).read_text()
)
# url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
@ -30,7 +35,9 @@ url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-1
# junk
img = Image.open(fetch(url))
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
img = img.resize(
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
)
img = np.array(img)
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
img = img[y0 : y0 + 224, x0 : x0 + 224]

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,13 @@
import os
from extra.export_model import compile_net, jit_model
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
from tinygrad.nn.state import (
get_state_dict,
safe_save,
safe_load_metadata,
torch_load,
load_state_dict,
)
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.helpers import fetch
@ -10,10 +16,13 @@ from pathlib import Path
import argparse
import numpy as np
def convert_f32_to_f16(input_file, output_file):
with open(input_file, 'rb') as f:
with open(input_file, "rb") as f:
metadata_length_bytes = f.read(8)
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
metadata_length = int.from_bytes(
metadata_length_bytes, byteorder="little", signed=False
)
metadata_json_bytes = f.read(metadata_length)
float32_values = np.fromfile(f, dtype=np.float32)
@ -22,12 +31,13 @@ def convert_f32_to_f16(input_file, output_file):
front_float16_values = float32_values[:num_elements].astype(np.float16)
rest_float32_values = float32_values[num_elements:]
with open(output_file, 'wb') as f:
with open(output_file, "wb") as f:
f.write(metadata_length_bytes)
f.write(metadata_json_bytes)
front_float16_values.tofile(f)
rest_float32_values.tofile(f)
def split_safetensor(fn):
_, json_len, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
@ -35,7 +45,7 @@ def split_safetensor(fn):
for k in metadata:
# safetensor is in fp16, except for text moel
if (metadata[k]["data_offsets"][0] < text_model_offset):
if metadata[k]["data_offsets"][0] < text_model_offset:
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0] / 2)
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1] / 2)
@ -43,35 +53,43 @@ def split_safetensor(fn):
part_end_offsets = []
for k in metadata:
offset = metadata[k]['data_offsets'][0]
offset = metadata[k]["data_offsets"][0]
if offset == text_model_offset:
break
part_offset = offset - last_offset
if (part_offset >= chunk_size):
if part_offset >= chunk_size:
part_end_offsets.append(8 + json_len + offset)
last_offset = offset
text_model_start = int(text_model_offset / 2)
net_bytes = bytes(open(fn, 'rb').read())
net_bytes = bytes(open(fn, "rb").read())
part_end_offsets.append(text_model_start + 8 + json_len)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
with open(f'./net_part{i}.safetensors', "wb+") as f:
with open(f"./net_part{i}.safetensors", "wb+") as f:
f.write(net_bytes[cur_pos:end_pos])
cur_pos = end_pos
with open(f'./net_textmodel.safetensors', "wb+") as f:
with open(f"./net_textmodel.safetensors", "wb+") as f:
f.write(net_bytes[text_model_start + 8 + json_len :])
return part_end_offsets
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
parser = argparse.ArgumentParser(
description="Run Stable Diffusion",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--remoteweights",
action="store_true",
help="Use safetensors from Huggingface, or from local",
)
args = parser.parse_args()
Device.DEFAULT = "WEBGPU"
@ -79,7 +97,16 @@ if __name__ == "__main__":
model = StableDiffusion()
# load in weights
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
load_state_dict(
model,
torch_load(
fetch(
"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
"sd-v1-4.ckpt",
)
)["state_dict"],
strict=False,
)
class Step(NamedTuple):
name: str = ""
@ -87,9 +114,25 @@ if __name__ == "__main__":
forward: Any = None
sub_steps = [
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
Step(
name="textModel",
input=[Tensor.randn(1, 77)],
forward=model.cond_stage_model.transformer.text_model,
),
Step(
name="diffusor",
input=[
Tensor.randn(1, 77, 768),
Tensor.randn(1, 77, 768),
Tensor.randn(1, 4, 64, 64),
Tensor.rand(1),
Tensor.randn(1),
Tensor.randn(1),
Tensor.randn(1),
],
forward=model,
),
Step(name="decoder", input=[Tensor.randn(1, 4, 64, 64)], forward=model.decode),
]
prg = ""
@ -99,12 +142,47 @@ if __name__ == "__main__":
functions, statements, bufs, _ = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
kernel_code = "\n\n".join(
[
f"const {key} = `{code.replace(key, 'main')}`;"
for key, code in functions.items()
]
)
kernel_names = ", ".join([name for (name, _, _, _) in statements])
kernel_calls = "\n ".join(
[
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
for i, (_name, args, global_size, _local_size) in enumerate(statements)
]
)
bufs = "\n ".join(
[
f"const {name} = "
+ (
f"createEmptyBuf(device, {size});"
if _key not in weights
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))"
)
+ ";"
for name, (size, dtype, _key) in bufs.items()
]
)
gpu_write_bufs = "\n ".join(
[
f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
for i, (_, value) in enumerate(special_names.items())
if "output" not in value
]
)
input_writer = "\n ".join(
[
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
+ f"data{i});"
+ f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);"
for i, (_, value) in enumerate(special_names.items())
if value != "output0"
]
)
return f"""\n var {step.name} = function() {{
{kernel_code}
@ -143,7 +221,7 @@ if __name__ == "__main__":
"""
for step in sub_steps:
print(f'Executing step={step.name}')
print(f"Executing step={step.name}")
prg += compile_step(model, step)
if step.name == "diffusor":
@ -151,7 +229,9 @@ if __name__ == "__main__":
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
else:
state = get_state_dict(model)
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
safe_save(
state, os.path.join(os.path.dirname(__file__), "net.safetensors")
)
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
split_safetensor("./net_conv.safetensors")
os.remove("net.safetensors")

View File

@ -15,8 +15,15 @@ from tinygrad.tensor import Tensor
import itertools
import librosa
class MultiHeadAttention:
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
def __init__(
self,
n_state,
n_head,
kv_caching: Literal["cross", "self"] = None,
max_self_attn_cache_len=None,
):
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
@ -26,11 +33,17 @@ class MultiHeadAttention:
self.kv_caching = kv_caching
self.max_self_attn_cache_len = max_self_attn_cache_len
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
if self.kv_caching == 'cross':
def __call__(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
len: Union[Variable, int] = None,
):
if self.kv_caching == "cross":
if xa is not None:
k, v = self.key(xa), self.value(xa)
if not hasattr(self, 'cache_k'):
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = k, v
else:
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
@ -40,50 +53,84 @@ class MultiHeadAttention:
k, v = self.cache_k, self.cache_v
else:
k, v = self.key(x), self.value(x)
if self.kv_caching == 'self':
if not hasattr(self, 'cache_k'):
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
if self.kv_caching == "self":
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.zeros(
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
)
self.cache_v = Tensor.zeros(
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
)
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
padding = self.max_self_attn_cache_len - len - x.shape[1]
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
self.cache_k.assign(
k.pad((None, (0, padding), None)).contiguous()
).realize()
self.cache_v.assign(
v.pad((None, (0, padding), None)).contiguous()
).realize()
q = self.query(x)
n_ctx = q.shape[1]
assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
head_dim = q.shape[-1] // self.n_head
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
attn = Tensor.scaled_dot_product_attention(
q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None
)
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
return self.out(wv)
class ResidualAttentionBlock:
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
def __init__(
self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None
):
self.attn = MultiHeadAttention(
n_state,
n_head,
kv_caching="self" if is_decoder_block else None,
max_self_attn_cache_len=max_self_attn_cache_len,
)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
self.cross_attn = (
MultiHeadAttention(n_state, n_head, kv_caching="cross")
if is_decoder_block
else None
)
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
self.mlp = [
nn.Linear(n_state, n_state * 4),
Tensor.gelu,
nn.Linear(n_state * 4, n_state),
]
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None):
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
x = x + self.mlp_ln(x).sequential(self.mlp)
return x.realize()
class AudioEncoder:
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
def __init__(
self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_
):
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
self.conv2 = nn.Conv1d(
n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1
)
self.blocks = [
ResidualAttentionBlock(n_audio_state, n_audio_head)
for _ in range(n_audio_layer)
]
self.ln_post = nn.LayerNorm(n_audio_state)
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
self.encode = TinyJit(self.__call__)
@ -97,14 +144,27 @@ class AudioEncoder:
x = self.ln_post(x)
return x.realize()
class TextDecoder:
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
def __init__(
self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_
):
self.max_tokens_to_sample = n_text_ctx // 2
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
self.max_self_attn_cache_len = (
self.max_tokens_to_sample * 2 + 5
) # roughly prompt + start toks + max_tokens_to_sample
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
self.blocks = [
ResidualAttentionBlock(
n_text_state,
n_text_head,
is_decoder_block=True,
max_self_attn_cache_len=self.max_self_attn_cache_len,
)
for _ in range(n_text_layer)
]
self.ln = nn.LayerNorm(n_text_state)
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
@ -117,18 +177,23 @@ class TextDecoder:
seqlen = x.shape[-1]
x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen]
if pos == 0:
for block in (self.blocks if streaming else self.blocks_start_tok):
x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
for block in self.blocks if streaming else self.blocks_start_tok:
x = block(
x, xa=encoded_audio, mask=self.mask, len=0
) # pass xa for cross attn kv caching
return self.output_tok(x) if streaming else self.start_output_tok(x)
else:
for block in self.blocks_after_start_tok:
len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
len_v = Variable(
"self_attn_cache_len", 1, self.max_self_attn_cache_len
).bind(pos)
x = block(x, mask=self.mask, len=len_v)
return self.after_start_output_tok(x)
def output_tok(self, x):
return (self.ln(x) @ self.token_embedding.weight.T).realize()
class Whisper:
def __init__(self, dims, batch_size=1):
self.encoder = AudioEncoder(**dims)
@ -145,7 +210,10 @@ HOP_LENGTH = 160
N_MELS = 80
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
def prep_audio(
waveforms: List[np.ndarray], batch_size: int, truncate=False
) -> np.ndarray:
"""
:param waveforms: A list of possibly variable length 16000Hz audio samples
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
@ -153,24 +221,30 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) ->
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
:return: mel spectrogram of the given waveforms
"""
def pad_or_trim(arr, target_len):
curr_len = len(arr)
if curr_len == target_len:
return arr
elif curr_len < target_len:
return np.pad(arr, (0, target_len - curr_len), 'constant')
return np.pad(arr, (0, target_len - curr_len), "constant")
else:
return arr[:target_len]
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
max_len += SAMPLES_PER_SEGMENT - r
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
assert waveforms.shape[0] <= batch_size
if waveforms.shape[0] < batch_size:
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
waveforms = np.pad(
waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0))
)
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
stft = librosa.stft(
waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann", dtype=np.csingle
)
magnitudes = np.absolute(stft[..., :-1]) ** 2
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
@ -180,22 +254,118 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) ->
return log_spec
LANGUAGES = {
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
"he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian",
"th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu",
"fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
"gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian",
"be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole",
"ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy",
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
def get_encoding(encoding_name):
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
with fetch(
f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken"
).open() as f:
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in f if line)
}
n_vocab = len(ranks)
specials = [
"<|endoftext|>",
@ -212,12 +382,15 @@ def get_encoding(encoding_name):
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
n_vocab += len(specials)
import tiktoken
return tiktoken.Encoding(
name=encoding_name,
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens)
special_tokens=special_tokens,
)
MODEL_URLS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
@ -232,23 +405,28 @@ MODEL_URLS = {
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}
def init_whisper(model_name="tiny.en", batch_size=1):
assert MODEL_URLS[model_name] is not None
filename = fetch(MODEL_URLS[model_name])
state = torch_load(filename)
model = Whisper(state['dims'], batch_size)
load_state_dict(model, state['model_state_dict'], strict=False)
model = Whisper(state["dims"], batch_size)
load_state_dict(model, state["model_state_dict"], strict=False)
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
return model, enc
def load_file_waveform(filename):
waveform, _ = librosa.load(filename, sr=RATE)
return waveform
def transcribe_file(model, enc, filename):
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
def transcribe_waveform(model, enc, waveforms, truncate=False):
"""
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
@ -260,12 +438,18 @@ def transcribe_waveform(model, enc, waveforms, truncate=False):
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
raise Exception("Multi-segment transcription not supported with batch audio input")
raise Exception(
"Multi-segment transcription not supported with batch audio input"
)
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
if model.is_multilingual:
# TODO detect language
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
language_token = (
enc._special_tokens["<|startoftranscript|>"]
+ 1
+ tuple(LANGUAGES.keys()).index("en")
)
start_tokens.append(language_token)
start_tokens.append(enc._special_tokens["<|transcribe|>"])
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
@ -274,52 +458,83 @@ def transcribe_waveform(model, enc, waveforms, truncate=False):
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
encoded_audio = model.encoder.encode(
Tensor(log_spec[:, :, curr_frame : curr_frame + FRAMES_PER_SEGMENT])
)
pos = 0
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
if curr_frame > 0:
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt = np.concatenate((
prompt = np.concatenate(
(
[enc._special_tokens["<|startofprev|>"]],
transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :],
start_tokens))
start_tokens,
)
)
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
transcription_start_index = len(curr_segment_tokens[0])
for i in range(model.decoder.max_tokens_to_sample):
out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
out = model.decoder(
Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]),
pos,
encoded_audio,
streaming=curr_frame > 0,
)
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
curr_segment_tokens = np.concatenate(
(curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1
)
pos = curr_segment_tokens.shape[-1] - 1
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
if DEBUG >= 1:
print(
i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens))
)
if (curr_segment_tokens[:, -1] == eot).all():
break
for i, t in enumerate(curr_segment_tokens):
eot_index = np.where(t == eot)[0]
eot_index = None if len(eot_index) == 0 else eot_index[0]
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
transcription_tokens[i] = np.concatenate(
(transcription_tokens[i], t[transcription_start_index:eot_index])
)
transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
transcriptions = list(
map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens)
)
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
CHUNK = 1600
RECORD_SECONDS = 10
def listener(q):
import pyaudio
p = pyaudio.PyAudio()
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=RATE,
input=True,
frames_per_buffer=CHUNK,
)
print("listening")
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
data = stream.read(CHUNK)
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
waveform = (np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3
q.put(waveform)
print("done listening")
if __name__ == "__main__":
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
model, enc = init_whisper(
"small.en" if getenv("SMALL") else "tiny.en", batch_size=1
)
if len(sys.argv) > 1:
print(transcribe_file(model, enc, sys.argv[1]))
@ -330,20 +545,29 @@ if __name__ == "__main__":
p.daemon = True
p.start()
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
lst = [
enc._special_tokens["<|startoftranscript|>"],
enc._special_tokens["<|notimestamps|>"],
]
total = None
did_read = False
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
while not q.empty() or total is None:
waveform = q.get()
if total is None: total = waveform
else: total = np.concatenate([total, waveform])
if total is None:
total = waveform
else:
total = np.concatenate([total, waveform])
did_read = True
if did_read:
log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
log_spec = prep_audio(
total.reshape(1, -1), model.batch_size, truncate=True
)
encoded_audio = model.encoder.encode(Tensor(log_spec))
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
out = model.decoder(
Tensor([lst]), 0, encoded_audio, streaming=True
).realize()
idx = int(out[0, -1].argmax().numpy().item())
lst.append(idx)
dec = enc.decode(lst)

View File

@ -10,11 +10,14 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d, Conv2d
from tinygrad.helpers import fetch
def show_labels(prediction, confidence=0.5, num_classes=80):
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_bytes()
coco_labels = coco_labels.decode('utf-8').split('\n')
coco_labels = fetch(
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
).read_bytes()
coco_labels = coco_labels.decode("utf-8").split("\n")
prediction = prediction.detach().numpy()
conf_mask = (prediction[:,:,4] > confidence)
conf_mask = prediction[:, :, 4] > confidence
prediction *= np.expand_dims(conf_mask, 2)
labels = []
# Iterate over batches
@ -30,16 +33,22 @@ def show_labels(prediction, confidence=0.5, num_classes=80):
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7))
classes, indexes = np.unique(image_pred_[:, -1], return_index=True)
for index, coco_class in enumerate(classes):
label, probability = coco_labels[int(coco_class)], image_pred_[indexes[index]][4] * 100
label, probability = (
coco_labels[int(coco_class)],
image_pred_[indexes[index]][4] * 100,
)
print(f"Detected {label} {probability:.2f}")
labels.append(label)
return labels
def add_boxes(img, prediction):
if isinstance(prediction, int): # no predictions
return img
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
coco_labels = coco_labels.decode('utf-8').split('\n')
coco_labels = fetch(
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
)
coco_labels = coco_labels.decode("utf-8").split("\n")
height, width = img.shape[0:2]
scale_factor = 608 / width
prediction[:, [1, 3]] -= (608 - scale_factor * width) / 2
@ -55,9 +64,18 @@ def add_boxes(img, prediction):
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4
img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1)
img = cv2.putText(img, label, (corner1[0], corner1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1)
img = cv2.putText(
img,
label,
(corner1[0], corner1[1] + t_size[1] + 4),
cv2.FONT_HERSHEY_PLAIN,
1,
[225, 255, 255],
1,
)
return img
def bbox_iou(box1, box2):
"""
Returns the IoU of two bounding boxes
@ -74,24 +92,27 @@ def bbox_iou(box1, box2):
inter_rect_x2 = np.maximum(b1_x2, b2_x2)
inter_rect_y2 = np.maximum(b1_y2, b2_y2)
# Intersection area
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, 99999)
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(
inter_rect_y2 - inter_rect_y1 + 1, 0, 99999
)
# Union Area
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
iou = inter_area / (b1_area + b2_area - inter_area)
return iou
def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
prediction = prediction.detach().numpy()
conf_mask = (prediction[:,:,4] > confidence)
conf_mask = prediction[:, :, 4] > confidence
conf_mask = np.expand_dims(conf_mask, 2)
prediction = prediction * conf_mask
# Non max suppression
box_corner = prediction
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
write = False
# Process img
@ -121,7 +142,10 @@ def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
for i in range(image_pred_class.shape[0]):
# Get the IOUs of all boxes that come after the one we are looking at in the loop
try:
ious = bbox_iou(np.expand_dims(image_pred_class[i], axis=0), image_pred_class[i+1:])
ious = bbox_iou(
np.expand_dims(image_pred_class[i], axis=0),
image_pred_class[i + 1 :],
)
except:
break
# Zero out all the detections that have IoU > threshold
@ -139,6 +163,7 @@ def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
output = np.concatenate((output, out))
return output
def infer(model, img):
img = np.array(Image.fromarray(img).resize((608, 608)))
img = img[:, :, ::-1].transpose((2, 0, 1))
@ -149,9 +174,9 @@ def infer(model, img):
def parse_cfg(cfg):
# Return a list of blocks
lines = cfg.decode("utf-8").split('\n')
lines = cfg.decode("utf-8").split("\n")
lines = [x for x in lines if len(x) > 0]
lines = [x for x in lines if x[0] != '#']
lines = [x for x in lines if x[0] != "#"]
lines = [x.rstrip().lstrip() for x in lines]
block, blocks = {}, []
for line in lines:
@ -166,6 +191,7 @@ def parse_cfg(cfg):
blocks.append(block)
return blocks
# TODO: Speed up this function, avoid copying stuff from GPU to CPU
def predict_transform(prediction, inp_dim, anchors, num_classes):
batch_size = prediction.shape[0]
@ -173,9 +199,13 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
grid_size = inp_dim // stride
bbox_attrs = 5 + num_classes
num_anchors = len(anchors)
prediction = prediction.reshape(shape=(batch_size, bbox_attrs*num_anchors, grid_size*grid_size))
prediction = prediction.reshape(
shape=(batch_size, bbox_attrs * num_anchors, grid_size * grid_size)
)
prediction = prediction.transpose(1, 2)
prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs))
prediction = prediction.reshape(
shape=(batch_size, grid_size * grid_size * num_anchors, bbox_attrs)
)
prediction_cpu = prediction.numpy()
for i in (0, 1, 4):
prediction_cpu[:, :, i] = 1 / (1 + np.exp(-prediction_cpu[:, :, i]))
@ -193,7 +223,9 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
anchors = np.expand_dims(anchors, 0)
prediction_cpu[:, :, :2] += x_y_offset
prediction_cpu[:, :, 2:4] = np.exp(prediction_cpu[:, :, 2:4]) * anchors
prediction_cpu[:,:,5:5+num_classes] = 1 / (1 + np.exp(-prediction_cpu[:,:,5:5+num_classes]))
prediction_cpu[:, :, 5 : 5 + num_classes] = 1 / (
1 + np.exp(-prediction_cpu[:, :, 5 : 5 + num_classes])
)
prediction_cpu[:, :, :4] *= stride
return Tensor(prediction_cpu)
@ -222,18 +254,33 @@ class Darknet:
filters = int(x["filters"])
padding = int(x["pad"])
pad = (int(x["size"]) - 1) // 2 if padding else 0
module.append(Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias=bias))
module.append(
Conv2d(
prev_filters,
filters,
int(x["size"]),
int(x["stride"]),
pad,
bias=bias,
)
)
# BatchNorm2d
if batch_normalize:
module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True))
module.append(
BatchNorm2d(filters, eps=1e-05, track_running_stats=True)
)
# LeakyReLU activation
if activation == "leaky":
module.append(lambda x: x.leakyrelu(0.1))
elif module_type == "maxpool":
size, stride = int(x["size"]), int(x["stride"])
module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride))
module.append(
lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride)
)
elif module_type == "upsample":
module.append(lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1)))
module.append(
lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1))
)
elif module_type == "route":
x["layers"] = x["layers"].split(",")
# Start of route
@ -243,11 +290,15 @@ class Darknet:
end = int(x["layers"][1])
except:
end = 0
if start > 0: start -= index
if end > 0: end -= index
if start > 0:
start -= index
if end > 0:
end -= index
module.append(lambda x: x)
if end < 0:
filters = output_filters[index + start] + output_filters[index + end]
filters = (
output_filters[index + start] + output_filters[index + end]
)
else:
filters = output_filters[index + start]
# Shortcut corresponds to skip connection
@ -256,7 +307,9 @@ class Darknet:
elif module_type == "yolo":
mask = list(map(int, x["mask"].split(",")))
anchors = [int(a) for a in x["anchors"].split(",")]
anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)]
anchors = [
(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)
]
module.append([anchors[i] for i in mask])
# Append to module_list
module_list.append(module)
@ -308,8 +361,12 @@ class Darknet:
# Cast the loaded weights into dims of model weights
bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape))
bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape))
bn_running_mean = bn_running_mean.reshape(shape=tuple(bn.running_mean.shape))
bn_running_var = bn_running_var.reshape(shape=tuple(bn.running_var.shape))
bn_running_mean = bn_running_mean.reshape(
shape=tuple(bn.running_mean.shape)
)
bn_running_var = bn_running_var.reshape(
shape=tuple(bn.running_var.shape)
)
# Copy data
bn.bias = bn_biases
bn.weight = bn_weights
@ -337,7 +394,7 @@ class Darknet:
outputs = {} # Cached outputs for route layer
detections, write = None, False
for i, module in enumerate(modules):
module_type = (module["type"])
module_type = module["type"]
if module_type == "convolutional" or module_type == "upsample":
for layer in self.module_list[i]:
x = layer(x)
@ -349,7 +406,8 @@ class Darknet:
if len(layers) == 1:
x = outputs[i + (layers[0])]
else:
if (layers[1]) > 0: layers[1] = layers[1] - i
if (layers[1]) > 0:
layers[1] = layers[1] - i
map1 = outputs[i + layers[0]]
map2 = outputs[i + layers[1]]
x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1))
@ -364,19 +422,26 @@ class Darknet:
if not write:
detections, write = x, True
else:
detections = Tensor(np.concatenate((detections.numpy(), x.numpy()), axis=1))
detections = Tensor(
np.concatenate((detections.numpy(), x.numpy()), axis=1)
)
outputs[i] = x
return detections
if __name__ == "__main__":
model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg'))
model = Darknet(
fetch(
"https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg"
)
)
print("Loading weights file (237MB). This might take a while…")
model.load_weights('https://pjreddie.com/media/files/yolov3.weights')
model.load_weights("https://pjreddie.com/media/files/yolov3.weights")
if len(sys.argv) > 1:
url = sys.argv[1]
else:
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
if url == 'webcam':
if url == "webcam":
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
while 1:
@ -386,21 +451,21 @@ if __name__ == "__main__":
img = Image.fromarray(frame[:, :, [2, 1, 0]])
boxes = add_boxes(np.array(img.resize((608, 608))), prediction)
boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR)
cv2.imshow('yolo', boxes)
if cv2.waitKey(1) & 0xFF == ord('q'):
cv2.imshow("yolo", boxes)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cap.release()
cv2.destroyAllWindows()
elif url.startswith('http'):
elif url.startswith("http"):
img_stream = io.BytesIO(fetch(url))
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
else:
img = cv2.imread(url)
st = time.time()
print('running inference…')
print("running inference…")
prediction = infer(model, img)
print(f'did inference in {(time.time() - st):2f}s')
print(f"did inference in {(time.time() - st):2f}s")
show_labels(prediction)
prediction = process_results(prediction)
boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction)
cv2.imwrite('boxes.jpg', boxes)
cv2.imwrite("boxes.jpg", boxes)

View File

@ -12,7 +12,10 @@ if not Path("yolov8n-seg.onnx").is_file():
model.export(format="onnx", imgsz=[480, 640])
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
# TODO: move get example inputs to onnx
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
input_shapes = {
inp.name: tuple(x.dim_value for x in inp.type.tensor_type.shape.dim)
for inp in onnx_model.graph.input
}
print(input_shapes)
run_onnx = get_run_onnx(onnx_model)
run_onnx({"images": Tensor.zeros(1, 3, 480, 640)}, debug=True)

View File

@ -12,8 +12,11 @@ from tinygrad.nn.state import safe_load, load_state_dict
# Model architecture from https://github.com/ultralytics/ultralytics/issues/189
# The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
# Pre processing image functions.
def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
def compute_transform(
image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32
):
shape = image.shape[:2] # current shape [height, width]
new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
@ -24,25 +27,39 @@ def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False,
new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad
dw /= 2
dh /= 2
image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else image
image = (
cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
if shape[::-1] != new_unpad
else image
)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
image = cv2.copyMakeBorder(
image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
)
return image
def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
same_shapes = all(x.shape == im[0].shape for x in im)
auto = same_shapes and model_pt
im = Tensor([compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im])
im = Tensor(
[
compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride)
for x in im
]
)
im = Tensor.stack(im) if im.shape[0] > 1 else im
im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im /= 255 # 0 - 255 to 0.0 - 1.0
return im
# Post Processing functions
def box_area(box):
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
def box_iou(box1, box2):
lt = np.maximum(box1[:, None, :2], box2[:, :2])
rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
@ -53,6 +70,7 @@ def box_iou(box1, box2):
iou = inter / (area1 + area2 - inter)
return iou
def compute_nms(boxes, scores, iou_threshold):
order, keep = scores.argsort()[::-1], []
while order.size > 0:
@ -65,7 +83,16 @@ def compute_nms(boxes, scores, iou_threshold):
order = order[inds + 1]
return np.array(keep)
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, max_det=300, nc=0, max_wh=7680):
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
agnostic=False,
max_det=300,
nc=0,
max_wh=7680,
):
prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
xc = np.amax(prediction[:, 4 : 4 + nc], axis=1) > conf_thres
@ -74,12 +101,16 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=Fa
for xi, x in enumerate(prediction):
x = x.swapaxes(0, -1)[xc[xi]]
if not x.shape[0]: continue
if not x.shape[0]:
continue
box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(cls, axis=1, keepdims=True)
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(
cls, axis=1, keepdims=True
)
x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
x = x[conf.ravel() > conf_thres]
if not x.shape[0]: continue
if not x.shape[0]:
continue
x = x[np.argsort(-x[:, 4])]
c = x[:, 5:6] * (0 if agnostic else max_wh)
boxes, scores = x[:, :4] + c, x[:, 4]
@ -87,12 +118,15 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=Fa
output[xi] = x[i]
return output
def postprocess(preds, img, orig_imgs):
print('copying to CPU now for post processing')
print("copying to CPU now for post processing")
# if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
# TODO: make non_max_suppression in tinygrad - to make this faster
preds = preds.numpy() if isinstance(preds, Tensor) else preds
preds = non_max_suppression(prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300)
preds = non_max_suppression(
prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300
)
all_preds = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
@ -101,8 +135,16 @@ def postprocess(preds, img, orig_imgs):
all_preds.append(pred)
return all_preds
def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5):
color_dict = {label: tuple((((i+1) * 50) % 256, ((i+1) * 100) % 256, ((i+1) * 150) % 256)) for i, label in enumerate(class_labels)}
def draw_bounding_boxes_and_save(
orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5
):
color_dict = {
label: tuple(
(((i + 1) * 50) % 256, ((i + 1) * 100) % 256, ((i + 1) * 150) % 256)
)
for i, label in enumerate(class_labels)
}
font = cv2.FONT_HERSHEY_SIMPLEX
def is_bright_color(color):
@ -110,9 +152,15 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
brightness = (r * 299 + g * 587 + b * 114) / 1000
return brightness > 127
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(zip(orig_img_paths, output_img_paths, all_predictions)):
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(
zip(orig_img_paths, output_img_paths, all_predictions)
):
predictions = np.array(predictions)
orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1)
orig_img = (
cv2.imread(orig_img_path)
if not isinstance(orig_img_path, np.ndarray)
else cv2.imdecode(orig_img_path, 1)
)
height, width, _ = orig_img.shape
box_thickness = int((height + width) / 400)
font_scale = (height + width) / 2500
@ -129,10 +177,29 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
label = f"{class_labels[class_id]} {conf:.2f}"
text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1)
cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1)
label_y, bg_y = (
(y1 - 4, y1 - text_size[1] - 4)
if y1 - text_size[1] - 4 > 0
else (y1 + text_size[1], y1)
)
cv2.rectangle(
orig_img,
(x1, bg_y),
(x1 + text_size[0], bg_y + text_size[1]),
color,
-1,
)
font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA)
cv2.putText(
orig_img,
label,
(x1, label_y),
font,
font_scale,
font_color,
1,
cv2.LINE_AA,
)
for class_id, pred_list in grouped_preds.items():
pred_list = np.array(pred_list)
@ -155,7 +222,8 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
print(f"- {obj}: {count}")
cv2.imwrite(output_img_path, orig_img)
print(f'saved detections at {output_img_path}')
print(f"saved detections at {output_img_path}")
# utility functions for forward pass.
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
@ -168,6 +236,7 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
return c_xy.cat(wh, dim=1)
return x1y1.cat(x2y2, dim=1)
def make_anchors(feats, strides, grid_cell_offset=0.5):
anchor_points, stride_tensor = [], []
assert feats is not None
@ -183,25 +252,39 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2))
stride_tensor.append(Tensor.full((h * w), stride))
anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
stride_tensor = (
stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
)
return anchor_points, stride_tensor
# this function is from the original implementation
def autopad(k, p=None, d=1): # kernel, padding, dilation
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
k = (
d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
) # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
def clip_boxes(boxes, shape):
boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
return boxes
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
gain = ratio_pad if ratio_pad else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
pad = ((img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2)
gain = (
ratio_pad
if ratio_pad
else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
)
pad = (
(img1_shape[1] - img0_shape[1] * gain) / 2,
(img1_shape[0] - img0_shape[0] * gain) / 2,
)
boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
boxes_np[..., [0, 2]] -= pad[0]
boxes_np[..., [1, 3]] -= pad[1]
@ -209,6 +292,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
boxes_np = clip_boxes(boxes_np, img0_shape)
return boxes_np
def xywh2xyxy(x):
xy = x[..., :2] # center x, y
wh = x[..., 2:4] # width, height
@ -217,8 +301,16 @@ def xywh2xyxy(x):
result = np.concatenate((xy1, xy2), axis=-1)
return Tensor(result) if isinstance(x, Tensor) else result
def get_variant_multiples(variant):
return {'n':(0.33, 0.25, 2.0), 's':(0.33, 0.50, 2.0), 'm':(0.67, 0.75, 1.5), 'l':(1.0, 1.0, 1.0), 'x':(1, 1.25, 1.0) }.get(variant, None)
return {
"n": (0.33, 0.25, 2.0),
"s": (0.33, 0.50, 2.0),
"m": (0.67, 0.75, 1.5),
"l": (1.0, 1.0, 1.0),
"x": (1, 1.25, 1.0),
}.get(variant, None)
def label_predictions(all_predictions):
class_index_count = defaultdict(int)
@ -230,6 +322,7 @@ def label_predictions(all_predictions):
return dict(class_index_count)
# this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling)
class Upsample:
def __init__(self, scale_factor: int, mode: str = "nearest") -> None:
@ -240,41 +333,86 @@ class Upsample:
def __call__(self, x: Tensor) -> Tensor:
assert len(x.shape) > 2 and len(x.shape) <= 5
(b, c), _lens = x.shape[:2], len(x.shape[2:])
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(*[1, 1, 1] + [self.scale_factor] * _lens)
return tmp.reshape(list(x.shape) + [self.scale_factor] * _lens).permute([0, 1] + list(chain.from_iterable([[y+2, y+2+_lens] for y in range(_lens)]))).reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(
*[1, 1, 1] + [self.scale_factor] * _lens
)
return (
tmp.reshape(list(x.shape) + [self.scale_factor] * _lens)
.permute(
[0, 1]
+ list(
chain.from_iterable([[y + 2, y + 2 + _lens] for y in range(_lens)])
)
)
.reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
)
class Conv_Block:
def __init__(self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None):
self.conv = Conv2d(c1,c2, kernel_size, stride, padding=autopad(kernel_size, padding, dilation), bias=False, groups=groups, dilation=dilation)
def __init__(
self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None
):
self.conv = Conv2d(
c1,
c2,
kernel_size,
stride,
padding=autopad(kernel_size, padding, dilation),
bias=False,
groups=groups,
dilation=dilation,
)
self.bn = BatchNorm2d(c2, eps=0.001)
def __call__(self, x):
return self.bn(self.conv(x)).silu()
class Bottleneck:
def __init__(self, c1, c2 , shortcut: bool, g=1, kernels: list = (3,3), channel_factor=0.5):
def __init__(
self, c1, c2, shortcut: bool, g=1, kernels: list = (3, 3), channel_factor=0.5
):
c_ = int(c2 * channel_factor)
self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None)
self.cv2 = Conv_Block(c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g)
self.cv2 = Conv_Block(
c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g
)
self.residual = c1 == c2 and shortcut
def __call__(self, x):
return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x))
class C2f:
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
self.c = int(c2 * e)
self.cv1 = Conv_Block(c1, 2 * self.c, 1,)
self.cv1 = Conv_Block(
c1,
2 * self.c,
1,
)
self.cv2 = Conv_Block((2 + n) * self.c, c2, 1)
self.bottleneck = [Bottleneck(self.c, self.c, shortcut, g, kernels=[(3, 3), (3, 3)], channel_factor=1.0) for _ in range(n)]
self.bottleneck = [
Bottleneck(
self.c,
self.c,
shortcut,
g,
kernels=[(3, 3), (3, 3)],
channel_factor=1.0,
)
for _ in range(n)
]
def __call__(self, x):
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.bottleneck)
z = y[0]
for i in y[1:]: z = z.cat(i, dim=1)
for i in y[1:]:
z = z.cat(i, dim=1)
return self.cv2(z)
class SPPF:
def __init__(self, c1, c2, k=5):
c_ = c1 // 2 # hidden channels
@ -282,7 +420,9 @@ class SPPF:
self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None)
# TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually.
self.maxpool = lambda x : x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1)
self.maxpool = lambda x: x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(
kernel_size=k, stride=1
)
def __call__(self, x):
x = self.cv1(x)
@ -291,6 +431,7 @@ class SPPF:
x4 = self.maxpool(x3)
return self.cv2(x.cat(x2, x3, x4, dim=1))
class DFL:
def __init__(self, c1=16):
self.conv = Conv2d(c1, 1, 1, bias=False)
@ -300,15 +441,33 @@ class DFL:
def __call__(self, x):
b, c, a = x.shape # batch, channels, anchors
return self.conv(x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)).reshape(b, 4, a)
return self.conv(
x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)
).reshape(b, 4, a)
# backbone
class Darknet:
def __init__(self, w, r, d):
self.b1 = [Conv_Block(c1=3, c2= int(64*w), kernel_size=3, stride=2, padding=1), Conv_Block(int(64*w), int(128*w), kernel_size=3, stride=2, padding=1)]
self.b2 = [C2f(c1=int(128*w), c2=int(128*w), n=round(3*d), shortcut=True), Conv_Block(int(128*w), int(256*w), 3, 2, 1), C2f(int(256*w), int(256*w), round(6*d), True)]
self.b3 = [Conv_Block(int(256*w), int(512*w), kernel_size=3, stride=2, padding=1), C2f(int(512*w), int(512*w), round(6*d), True)]
self.b4 = [Conv_Block(int(512*w), int(512*w*r), kernel_size=3, stride=2, padding=1), C2f(int(512*w*r), int(512*w*r), round(3*d), True)]
self.b1 = [
Conv_Block(c1=3, c2=int(64 * w), kernel_size=3, stride=2, padding=1),
Conv_Block(int(64 * w), int(128 * w), kernel_size=3, stride=2, padding=1),
]
self.b2 = [
C2f(c1=int(128 * w), c2=int(128 * w), n=round(3 * d), shortcut=True),
Conv_Block(int(128 * w), int(256 * w), 3, 2, 1),
C2f(int(256 * w), int(256 * w), round(6 * d), True),
]
self.b3 = [
Conv_Block(int(256 * w), int(512 * w), kernel_size=3, stride=2, padding=1),
C2f(int(512 * w), int(512 * w), round(6 * d), True),
]
self.b4 = [
Conv_Block(
int(512 * w), int(512 * w * r), kernel_size=3, stride=2, padding=1
),
C2f(int(512 * w * r), int(512 * w * r), round(3 * d), True),
]
self.b5 = [SPPF(int(512 * w * r), int(512 * w * r), 5)]
def return_modules(self):
@ -322,16 +481,28 @@ class Darknet:
x5 = x4.sequential(self.b5)
return (x2, x3, x5)
# yolo fpn (neck)
class Yolov8NECK:
def __init__(self, w, r, d): # width_multiple, ratio_multiple, depth_multiple
self.up = Upsample(2, mode='nearest')
self.n1 = C2f(c1=int(512*w*(1+r)), c2=int(512*w), n=round(3*d), shortcut=False)
self.up = Upsample(2, mode="nearest")
self.n1 = C2f(
c1=int(512 * w * (1 + r)), c2=int(512 * w), n=round(3 * d), shortcut=False
)
self.n2 = C2f(c1=int(768 * w), c2=int(256 * w), n=round(3 * d), shortcut=False)
self.n3 = Conv_Block(c1=int(256*w), c2=int(256*w), kernel_size=3, stride=2, padding=1)
self.n3 = Conv_Block(
c1=int(256 * w), c2=int(256 * w), kernel_size=3, stride=2, padding=1
)
self.n4 = C2f(c1=int(768 * w), c2=int(512 * w), n=round(3 * d), shortcut=False)
self.n5 = Conv_Block(c1=int(512* w), c2=int(512 * w), kernel_size=3, stride=2, padding=1)
self.n6 = C2f(c1=int(512*w*(1+r)), c2=int(512*w*r), n=round(3*d), shortcut=False)
self.n5 = Conv_Block(
c1=int(512 * w), c2=int(512 * w), kernel_size=3, stride=2, padding=1
)
self.n6 = C2f(
c1=int(512 * w * (1 + r)),
c2=int(512 * w * r),
n=round(3 * d),
shortcut=False,
)
def return_modules(self):
return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6]
@ -343,6 +514,7 @@ class Yolov8NECK:
head_3 = self.n6(self.n5(head_2).cat(p5, dim=1))
return [head_1, head_2, head_3]
# task specific head.
class DetectionHead:
def __init__(self, nc=80, filters=()):
@ -354,25 +526,41 @@ class DetectionHead:
c1 = max(filters[0], self.nc)
c2 = max((filters[0] // 4, self.ch * 4))
self.dfl = DFL(self.ch)
self.cv3 = [[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)] for x in filters]
self.cv2 = [[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)] for x in filters]
self.cv3 = [
[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)]
for x in filters
]
self.cv2 = [
[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)]
for x in filters
]
def __call__(self, x):
for i in range(self.nl):
x[i] = (x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1))
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
x[i] = x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1)
self.anchors, self.strides = (
x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)
)
y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
x_cat = y[0].cat(y[1], y[2], dim=2)
box, cls = x_cat[:, : self.ch * 4], x_cat[:, self.ch * 4 :]
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
dbox = (
dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1)
* self.strides
)
z = dbox.cat(cls.sigmoid(), dim=1)
return z
class YOLOv8:
def __init__(self, w, r, d, num_classes): #width_multiple, ratio_multiple, depth_multiple
def __init__(
self, w, r, d, num_classes
): # width_multiple, ratio_multiple, depth_multiple
self.net = Darknet(w, r, d)
self.fpn = Yolov8NECK(w, r, d)
self.head = DetectionHead(num_classes, filters=(int(256*w), int(512*w), int(512*w*r)))
self.head = DetectionHead(
num_classes, filters=(int(256 * w), int(512 * w), int(512 * w * r))
)
def __call__(self, x):
x = self.net(x)
@ -383,27 +571,44 @@ class YOLOv8:
backbone_modules = [*range(10)]
yolov8neck_modules = [12, 15, 16, 18, 19, 21]
yolov8_head_weights = [(22, self.head)]
return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights]
return [
*zip(backbone_modules, self.net.return_modules()),
*zip(yolov8neck_modules, self.fpn.return_modules()),
*yolov8_head_weights,
]
if __name__ == '__main__':
if __name__ == "__main__":
# usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
if len(sys.argv) < 2:
print("Error: Image URL or path not provided.")
sys.exit(1)
img_path = sys.argv[1]
yolo_variant = sys.argv[2] if len(sys.argv) >= 3 else (print("No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']") or 'n')
print(f'running inference for YOLO version {yolo_variant}')
yolo_variant = (
sys.argv[2]
if len(sys.argv) >= 3
else (
print(
"No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']"
)
or "n"
)
)
print(f"running inference for YOLO version {yolo_variant}")
output_folder_path = Path('./outputs_yolov8')
output_folder_path = Path("./outputs_yolov8")
output_folder_path.mkdir(parents=True, exist_ok=True)
# absolute image path or URL
image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
image = [cv2.imdecode(image_location[0], 1)]
out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}").as_posix()]
out_paths = [
(
output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}"
).as_posix()
]
if not isinstance(image[0], np.ndarray):
print('Error in image loading. Check your image file.')
print("Error in image loading. Check your image file.")
sys.exit(1)
pre_processed_image = preprocess(image)
@ -411,19 +616,36 @@ if __name__ == '__main__':
depth, width, ratio = get_variant_multiples(yolo_variant)
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors'))
state_dict = safe_load(
fetch(
f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors"
)
)
load_state_dict(yolo_infer, state_dict)
st = time.time()
predictions = yolo_infer(pre_processed_image)
print(f'did inference in {int(round(((time.time() - st) * 1000)))}ms')
print(f"did inference in {int(round(((time.time() - st) * 1000)))}ms")
post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image)
post_predictions = postprocess(
preds=predictions, img=pre_processed_image, orig_imgs=image
)
# v8 and v3 have same 80 class names for Object Detection
class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n")
class_labels = (
fetch(
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
)
.read_text()
.split("\n")
)
draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels)
draw_bounding_boxes_and_save(
orig_img_paths=image_location,
output_img_paths=out_paths,
all_predictions=post_predictions,
class_labels=class_labels,
)
# TODO for later:
# 1. Fix SPPF minor difference due to maxpool

View File

@ -6,9 +6,9 @@ from coremltools.models.neural_network import datatypes, NeuralNetworkBuilder
# KxK GEMM with bias
K = 64
input_features = [('image', datatypes.Array(K))]
input_features2 = [('image2', datatypes.Array(K))]
output_features = [('probs', datatypes.Array(K))]
input_features = [("image", datatypes.Array(K))]
input_features2 = [("image2", datatypes.Array(K))]
output_features = [("probs", datatypes.Array(K))]
weights = np.zeros((K, K)) + 3
bias = np.ones(K)
@ -17,7 +17,9 @@ builder = NeuralNetworkBuilder(input_features+input_features2, output_features)
# builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med')
# builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs')
builder.add_elementwise(name='element', input_names=['image', 'image2'], output_name='probs', mode='ADD')
builder.add_elementwise(
name="element", input_names=["image", "image2"], output_name="probs", mode="ADD"
)
# builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,))
# builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs')
@ -25,6 +27,11 @@ builder.add_elementwise(name='element', input_names=['image', 'image2'], output_
mlmodel = ct.models.MLModel(builder.spec)
# trigger the ANE!
out = mlmodel.predict({"image": np.zeros(K, dtype=np.float32)+1, "image2": np.zeros(K, dtype=np.float32)+2})
out = mlmodel.predict(
{
"image": np.zeros(K, dtype=np.float32) + 1,
"image2": np.zeros(K, dtype=np.float32) + 2,
}
)
print(out)
mlmodel.save('test.mlmodel')
mlmodel.save("test.mlmodel")

View File

@ -6,7 +6,7 @@ import pylab as plt
from networkx.drawing.nx_pydot import read_dot
ret = os.system("./a.out " + sys.argv[1] + " debug")
assert(ret == 0)
assert ret == 0
df = "debug/model.hwx.zinir_graph_after_reg_spill.dot"

View File

@ -3,17 +3,21 @@ import sys
from hexdump import hexdump
from macholib import MachO
from tinygrad.helpers import getenv
def get_macho(fn):
# mod to make the header okay
# MH_CIGAM_64 is good
dat = open(fn, "rb").read()
dat = b"\xcf\xfa\xed\xfe" + dat[4:]
from tempfile import NamedTemporaryFile
with NamedTemporaryFile(delete=False) as f:
f.write(dat)
f.close()
return MachO.MachO(f.name)
a = get_macho("model.hwx.golden")
# load commands
@ -23,12 +27,19 @@ for c in a.headers[0].commands:
hexdump(c[2])
pass
if c[0].cmd == 6:
print("name:", c[2].decode('utf-8'))
print("name:", c[2].decode("utf-8"))
if c[0].cmd == 8:
print(c[2].decode('utf-8'))
print(c[2].decode("utf-8"))
if c[0].cmd == 25:
for section in c[2]:
print(section.segname.strip(b'\0'), section.sectname.strip(b'\0'), hex(section.addr), hex(section.size), "@", hex(c[1].fileoff))
print(
section.segname.strip(b"\0"),
section.sectname.strip(b"\0"),
hex(section.addr),
hex(section.size),
"@",
hex(c[1].fileoff),
)
# print(dir(section))
if c[1].filesize > 0:
if len(section.section_data) < 0x100:
@ -38,6 +49,7 @@ for c in a.headers[0].commands:
# this parser is wrong (fixed with 64-bit one)
from macholib import SymbolTable
sym = SymbolTable.SymbolTable(a)
syms = {}
@ -52,6 +64,7 @@ for k,v in syms.items():
# **** document what we know ***
from ane import ANE_Struct, ANE
ane = ANE()
aneb = set()
@ -65,6 +78,8 @@ for l in range(0x34, 0xF4):
aneb.add(l)
from termcolor import colored
def compare(x, y):
ss = []
ln = []
@ -73,7 +88,7 @@ def compare(x, y):
ll = (max(len(x), len(y)) + 0xF) // 0x10 * 0x10
highlight = False
next_highlight = 0x2b
next_highlight = 0x2B
for i in range(ll + 1):
if i == next_highlight:
highlight = True
@ -83,35 +98,37 @@ def compare(x, y):
next_highlight = None
else:
highlight = False
a = "%02X" % x[i] if i < len(x) else "--", \
"%02X" % y[i] if i < len(y) else "--"
a = "%02X" % x[i] if i < len(x) else "--", "%02X" % y[i] if i < len(y) else "--"
def fj(x):
ss = []
for i in range(0, 0x10, 4):
ss.append(' '.join(x[i:i+4]))
return ' '.join(ss)
ss.append(" ".join(x[i : i + 4]))
return " ".join(ss)
if i != 0 and i % 0x10 == 0:
ss.append("%8X: " % (i - 0x10) + fj(ln) + " | " + fj(ln2) + "\n")
ln = []
ln2 = []
if a[0] != a[1] and a[0] != "--" and a[1] != "--":
ln.append(colored(a[0], 'green'))
ln2.append(colored(a[1], 'red'))
ln.append(colored(a[0], "green"))
ln2.append(colored(a[1], "red"))
else:
if highlight:
ln.append(colored(a[0], 'yellow'))
ln2.append(colored(a[1], 'yellow'))
ln.append(colored(a[0], "yellow"))
ln2.append(colored(a[1], "yellow"))
else:
if i in aneb:
ln.append(colored(a[0], 'white'))
ln2.append(colored(a[1], 'white'))
ln.append(colored(a[0], "white"))
ln2.append(colored(a[1], "white"))
else:
ln.append(a[0])
ln2.append(a[1])
return ''.join(ss)
return "".join(ss)
import json
aneregs = dict(json.load(open("aneregs.json")))
g = get_macho("model.hwx.golden" if len(sys.argv) < 2 else sys.argv[1])
f1 = g.headers[0].commands[1][2][0].section_data

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
from ane import ANE
ane = ANE()
lens = {}
@ -30,7 +31,7 @@ for i in range(0x300):
pos.append((k, (i, j, lens[k])))
import json
jpos = json.dumps(pos, indent=2)
with open("aneregs.json", "w") as f:
f.write(jpos)

View File

@ -2,6 +2,7 @@ import ctypes
from subprocess import check_output
from hexdump import hexdump
def get_pid(name):
try:
output = check_output(["pgrep", name])
@ -9,8 +10,10 @@ def get_pid(name):
except:
return None
from ctypes.util import find_library
libc = ctypes.CDLL(find_library('c'))
libc = ctypes.CDLL(find_library("c"))
amfid_pid = get_pid("amfid")
@ -21,6 +24,7 @@ print(amfid_pid, ret, task, mytask)
# myport = libc.mach_task_self()
class vm_region_submap_short_info_data_64(ctypes.Structure):
_pack_ = 1
_fields_ = [
@ -38,6 +42,8 @@ class vm_region_submap_short_info_data_64(ctypes.Structure):
("object_id", ctypes.c_uint32),
("user_wired_count", ctypes.c_uint32),
]
submap_info_size = ctypes.sizeof(vm_region_submap_short_info_data_64) // 4
address = ctypes.c_ulong(0)
@ -48,21 +54,31 @@ depth = 0
c_depth = ctypes.c_uint32(depth)
for i in range(1):
ret = libc.mach_vm_region_recurse(task,
ctypes.pointer(address), ctypes.pointer(mapsize),
ctypes.pointer(c_depth), ctypes.pointer(sub_info),
ctypes.pointer(count))
ret = libc.mach_vm_region_recurse(
task,
ctypes.pointer(address),
ctypes.pointer(mapsize),
ctypes.pointer(c_depth),
ctypes.pointer(sub_info),
ctypes.pointer(count),
)
print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection)
# address.value += mapsize.value
# exit(0)
patch_address = address.value + 0x8e38
patch_address = address.value + 0x8E38
patch = b"\x00\x00\x80\xd2"
pdata = ctypes.c_void_p(0)
data_cnt = ctypes.c_uint32(0)
ret = libc.mach_vm_read(task, ctypes.c_ulong(patch_address), 4, ctypes.pointer(pdata), ctypes.pointer(data_cnt))
ret = libc.mach_vm_read(
task,
ctypes.c_ulong(patch_address),
4,
ctypes.pointer(pdata),
ctypes.pointer(data_cnt),
)
buf = ctypes.string_at(pdata.value, data_cnt.value)
hexdump(buf)

View File

@ -6,12 +6,15 @@ import collections
import numpy as np
import faulthandler
import struct
faulthandler.enable()
basedir = Path(__file__).resolve().parent
libane = None
aneregs = None
def init_libane():
global libane, aneregs
libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())
@ -32,71 +35,56 @@ def init_libane():
with open(basedir / "aneregs.json") as f:
aneregs = json.load(f)
ANE_Struct = [
# aneTD.Header
("u32", 0x1C, "NextCommandOffset"),
# KernelDMASrc @ section @ 0x2C len 0xF4
# reloc 0x2c-0x34?? = weights
# u32[16] 0x34-0x74 = 0x80 | 1 if used
# u32[16] 0x74-0xB4 = <channel data offset>
# u32[16] 0xB4-0xF4 = <channel data length>
# Common @ section @ 0x128 len 0x3C (conv)
("u16", 0x128, "InputWidth"),
("u16", 0x12A, "InputHeight"),
("u16", 0x12C, "InputDepth"),
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
# UInt8 = 0, Int8 = 1, Float16 = 2
("u32", 0x134, "InputChannels"),
("u32", 0x138, "OutputChannels"),
("u16", 0x13C, "OutputWidth"),
("u16", 0x13E, "OutputHeight"),
("u16", 0x140, "OutputDepth"),
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
("u16", 0x14C, "BatchSize"),
# TileDMASrc @ section @ 0x16C len 0x6C (input)
# reloc 0x16c-0x174 = image
("u32", 0x178, "InputRowStride"),
("u32", 0x17C, "InputPlaneStride"),
("u32", 0x180, "InputDepthStride"),
("u32", 0x184, "InputBatchStride"),
("u8", 0x1A7, "InputInterleave"),
# L2 @ section @ 0x1E0 len 0x44
# [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines
# [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv?
# [0x21c, 0x220, 0x224] = engines for outconv?
# NE @ section @ 0x22c len 0xC (scaling)
("u16", 0x230, "BiasScalar"),
("u16", 0x232, "ScaleScalar"),
# section @ 0x240 len 0x10
("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom
("u32", 0x250, "PostScale"),
# TileDMADst @ section @ 0x258 len 0x18
# HandleTileDmaDstConfig
# 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0;
# (GetCacheHintRegisterValue & 0xf) << 6;
("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at?
# 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6;
("u32", 0x260, "OutputRowStride"),
("u32", 0x264, "OutputPlaneStride"),
("u32", 0x268, "OutputDepthStride"),
("u32", 0x26C, "OutputBatchStride"),
# 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000;
# uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031;
# (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18;
@ -108,24 +96,26 @@ for typ, num, nam in ANE_Struct:
styp = {"u32": "I", "u16": "H", "u8": "B"}[typ]
ANE_Struct_Dict[nam] = (styp, num)
class ANETensor:
def __init__(self, *shape):
self.shape = shape
self.dtype = np.float16
self.sz = int(np.prod(shape))
assert(self.sz <= 0x4000)
assert self.sz <= 0x4000
self.tt = libane.ANE_TensorCreate(self.sz, 1)
assert(self.tt is not None)
assert self.tt is not None
def data(self):
data = libane.ANE_TensorData(self.tt)
assert(data is not None)
assert data is not None
# print(hex(addressof(data.contents)))
buf = np.ctypeslib.as_array(data, shape=(self.sz,))
ret = np.frombuffer(buf, dtype=self.dtype)
# print(ret.data)
return ret
class ANE:
def __init__(self):
init_libane()
@ -133,11 +123,13 @@ class ANE:
def compile(self, dat):
ret = libane.ANE_Compile(create_string_buffer(dat), len(dat))
assert(ret is not None)
assert ret is not None
return ret
def run(self, prog, tin, tout, tweights=None):
libane.ANE_Run(prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0)
libane.ANE_Run(
prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0
)
def tensor(self, shape):
return ANETensor(shape)
@ -165,9 +157,9 @@ class ANE:
return dat
def debug(self, dat, mems=0):
add = [0x30, 0x1d4, 0x220, 0x29c, 0x2f0, 0x30c, 0x32c]
add = [0x30, 0x1D4, 0x220, 0x29C, 0x2F0, 0x30C, 0x32C]
lens = [244, 60, 108, 68, 12, 16, 24]
ptr = 0x2b
ptr = 0x2B
ddat = dat[0:0x28]
for a, pm in zip(add, lens):
# assert pm == dat[ptr]
@ -176,7 +168,12 @@ class ANE:
ptr += pm + 8
ddat += b"\x00" * 0x100
ret = collections.OrderedDict()
for ln in libane.ANE_RegDebug(0, create_string_buffer(ddat), mems).decode('utf-8').strip().split("\n"):
for ln in (
libane.ANE_RegDebug(0, create_string_buffer(ddat), mems)
.decode("utf-8")
.strip()
.split("\n")
):
lnn = ln.split(" = ")
if len(lnn) == 2:
ret[lnn[0]] = int(lnn[1])
@ -194,6 +191,7 @@ class ANE:
dat[base + a : base + a + len(x)] = x
return dat
if __name__ == "__main__":
ane = ANE()
@ -212,11 +210,10 @@ if __name__ == "__main__":
md = dat[0x4000:0x4300]
dd = ane.unpack(md)
mdf = ane.pack(dd, md)
assert(md == mdf)
assert md == mdf
comp = ane.compile(dat)
ret = ane.run(comp, tin, tout)
print("** after **")
print(tind)
print(toutd)

View File

@ -2,6 +2,7 @@
import time
from ane import ANE, ANETensor
def benchmark(ane):
tin = ANETensor(512 * 0x20)
tout = ANETensor(512 * 0x20)
@ -14,7 +15,7 @@ def benchmark(ane):
for i in range(1000):
ret = ane.run(comp, tin, tout)
et = time.time()
ts = (et-st)
ts = et - st
ops = 1000 * 512 * 512 * 2
print("%.2f ms, %.2f gigaops/sec" % (ts * 1000, ops * 1e-9 / ts))
@ -72,7 +73,7 @@ if __name__ == "__main__":
dd = ane.unpack(dat[0x4000:0x4300])
# use the 3rd arg as the weights
dd["aneTD.Header[9].KBase0"] = 6
dd["aneRegs.NE.PostScale.PostScale"] = 0x3c00
dd["aneRegs.NE.PostScale.PostScale"] = 0x3C00
# dd["aneRegs.L2.L2Cfg.InputReLU"] = 1
# dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1
# dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0

View File

@ -1,13 +1,16 @@
from functools import lru_cache
from .tensor import Device, Function, register
@lru_cache
def compile_wrapper(ane, dat):
return ane.compile(dat)
def roundup(x, v):
return x + (v - x) % v
@lru_cache
def compile_relu(ane, sz):
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
@ -17,16 +20,25 @@ def compile_relu(ane, sz):
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
# 0x1f4, 0x1f8?
# 0x214 = L2.ResultBase.Addr
dat = ane.fill(dat, [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214], "I", l2_stride)
dat = ane.fill(dat, [0x1EC, 0x1F0, 0x1F4, 0x1F8, 0x214], "I", l2_stride)
stride = roundup(sz * 2, 0x40)
dat = ane.filln(dat, {
dat = ane.filln(
dat,
{
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
"InputWidth": sz, "OutputWidth": sz,
"InputRowStride": stride, "InputPlaneStride": stride, "InputDepthStride": stride,
"OutputRowStride": stride, "OutputPlaneStride": stride, "OutputDepthStride": stride,
})
"InputWidth": sz,
"OutputWidth": sz,
"InputRowStride": stride,
"InputPlaneStride": stride,
"InputDepthStride": stride,
"OutputRowStride": stride,
"OutputPlaneStride": stride,
"OutputDepthStride": stride,
},
)
return compile_wrapper(ane, bytes(dat))
class ReLU(Function):
def forward(ctx, input):
ret = ctx.ane.tensor(input.shape)
@ -36,4 +48,5 @@ class ReLU(Function):
def backward(ctx, grad_output):
return 0
register('relu', ReLU, device=Device.ANE)
register("relu", ReLU, device=Device.ANE)

View File

@ -31,13 +31,14 @@ for x in out.values(): x.realize()
"""
from openvino.runtime import Core
core = Core()
devices = core.available_devices
for device in devices:
device_name = core.get_property(device, "FULL_DEVICE_NAME")
print(f"{device}: {device_name}")
model = core.read_model(onnx_path)
compiled_model = core.compile_model(model, device_name='GPU.0')
compiled_model = core.compile_model(model, device_name="GPU.0")
print(compiled_model)
ireq = compiled_model.create_infer_request()
for model_input in compiled_model.inputs:
@ -51,7 +52,7 @@ print("did one")
REPS = 20
st = time.perf_counter()
for i in range(REPS): ireq.infer()
for i in range(REPS):
ireq.infer()
et = time.perf_counter() - st
print(f"{et*1000:.2f} ms {(CNT*N*N*N*REPS*2/et)*1e-9:.2f} GFLOPS")

View File

@ -7,9 +7,12 @@ from tqdm import trange, tqdm
from matplotlib import pyplot as plt
tests = {}
def register_test(fxn):
tests[fxn.__name__] = fxn
def warp_size2(nthread):
prg = """__kernel void warp_size2(
__global float* src,
@ -27,16 +30,36 @@ def warp_size2(nthread):
src_buf = CLBuffer(1, dtypes.float32)
dst_buf = CLBuffer(1, dtypes.int32)
cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32])
return min([cl([nthread, 1024, 1], [nthread, 1, 1], src_buf, dst_buf, 10, 3, wait=True) for _ in range(5)])*1e9
return (
min(
[
cl(
[nthread, 1024, 1],
[nthread, 1, 1],
src_buf,
dst_buf,
10,
3,
wait=True,
)
for _ in range(5)
]
)
* 1e9
)
@register_test
def test_warp_size():
return [(nthread, warp_size2(nthread)) for nthread in trange(1, 256)]
def reg_count(nthread, ngrp, nreg):
reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)])
reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
reg_declr = "".join(
[f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)]
)
reg_comp = "".join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
reg_reduce = "".join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
prg = f"""__kernel void reg_count(
__global float* out_buf,
__private const int niter
@ -51,12 +74,25 @@ def reg_count(nthread, ngrp, nreg):
}}"""
out_buf = CLBuffer(1, dtypes.float32)
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True) for _ in range(10)])*1e9
return (
min(
[
cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True)
for _ in range(10)
]
)
* 1e9
)
@register_test
def test_reg_count(nthread=1, ngrp=1):
base = reg_count(nthread, ngrp, 1)
return [(nreg, (reg_count(nthread, ngrp, nreg)-base)/nreg) for nreg in trange(4, 513, 4)]
return [
(nreg, (reg_count(nthread, ngrp, nreg) - base) / nreg)
for nreg in trange(4, 513, 4)
]
def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
ndata //= NCOMP * 4 # ptr size
@ -72,22 +108,40 @@ def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
*dst = idx;
}}"""
idx_buf = np.zeros(ndata * NCOMP, dtype=np.int32)
for i in range(ndata): idx_buf[i*NCOMP] = (i + stride) % ndata
for i in range(ndata):
idx_buf[i * NCOMP] = (i + stride) % ndata
in_buf = CLBuffer.fromCPU(idx_buf)
out_buf = CLBuffer(1, dtypes.int32)
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True)/steps for _ in range(5)])*1e9
return (
min(
[
cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True) / steps
for _ in range(5)
]
)
* 1e9
)
@register_test
def test_memory_latency():
# requires cacheline < 16
szs = [int(1.3**x) for x in range(20, 70)]
return [(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128*1024)) for ndata in tqdm(szs)]
return [
(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128 * 1024))
for ndata in tqdm(szs)
]
@register_test
def test_cacheline_size():
# TODO: this buffer must be at least 2x the L1 cache for this test to work
return [(stride, buf_cache_hierarchy_pchase(4*65536, stride, steps=65536)) for stride in trange(1,64)]
return [
(stride, buf_cache_hierarchy_pchase(4 * 65536, stride, steps=65536))
for stride in trange(1, 64)
]
def cl_read(sz, niter=1):
prg = f"""__kernel void copy(
@ -101,7 +155,16 @@ def cl_read(sz, niter=1):
out_buf = CLBuffer(1, dtypes.float32)
cl = CLProgram("copy", prg)
# NOTE: if nay of the niters form a local group, this is wrong
return min([cl([sz//16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True) for _ in range(10)])*1e9
return (
min(
[
cl([sz // 16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True)
for _ in range(10)
]
)
* 1e9
)
@register_test
def test_read_bandwidth():
@ -129,12 +192,17 @@ def gflops(niter=4, nroll=4, ngroups=4096):
cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math")
FLOPS = NCOMP * 2 * 2 * niter * nroll * ngroups * 32
# NOTE: if nay of the niters form a local group, this is wrong
return FLOPS/(min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])*1e9)
return FLOPS / (
min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])
* 1e9
)
@register_test
def test_gflops():
return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)]
if __name__ == "__main__":
cache = {}
# cache = pickle.load(open("/tmp/cache.pkl", "rb"))
@ -144,8 +212,10 @@ if __name__ == "__main__":
print(f"running {k}")
plt.subplot(2, (len(tests) + 1) // 2, i + 1)
plt.title(k)
if k == "test_memory_latency": plt.xscale('log')
if k not in cache: cache[k] = test()
if k == "test_memory_latency":
plt.xscale("log")
if k not in cache:
cache[k] = test()
plt.plot(*zip(*cache[k]))
# pickle.dump(cache, open("/tmp/cache.pkl", "wb"))

View File

@ -1,32 +1,69 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
from typing import (
Tuple,
List,
NamedTuple,
Any,
Dict,
Optional,
Union,
DefaultDict,
cast,
)
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
from tinygrad.ops import BinaryOps, UnaryOps
from tinygrad.helpers import DType, dtypes, DEBUG
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
from tinygrad.shape.symbolic import (
Variable,
NumNode,
MulNode,
DivNode,
ModNode,
LtNode,
SumNode,
AndNode,
)
import functools
import math
from collections import defaultdict
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
_type_to_letter = {
dtypes.float32: "f",
dtypes.bool: "p",
dtypes.int32: "i",
dtypes.int64: "a",
dtypes.uint32: "u",
dtypes.uint64: "b",
dtypes.float.vec(4): "x",
dtypes.uint8: "uc",
dtypes.float16: "h",
dtypes.int8: "c",
dtypes.uint16: "us",
dtypes.float64: "d",
}
class Register(NamedTuple):
nm: str
dtype: DType
scalar: bool
off: Optional[int] = None
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
def __repr__(self):
return self.nm if self.off is None else f"{self.nm}:{self.off}"
def subregs(self):
if self.dtype == dtypes.float.vec(4):
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
return []
class AssemblyInstruction(NamedTuple):
op: UOps
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
class AssemblyLanguage:
supports_load3: bool = False
@ -37,9 +74,15 @@ class AssemblyLanguage:
tor: Dict[Any, Register] = {}
ins: List[AssemblyInstruction] = []
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
def type_to_letter(self, x):
return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
self.tor[tok] = ret = Register(
f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}",
dtype,
scalar,
)
if dtype == dtypes.float.vec(4):
for off in range(4):
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
@ -48,30 +91,72 @@ class AssemblyLanguage:
def render_numnode(self, b) -> Register:
key = ("num", b)
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
if key not in self.tor:
self.ins.append(
AssemblyInstruction(
UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b
)
)
return self.tor[key]
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
def render_alu(
self, op, a: Register, b: Union[Register, int, float], dtype=dtypes.int32
) -> Register:
key = (op, a, b)
if key not in self.tor:
# if not isinstance(b, Register): b = render_numnode(b)
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
self.ins.append(
AssemblyInstruction(
UOps.ALU,
self.newreg(
key,
dtype=dtype,
scalar=a.scalar and (not isinstance(b, Register) or b.scalar),
),
[a, b],
op,
)
)
return self.tor[key]
def render_cast(self, a: Register, new_dtype: DType) -> Register:
if a.dtype == new_dtype: return a
if a.dtype == new_dtype:
return a
key = (a, new_dtype)
if key not in self.tor:
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
self.ins.append(
AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])
)
return self.tor[key]
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
render_ops: Any = {
Variable: lambda self, ops, ctx: ctx.tor[self],
NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
MulNode: lambda self, ops, ctx: ctx.render_alu(
BinaryOps.MUL, self.a.render(ops, ctx), self.b
),
DivNode: lambda self, ops, ctx: ctx.render_alu(
BinaryOps.DIV, self.a.render(ops, ctx), self.b
),
ModNode: lambda self, ops, ctx: ctx.render_alu(
BinaryOps.MOD, self.a.render(ops, ctx), self.b
),
LtNode: lambda self, ops, ctx: ctx.render_alu(
BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool
),
SumNode: lambda self, ops, ctx: functools.reduce(
lambda a, b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops, ctx)),
self.nodes[1:],
self.nodes[0].render(ops, ctx),
),
AndNode: lambda self, ops, ctx: functools.reduce(
lambda a, b: ctx.render_alu(
BinaryOps.MUL, a, b.render(ops, ctx), dtype=dtypes.bool
),
self.nodes[1:],
self.nodes[0].render(ops, ctx),
),
}
def addr_w_offset(self, args):
assert isinstance(args, MemOp)
@ -79,110 +164,264 @@ class AssemblyLanguage:
off = 0 # TODO: should this be None?
if isinstance(idx, SumNode):
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
if (
nums and nums[0] < 4096 and (idx - nums[0]).min >= 0
): # TODO: different for each GPU?
idx -= nums[0]
off = cast(int, nums[0])
reg = idx.render(self.render_ops, self)
if self.supports_load3:
if reg.scalar:
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
new_reg = self.newreg((reg.nm, "vec"), dtype=reg.dtype)
self.ins.append(
AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)
)
reg = new_reg
return self.tor[args.name], reg, off
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
reg = self.render_alu(
BinaryOps.ADD,
self.render_cast(reg, dtypes.uint64),
self.tor[args.name],
dtype=dtypes.uint64,
)
return reg, None, off
def uops_to_asmstyle(lang, function_name: str, uops: List[UOp]):
# TODO: Do not use clear()
lang.ins.clear()
lang.tor.clear()
lang.cnts.clear()
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
buf_to_dtype = {
args[0]: args[1] for uop, _, _, args, _ in uops if uop == UOps.DEFINE_GLOBAL
}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
lang.ins += [
AssemblyInstruction(
UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf
)
for buf in buf_to_dtype
]
for u in uops:
uop, dtype, vin, args, _ = u
if uop == UOps.DEFINE_LOCAL:
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
lang.ins.append(
AssemblyInstruction(
UOps.ALU,
lang.newreg(args[0], dtype=dtypes.uint64),
[args[0]],
UnaryOps.NOOP,
)
)
elif uop == UOps.LOOP:
if args[1] == "global":
for i, var in enumerate(args[0]):
global_size.append(var.max + 1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
lang.ins.append(
AssemblyInstruction(
UOps.SPECIAL,
lang.newreg(var, dtype=dtypes.int32),
[],
f"gid{len(args[0])-1-i}",
)
)
elif args[1] == "local":
for i, var in enumerate(args[0]):
local_size.append(var.max + 1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
lang.ins.append(
AssemblyInstruction(
UOps.SPECIAL,
lang.newreg(var, dtype=dtypes.int32),
[],
f"lid{len(args[0])-1-i}",
)
)
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
if not isinstance(
var, NumNode
): # TODO: why is this coming through?
lang.ins.append(
AssemblyInstruction(
UOps.LOAD,
lang.newreg(var, dtype=dtypes.int32, scalar=True),
[],
0,
)
)
lang.ins.append(
AssemblyInstruction(
UOps.LABEL, None, [], "$loop_" + var.expr
)
)
elif uop == UOps.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
if not isinstance(
var, NumNode
): # TODO: why is this coming through?
lang.ins.append(
AssemblyInstruction(
UOps.ALU,
lang.tor[var],
[lang.tor[var], 1],
BinaryOps.ADD,
)
)
pred = lang.render_alu(
BinaryOps.CMPLT, lang.tor[var], var.max + 1, dtypes.bool
)
lang.ins.append(
AssemblyInstruction(
UOps.COND_BRANCH,
None,
[pred],
("$loop_" + var.expr, True),
)
)
elif args[1] == "global+local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
elif args[1] == 'local':
lang.ins.append(
AssemblyInstruction(
UOps.ENDLOOP,
None,
[lang.tor[var]],
(var.max + 1, f"gid{i}"),
)
)
elif args[1] == "local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
lang.ins.append(
AssemblyInstruction(
UOps.ENDLOOP,
None,
[lang.tor[var]],
(var.max + 1, f"lid{i}"),
)
)
elif uop == UOps.CAST:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = lang.newreg(u, dtype)
for i, sr in enumerate(out.subregs()):
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
lang.ins.append(
AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)
)
elif uop == UOps.ALU:
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
pred_reg = lang.newreg((u, "pred"), dtype=dtypes.bool)
lang.ins.append(
AssemblyInstruction(
UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args
)
)
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and lang.no_div:
tmp = lang.newreg((u, "rcp"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
lang.ins.append(
AssemblyInstruction(
UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP
)
)
lang.ins.append(
AssemblyInstruction(
UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL
)
)
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
tmp = lang.newreg((u, "2pi"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
lang.ins.append(
AssemblyInstruction(
UOps.ALU,
tmp,
[lang.tor[vin[0]], 1 / (math.pi * 2)],
BinaryOps.MUL,
)
)
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
else:
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
lang.ins.append(
AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)
)
elif uop == UOps.DEFINE_ACC:
reg = lang.newreg(u, dtype=dtype)
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
elif uop == UOps.SPECIAL:
lang.tor[u] = lang.tor[args]
elif uop == UOps.CONST:
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
lang.ins.append(
AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)
)
elif uop == UOps.LOAD:
idx, treg, off = lang.addr_w_offset(args)
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
reg = lang.newreg(
u,
dtype=dtype,
scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)),
)
if args.valid.min == 0:
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
lang.ins.append(
AssemblyInstruction(
UOps.COND_BRANCH,
None,
[pred],
(f"$skipload_{skipload_branch}", False),
)
)
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
lang.ins.append(
AssemblyInstruction(
UOps.LOAD,
reg,
[idx] + ([treg] if treg is not None else []),
(
off,
"global" if not args.local else "shared",
args.memory_dtype
if args.memory_dtype != dtypes.float
else None,
),
)
)
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
lang.ins.append(
AssemblyInstruction(
UOps.LABEL, None, [], f"$skipload_{skipload_branch}"
)
)
skipload_branch += 1
elif uop == UOps.STORE:
if args is None:
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
lang.ins.append(
AssemblyInstruction(
UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP
)
)
else:
idx, treg, off = lang.addr_w_offset(args)
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
lang.ins.append(
AssemblyInstruction(
UOps.STORE,
None,
[idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []),
(
off,
"global" if not args.local else "shared",
args.memory_dtype
if args.memory_dtype != dtypes.float
else None,
),
)
)
if DEBUG >= 4:
for tins in lang.ins: print(tins)
for tins in lang.ins:
print(tins)
return global_size, local_size

View File

@ -6,28 +6,60 @@ from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.helpers import dtypes, CI
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def float_to_hex(x):
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
def compute_offsets(total):
quotient, remainder = divmod(total, 4096)
return [4096] * quotient + [remainder] if remainder else [4096] * quotient
#NOTE: Darwin needs names to start with a "_"
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
class ARM64Language(AssemblyLanguage): pass
# NOTE: Darwin needs names to start with a "_"
def get_name(name):
return ("_" if system() == "Darwin" else "") + name
class ARM64Language(AssemblyLanguage):
pass
def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop: Optional[UOps] = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(12))]
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
x_regs = ["x" + str(i) for i in reversed(range(12))]
s_regs = ["s" + str(i) for i in reversed(range(3, 32)) if i <= 7 or i >= 16]
type_to_reg = {
dtypes.double: "d",
dtypes.half: "h",
dtypes.float32: "s",
dtypes.bool: "w",
dtypes.int8: "w",
dtypes.int32: "w",
dtypes.int64: "x",
dtypes.uint8: "w",
dtypes.uint32: "w",
dtypes.uint64: "x",
}
alu = {
BinaryOps.ADD: "add",
BinaryOps.SUB: "sub",
BinaryOps.MUL: "mul",
BinaryOps.DIV: "div",
BinaryOps.MAX: "max",
BinaryOps.MOD: "",
BinaryOps.CMPLT: "subs",
UnaryOps.NOOP: "mov",
UnaryOps.NEG: "neg",
UnaryOps.SIN: "bl " + get_name("sinf"),
UnaryOps.LOG2: "bl " + get_name("log2f"),
UnaryOps.EXP2: "bl " + get_name("exp2f"),
UnaryOps.SQRT: "bl " + get_name("sqrtf"),
TernaryOps.MULACC: "madd",
TernaryOps.WHERE: "fcsel",
}
def mov_imm(value, reg):
# Manually move value into reg if value can't fit
@ -35,7 +67,7 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"movz w15, #{value & 0xffff}")
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
ins.append(f"sxtw {reg}, w15")
elif reg[0] == 's':
elif reg[0] == "s":
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
ins.append("str x15, [sp, 16]")
@ -46,42 +78,51 @@ def specialize_to_arm64(fn_nm, asm):
# Get variables intervals
live_range: Dict[str, List[int]] = {}
for i, (uop, out, vin, arg) in enumerate(asm):
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
for var in [v for v in [out] + vin if v is not None and v.__class__ is not int]:
live_range[var.nm] = (
[i, i] if var.nm not in live_range else [live_range[var.nm][0], i]
)
mem_vars: Dict[str, int] = {}
rtor: Dict[str, str] = {}
def allocate_regs(mvars):
nonlocal var_size
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
for v in [
v
for v in mvars
if v is not None and v.__class__ is not int and v.nm not in rtor
]:
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
# NOTE: Very simple spill, everything that don't fit in regs goes to mem
if not available_regs:
# ARM needs the stack 16-byte aligned
var_size += 16
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
available_regs.append("s0" if dtypes.is_float(out[1]) else "x12")
mem_vars[v.nm] = var_size
rtor[v.nm] = available_regs.pop()
temp_floats = ['s0', 's1', 's2']
temp_ints = ['x12', 'x13', 'x16']
temp_floats = ["s0", "s1", "s2"]
temp_ints = ["x12", "x13", "x16"]
for i, (uop, out, vin, arg) in enumerate(asm):
# Clear regs out of interval
for var, reg in list(rtor.items()):
available_regs = s_regs if reg[0] == 's' else x_regs
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
available_regs = s_regs if reg[0] == "s" else x_regs
if var[1] not in "B" and var not in mem_vars and i > live_range[var][1]:
available_regs.append(rtor.pop(var))
# Assign a registers to the variables using live ranges.
allocate_regs([out] + vin)
# Assign temp regs to vin and load them before direct use
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
for i, v in enumerate(
[v for v in vin if v.__class__ is not int and v.nm in mem_vars]
):
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
ins.append(f"mov x15, {mem_vars[v.nm]}")
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
if uop == UOps.SPECIAL:
if arg.startswith('data'):
if arg.startswith("data"):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
@ -91,28 +132,40 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"loop_{arg}:")
elif uop == UOps.CAST:
if arg == BinaryOps.CMPLT:
if rtor[out.nm][0] == 's':
mov_imm(0.0, 's0')
mov_imm(1.0, 's1')
if rtor[out.nm][0] == "s":
mov_imm(0.0, "s0")
mov_imm(1.0, "s1")
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
if rtor[out.nm][0] == 'x':
mov_imm(0, 'x14')
mov_imm(1, 'x15')
if rtor[out.nm][0] == "x":
mov_imm(0, "x14")
mov_imm(1, "x15")
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
else:
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
elif uop == UOps.ALU:
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
if len(vin) == 2 and vin[1].__class__ is int:
mov_imm(vin[1], "x15")
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
ins.append(
f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
)
elif arg == TernaryOps.WHERE:
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
ins.append(
f"fcmp {rtor[vin[0].nm]}, #0.0"
if rtor[vin[0].nm][0] == "s"
else f"cmp {rtor[vin[0].nm]}, #0"
)
ins.append(
f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne"
)
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
# NOTE: Not a real instruction, use to emulate a ext call in unicorn
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
if CI:
ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
else:
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
save_regs = [
k for k in rtor.keys() if k != out.nm and k not in mem_vars
]
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
# Save the registers before they are cleared by func call
for i, k in enumerate(save_regs, 1):
@ -128,27 +181,49 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
ins.append(f"add sp, sp, #{len(save_regs)*16}")
elif arg == BinaryOps.CMPLT:
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
ins.append(
f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
if not dtypes.is_float(vin[0][1])
else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}"
)
elif arg == BinaryOps.MOD:
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
rhs = "x15" if vin[1].__class__ is int else rtor[vin[1].nm]
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
else:
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
ins.append(
f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
)
elif uop == UOps.LOAD:
if arg.__class__ in (int, float):
mov_imm(arg, rtor[out.nm])
else:
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
reg_in = (
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
if arg[2] is not None
else rtor[out.nm]
)
mov_imm(arg[0], "x15")
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
ins.append(
f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]"
)
if arg[2] is not None:
ins.append(
f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}"
)
elif uop == UOps.STORE:
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
reg_out = (
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
if arg[2] is not None
else rtor[vin[1].nm]
)
if arg[2] is not None:
ins.append(
f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}"
)
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
elif uop == UOps.COND_BRANCH:
@ -168,9 +243,31 @@ def specialize_to_arm64(fn_nm, asm):
if out is not None and out.nm in mem_vars:
ins.append(f"mov x15, {mem_vars[out.nm]}")
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
return "\n".join(
[
f"//varsize {var_size}",
".arch armv8-a",
".text",
f".global {get_name(fn_nm)}",
".p2align 2",
f"{get_name(fn_nm)}:",
"mov x17, sp",
]
+ [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]
+ ins
+ [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)]
+ ["ret", "\n"]
)
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
def uops_to_arm64_asm(
fn_nm: str, uops: List[UOp]
) -> Tuple[str, List[int], List[int], bool]:
lang = ARM64Language()
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
return (
specialize_to_arm64(fn_nm, lang.ins),
global_size[::-1],
local_size[::-1],
True,
)

View File

@ -6,50 +6,113 @@ from tinygrad.helpers import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
dtype_to_nvtype = {
dtypes.float32: "f32",
dtypes.float16: "f16",
dtypes.int64: "s64",
dtypes.int32: "s32",
dtypes.int8: "s8",
dtypes.bool: "pred",
dtypes.uint64: "u64",
dtypes.uint32: "u32",
dtypes.uint16: "u16",
dtypes.uint8: "u8",
"bits16": "b16",
dtypes.float64: "f64",
}
def float_to_hex(x):
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
def ptx_needs_cast(dest_dtype, src_dtype):
return (
dtypes.is_float(dest_dtype)
and dtypes.is_int(src_dtype)
or dtypes.is_int(dest_dtype)
and dtypes.is_float(src_dtype)
or (
dtypes.is_float(src_dtype)
and dtypes.is_float(dest_dtype)
and dest_dtype.itemsize != src_dtype.itemsize
)
)
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
def render_cast(ins, inp, out):
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
if inp.dtype == dtypes.bool and (
dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)
):
ins.append(
f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};"
)
elif out.dtype == dtypes.bool:
if inp.dtype == dtypes.bool:
ins.append(f"mov.pred {out}, {inp};")
else:
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
ins.append(
f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};"
)
else:
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
round_mod = (
".rzi"
if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype)
else ".rz"
if dtypes.is_float(out.dtype)
and (
dtypes.is_int(inp.dtype)
or dtypes.is_float(inp.dtype)
and inp.dtype.itemsize > out.dtype.itemsize
)
else ""
)
ins.append(
f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};"
)
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
class PTXLanguage(AssemblyLanguage):
supports_constant_folding: bool = True
def specialize_to_ptx(lang, function_name):
param_cnt = 0
ins = []
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
alu = {
BinaryOps.ADD: "add",
BinaryOps.SUB: "sub",
BinaryOps.MUL: "mul",
BinaryOps.DIV: "div",
BinaryOps.MAX: "max",
BinaryOps.MOD: "rem",
BinaryOps.CMPLT: "setp.lt",
UnaryOps.SQRT: "sqrt.approx",
UnaryOps.NOOP: "mov",
UnaryOps.NEG: "neg",
UnaryOps.SIN: "sin.approx",
UnaryOps.LOG2: "lg2.approx",
UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn",
TernaryOps.WHERE: "selp",
}
for uop, out, vin, arg in lang.ins:
if uop == UOps.ENDLOOP:
ins.append("bar.sync 0;")
elif uop == UOps.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == UOps.SPECIAL:
if arg.startswith('data'):
if arg.startswith("data"):
param_cnt += 1
ins.append(f"ld.param.u64 {out}, [{arg}];")
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
# ins.append(f"cvta.to.global.u64 {out}, {out};")
elif arg.startswith('gid'):
elif arg.startswith("gid"):
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
elif arg.startswith('lid'):
elif arg.startswith("lid"):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == UOps.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
@ -60,31 +123,64 @@ def specialize_to_ptx(lang, function_name):
if vin[0].dtype == dtypes.bool:
reg = vin[0]
else:
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
reg = lang.newreg((vin[0], "bool"), dtypes.bool)
ins.append(
f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};"
)
vin = vin[1:] + [reg]
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
ins.append(
f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};"
)
elif uop == UOps.LOAD:
if arg.__class__ in (int, float):
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
ins.append(
f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};"
)
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
dt = (
("u16", dtypes.uint16)
if arg[2] == dtypes.bool == out.dtype
else ("u8", dtypes.uint8)
if arg[2] == dtypes.bool
else ("b16", dtypes.float16)
if arg[2] == dtypes.half
else (dtype_to_nvtype[arg[2]], arg[2])
)
reg = lang.newreg((out, dt[0]), dtype=dt[1])
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
ins.append(
f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
)
render_cast(ins, reg, out)
else:
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
ins.append(
f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
)
elif uop == UOps.STORE:
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
if (
ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype)
or arg[2] == dtypes.bool
):
if arg[2] == dtypes.bool != vin[1].dtype:
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
prereg = lang.newreg((vin[1], "bool"), dtype=dtypes.bool)
render_cast(ins, vin[1], prereg)
else: prereg = vin[1]
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
render_cast(ins, prereg, reg)
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
else:
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
prereg = vin[1]
reg = lang.newreg(
(prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]),
dtype=dtypes.uint16
if arg[2] == dtypes.bool
else dtypes.float
if arg[2] is None
else arg[2],
)
render_cast(ins, prereg, reg)
ins.append(
f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};"
)
else:
ins.append(
f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};"
)
elif uop == UOps.CAST:
render_cast(ins, vin[0], out)
elif uop == UOps.LABEL:
@ -92,14 +188,29 @@ def specialize_to_ptx(lang, function_name):
elif uop == UOps.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
ins_prefix = [
".version 7.8",
".target " + arch(),
".address_size 64",
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{",
]
for arg in [
(dtype, lang.type_to_letter(dtype), c) for dtype, c in lang.cnts.items()
]:
ins_prefix.append(
f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",
)
ins = ins_prefix + ins
ins += ["ret;", "}"]
return '\n'.join(ins)
return "\n".join(ins)
def uops_to_ptx_asm(function_name: str, uops: List[UOp]):
lang = PTXLanguage()
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
return (
specialize_to_ptx(lang, function_name),
global_size[::-1],
local_size[::-1],
True,
)

View File

@ -8,6 +8,7 @@ from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
# ugh, is this really needed?
from extra.helpers import enable_early_exec
early_exec = enable_early_exec()
boilerplate_start = """
@ -24,6 +25,7 @@ code_start = """.end_amdhsa_kernel
code:
"""
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
# RDNA3 is actually a SIMD machine!
@ -36,107 +38,202 @@ class RDNACodegen(AssemblyCodegen):
def specialize(self, asm) -> Tuple[str, str]:
args = []
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
for i, b in enumerate(self.bufs):
args.append(
{
".address_space": "global",
".name": f"buf_{i}",
".offset": i * 8,
".size": 8,
".type_name": b.dtype.name + "*",
".value_kind": "global_buffer",
}
)
ins = []
v_cnt = 3 # v[0:2] is local_xyz
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
BinaryOps.CMPLT: "cmp_lt"}
dtype_to_rdnatype = {
dtypes.float32: "f32",
dtypes.int64: "i64",
dtypes.int32: "i32",
dtypes.uint64: "u64",
dtypes.bool: "i32",
}
alu = {
BinaryOps.ADD: "add",
BinaryOps.SUB: "sub",
BinaryOps.MUL: "mul",
TernaryOps.MULACC: "fma",
BinaryOps.MAX: "max",
UnaryOps.RECIP: "rcp",
UnaryOps.NOOP: "mov",
UnaryOps.SIN: "sin",
UnaryOps.LOG2: "log",
UnaryOps.EXP2: "exp",
BinaryOps.CMPLT: "cmp_lt",
}
pend_regs: Set[Register] = set()
rtor: Dict[Register, str] = {}
def reg_in(x):
nonlocal pend_regs
# print("reg_in", x, rtor[x], pend_regs)
if x in pend_regs:
# print("clear")
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
ins.append("s_waitcnt lgkmcnt(0), vmcnt(0)")
pend_regs.clear()
return rtor[x]
def reg_out(x):
return rtor[x]
for uop, out, vin, arg in asm:
if uop == UOps.DEFINE_REGISTER:
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
if arg[0][0] in [
dtypes.uint32,
dtypes.uint64,
dtypes.int64,
dtypes.int32,
dtypes.float32,
dtypes.float.vec(4),
]:
for i in range(arg[2]):
# TODO: Re-use gaps created by this to avoid wasting registers
align = int(arg[0][0].itemsize / 4)
if arg[0][1]:
s_cnt += s_cnt % align
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
reg_name = (
f"s[{s_cnt}:{s_cnt + align - 1}]"
if align > 1
else f"s{s_cnt}"
)
s_cnt += align
else:
v_cnt += v_cnt % align
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
reg_name = (
f"v[{v_cnt}:{v_cnt + align - 1}]"
if align > 1
else f"v{v_cnt}"
)
v_cnt += align
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
if arg[0][0] == dtypes.float.vec(4):
for off in range(4):
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
reg_name = (
f"s{s_cnt-align+off}"
if arg[0][1]
else f"v{v_cnt-align+off}"
)
rtor[
Register(
f"%{arg[1]}{i}", dtypes.float, False, off=off
)
] = reg_name
elif arg[0][0] == dtypes.bool:
for i in range(arg[2]):
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
reg_name = (
"scc" if arg[0][1] else "vcc_lo"
) # `_lo` suffix since we're running wavefront_size=32
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
else:
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
raise NotImplementedError(
"DEFINE_REGISTER not implemented for arg: ", arg
)
elif uop == UOps.SPECIAL:
if arg.startswith('buf'):
if arg.startswith("buf"):
i = int(arg[3:])
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
ins.append(f"s_load_b64 {reg_out(out)}, s[0:1], {i*8}")
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif arg.startswith('gid'):
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
for r in out.subregs():
pend_regs.add(r)
elif arg.startswith("gid"):
ins.append(f"v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}")
# the docs lied, this is actually y
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
if int(arg[3]) == 2:
ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
if int(arg[3]) == 1:
ins.append("v_bfe_u32 v1, v0, 10, 10")
elif int(arg[3]) == 0:
ins.append("v_and_b32_e32 v0, 0x3ff, v0")
# get local size
offset = len(args) * 8
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
args.append(
{
".offset": offset,
".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}",
".size": 8,
}
)
ins.append(f"s_load_b32 s{2+int(arg[3])}, s[0:1], {offset}")
ins.append("s_waitcnt vmcnt(0) lgkmcnt(0)")
pend_regs.clear()
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
ins.append(
f"v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}"
)
ins.append(
f"v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}"
)
elif uop == UOps.CONST:
if arg == float('inf'): arg = "0x7f800000"
elif arg == float('-inf'): arg = "0xff800000"
if arg == float("inf"):
arg = "0x7f800000"
elif arg == float("-inf"):
arg = "0xff800000"
if out.dtype == dtypes.float.vec(4):
for off in range(4):
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
ins.append(
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}"
)
else:
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
ins.append(
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}"
)
elif uop == UOps.ALU:
if arg in [BinaryOps.CMPLT]:
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
ins.append(
f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
)
else:
alu_arg = alu[arg]
if arg == TernaryOps.MULACC and out == vin[2]:
alu_arg = "fmac"
vin = vin[0:2]
if out.dtype == dtypes.float.vec(4):
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
for rr in zip(
*[
x.subregs()
if x.dtype == dtypes.float.vec(4)
else [x, x, x, x]
for x in [out] + vin
]
):
ins.append(
f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}"
)
else:
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
ins.append(
f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
)
elif uop == UOps.LOAD:
if out.scalar:
# swap arg order
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
ins.append(
f"s_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}"
)
else:
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
ins.append(
f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
)
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
for r in out.subregs():
pend_regs.add(r)
elif uop == UOps.STORE:
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
ins.append(
f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
)
elif uop == UOps.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
@ -144,29 +241,40 @@ class RDNACodegen(AssemblyCodegen):
elif uop == UOps.CAST:
if vin[0].dtype == dtypes.bool:
if out.dtype == dtypes.float32:
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
ins.append(
f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}"
)
else:
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
else:
raise NotImplementedError(uop)
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
ins += ["s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm", "s_code_end"]
# dual alu group
seen = set()
new_ins = []
for i, tins in enumerate(ins):
if tins in seen: continue
if tins in seen:
continue
if tins.startswith("v_fmac_f32"):
for gins in reversed(ins[i + 1 :]):
if gins in seen: continue
if gins in seen:
continue
if gins.startswith("v_fmac_f32"):
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
if r0[0]%2 == r1[0]%2: continue
if r0[1]%2 == r1[1]%2: continue
if r0[2]%2 == r1[2]%2: continue
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
r0 = [int(x[1:].strip(",")) for x in tins.split(" ")[1:]]
r1 = [int(x[1:].strip(",")) for x in gins.split(" ")[1:]]
if r0[0] % 2 == r1[0] % 2:
continue
if r0[1] % 2 == r1[1] % 2:
continue
if r0[2] % 2 == r1[2] % 2:
continue
new_ins.append(
tins.replace("v_", "v_dual_")
+ " :: "
+ gins.replace("v_", "v_dual_")
)
seen.add(tins)
seen.add(gins)
break
@ -174,30 +282,102 @@ class RDNACodegen(AssemblyCodegen):
new_ins.append(tins)
ins = new_ins
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
return "code", self.assemble(args, ins, v_cnt, s_cnt)
def assemble(self, args, ins, v_cnt, s_cnt):
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
'.amdhsa_next_free_sgpr': s_cnt,
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
kernel_desc = {
".amdhsa_group_segment_fixed_size": 0,
".amdhsa_private_segment_fixed_size": 0,
".amdhsa_kernarg_size": 0,
".amdhsa_next_free_vgpr": v_cnt, # this matters!
".amdhsa_reserve_vcc": 0,
".amdhsa_reserve_xnack_mask": 0,
".amdhsa_next_free_sgpr": s_cnt,
".amdhsa_float_round_mode_32": 0,
".amdhsa_float_round_mode_16_64": 0,
".amdhsa_float_denorm_mode_32": 3,
".amdhsa_float_denorm_mode_16_64": 3,
".amdhsa_dx10_clamp": 1,
".amdhsa_ieee_mode": 1,
".amdhsa_fp16_overflow": 0,
".amdhsa_workgroup_processor_mode": 1,
".amdhsa_memory_ordered": 1,
".amdhsa_forward_progress": 0,
".amdhsa_enable_private_segment": 0,
".amdhsa_system_sgpr_workgroup_id_x": 1,
".amdhsa_system_sgpr_workgroup_id_y": 1,
".amdhsa_system_sgpr_workgroup_id_z": 1,
".amdhsa_system_sgpr_workgroup_info": 0,
".amdhsa_system_vgpr_workitem_id": 2, # is amdhsa_system_vgpr_workitem_id real?
".amdhsa_exception_fp_ieee_invalid_op": 0,
".amdhsa_exception_fp_denorm_src": 0,
".amdhsa_exception_fp_ieee_div_zero": 0,
".amdhsa_exception_fp_ieee_overflow": 0,
".amdhsa_exception_fp_ieee_underflow": 0,
".amdhsa_exception_fp_ieee_inexact": 0,
".amdhsa_exception_int_div_zero": 0,
".amdhsa_user_sgpr_dispatch_ptr": 0,
".amdhsa_user_sgpr_queue_ptr": 0,
".amdhsa_user_sgpr_kernarg_segment_ptr": 1,
".amdhsa_user_sgpr_dispatch_id": 0,
".amdhsa_user_sgpr_private_segment_size": 0,
".amdhsa_wavefront_size32": 1,
".amdhsa_uses_dynamic_stack": 0,
}
metadata = {'amdhsa.kernels': [{'.args': args,
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
'.wavefront_size': 32}],
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
metadata = {
"amdhsa.kernels": [
{
".args": args,
".group_segment_fixed_size": 0,
".kernarg_segment_align": 8,
".kernarg_segment_size": args[-1][".offset"] + args[-1][".size"],
".language": "OpenCL C",
".language_version": [1, 2],
".max_flat_workgroup_size": 256,
".name": "code",
".private_segment_fixed_size": 0,
".sgpr_count": s_cnt,
".sgpr_spill_count": 0,
".symbol": "code.kd",
".uses_dynamic_stack": False,
".vgpr_count": v_cnt,
".vgpr_spill_count": 0,
".wavefront_size": 32,
}
],
"amdhsa.target": "amdgcn-amd-amdhsa--gfx1100",
"amdhsa.version": [1, 2],
}
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
code = (
boilerplate_start
+ "\n"
+ "\n".join("%s %d" % x for x in kernel_desc.items())
+ "\n"
+ code_start
+ "\n".join(ins)
+ "\n.amdgpu_metadata\n"
+ yaml.dump(metadata)
+ ".end_amdgpu_metadata"
)
obj = early_exec(
(
[
ROCM_LLVM_PATH / "llvm-mc",
"--arch=amdgcn",
"--mcpu=gfx1100",
"--triple=amdgcn-amd-amdhsa",
"--filetype=obj",
"-",
],
code.encode("utf-8"),
)
)
asm = early_exec(
(
[ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"],
obj,
)
)
return asm

View File

@ -4,7 +4,9 @@ from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
if __name__ == "__main__":
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
prg = CUDAProgram("test", """
prg = CUDAProgram(
"test",
"""
.version 7.8
.target sm_86
.address_size 64
@ -17,7 +19,8 @@ if __name__ == "__main__":
mov.u32 %r1, 0x40000000; // 2.0 in float
st.global.u32 [%rd2], %r1;
ret;
}""", binary=True)
}""",
binary=True,
)
prg([1], [1], test)
print(test.toCPU())

View File

@ -3,6 +3,7 @@ import pathlib
from hexdump import hexdump
from tinygrad.helpers import colored
from extra.helpers import enable_early_exec
early_exec = enable_early_exec()
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH
@ -37,29 +38,49 @@ for j in range(1):
c = (y * KX + x) * 8
a = (KY * KX * 8) + y * 8
b = (KY * KX * 8) + (KY * 8) + x * 8
gen.append(f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]")
gen.append(
f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]"
)
FLOPS += 16 * 8 * 2
else:
for i in range(0, MAX_REG, 6):
if DUAL_ALU:
if F32:
gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
gen.append(
f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}"
)
FLOPS += 4
else:
gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}")
gen.append(
f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}"
)
FLOPS += 8
else:
assert F32
gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}")
gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
code = code.replace("// FLOPS", '\n'.join(gen))
code = code.replace("// FLOPS", "\n".join(gen))
print(code)
# fix: COMGR failed to get code object ISA name. set triple to 'amdgcn-amd-amdhsa'
object = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object))
object = early_exec(
(
[
ROCM_LLVM_PATH / "llvm-mc",
"--arch=amdgcn",
"--mcpu=gfx1100",
"--triple=amdgcn-amd-amdhsa",
"--filetype=obj",
"-",
],
code.encode("utf-8"),
)
)
asm = early_exec(
([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object)
)
with open("/tmp/cc2.o", "wb") as f:
f.write(object)

View File

@ -2,12 +2,14 @@ import numpy as np
from PIL import Image
from pathlib import Path
import sys
cwd = Path.cwd()
sys.path.append(cwd.as_posix())
sys.path.append((cwd / 'test').as_posix())
sys.path.append((cwd / "test").as_posix())
from extra.datasets import fetch_mnist
from tqdm import trange
def augment_img(X, rotate=10, px=3):
Xaug = np.zeros_like(X)
for i in trange(len(X)):
@ -20,8 +22,10 @@ def augment_img(X, rotate=10, px=3):
Xaug[i] = im
return Xaug
if __name__ == "__main__":
import matplotlib.pyplot as plt
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
@ -29,14 +33,18 @@ if __name__ == "__main__":
fig, a = plt.subplots(2, len(X))
Xaug = augment_img(X)
for i in range(len(X)):
a[0][i].imshow(X[i], cmap='gray')
a[1][i].imshow(Xaug[i],cmap='gray')
a[0][i].axis('off')
a[1][i].axis('off')
a[0][i].imshow(X[i], cmap="gray")
a[1][i].imshow(Xaug[i], cmap="gray")
a[0][i].axis("off")
a[1][i].axis("off")
plt.show()
# create some nice gifs for doc?!
for i in range(10):
im = Image.fromarray(X_train[7353 + i])
im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0)
im_aug = [
Image.fromarray(x) for x in augment_img(np.array([X_train[7353 + i]] * 100))
]
im.save(
f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0
)

View File

@ -3,30 +3,53 @@ import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, fetch
def fetch_mnist(tensors=False):
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
X_train = (
parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:]
.reshape((-1, 28 * 28))
.astype(np.float32)
)
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
X_test = (
parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:]
.reshape((-1, 28 * 28))
.astype(np.float32)
)
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
else: return X_train, Y_train, X_test, Y_test
if tensors:
return (
Tensor(X_train).reshape(-1, 1, 28, 28),
Tensor(Y_train),
Tensor(X_test).reshape(-1, 1, 28, 28),
Tensor(Y_test),
)
else:
return X_train, Y_train, X_test, Y_test
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
def fetch_cifar():
X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8)
Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64)
X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8)
Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64)
X_train = Tensor.empty(
50000, 3 * 32 * 32, device=f"disk:/tmp/cifar_train_x", dtype=dtypes.uint8
)
Y_train = Tensor.empty(50000, device=f"disk:/tmp/cifar_train_y", dtype=dtypes.int64)
X_test = Tensor.empty(
10000, 3 * 32 * 32, device=f"disk:/tmp/cifar_test_x", dtype=dtypes.uint8
)
Y_test = Tensor.empty(10000, device=f"disk:/tmp/cifar_test_y", dtype=dtypes.int64)
if not os.path.isfile("/tmp/cifar_extracted"):
def _load_disk_tensor(X, Y, db_list):
idx = 0
for db in db_list:
x, y = db[b'data'], np.array(db[b'labels'])
x, y = db[b"data"], np.array(db[b"labels"])
assert x.shape[0] == y.shape[0]
X[idx : idx + x.shape[0]].assign(x)
Y[idx : idx + x.shape[0]].assign(y)
@ -34,10 +57,28 @@ def fetch_cifar():
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
print("downloading and extracting CIFAR...")
fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
tt = tarfile.open(fn, mode='r:gz')
_load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
_load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])
fn = fetch("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz")
tt = tarfile.open(fn, mode="r:gz")
_load_disk_tensor(
X_train,
Y_train,
[
pickle.load(
tt.extractfile(f"cifar-10-batches-py/data_batch_{i}"),
encoding="bytes",
)
for i in range(1, 6)
],
)
_load_disk_tensor(
X_test,
Y_test,
[
pickle.load(
tt.extractfile("cifar-10-batches-py/test_batch"), encoding="bytes"
)
],
)
open("/tmp/cifar_extracted", "wb").close()
return X_train, Y_train, X_test, Y_test

View File

@ -15,32 +15,36 @@ frPyObjects = _mask.frPyObjects
BASEDIR = pathlib.Path(__file__).parent / "COCO"
BASEDIR.mkdir(exist_ok=True)
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
def create_dict(key_row, val_row, rows):
return {row[key_row]: row[val_row] for row in rows}
if not pathlib.Path(BASEDIR/'val2017').is_dir():
fn = fetch('http://images.cocodataset.org/zips/val2017.zip')
with zipfile.ZipFile(fn, 'r') as zip_ref:
if not pathlib.Path(BASEDIR / "val2017").is_dir():
fn = fetch("http://images.cocodataset.org/zips/val2017.zip")
with zipfile.ZipFile(fn, "r") as zip_ref:
zip_ref.extractall(BASEDIR)
fn.unlink()
if not pathlib.Path(BASEDIR/'annotations').is_dir():
fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip')
with zipfile.ZipFile(fn, 'r') as zip_ref:
if not pathlib.Path(BASEDIR / "annotations").is_dir():
fn = fetch("http://images.cocodataset.org/annotations/annotations_trainval2017.zip")
with zipfile.ZipFile(fn, "r") as zip_ref:
zip_ref.extractall(BASEDIR)
fn.unlink()
with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
with open(BASEDIR / "annotations/instances_val2017.json", "r") as f:
annotations_raw = json.loads(f.read())
images = annotations_raw['images']
categories = annotations_raw['categories']
annotations = annotations_raw['annotations']
file_name_to_id = create_dict('file_name', 'id', images)
id_to_width = create_dict('id', 'width', images)
id_to_height = create_dict('id', 'height', images)
json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
images = annotations_raw["images"]
categories = annotations_raw["categories"]
annotations = annotations_raw["annotations"]
file_name_to_id = create_dict("file_name", "id", images)
id_to_width = create_dict("id", "width", images)
id_to_height = create_dict("id", "height", images)
json_category_id_to_contiguous_id = {v["id"]: i + 1 for i, v in enumerate(categories)}
contiguous_category_id_to_json_id = {
v: k for k, v in json_category_id_to_contiguous_id.items()
}
def encode(bimask):
@ -48,7 +52,8 @@ def encode(bimask):
return _mask.encode(bimask)
elif len(bimask.shape) == 2:
h, w = bimask.shape
return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
return _mask.encode(bimask.reshape((h, w, 1), order="F"))[0]
def decode(rleObjs):
if type(rleObjs) == list:
@ -56,12 +61,14 @@ def decode(rleObjs):
else:
return _mask.decode([rleObjs])[:, :, 0]
def area(rleObjs):
if type(rleObjs) == list:
return _mask.area(rleObjs)
else:
return _mask.area([rleObjs])[0]
def toBbox(rleObjs):
if type(rleObjs) == list:
return _mask.toBbox(rleObjs)
@ -102,8 +109,10 @@ def convert_prediction_to_coco_bbox(file_name, prediction):
print(file_name, e)
return coco_results
masker = Masker(threshold=0.5, padding=1)
def convert_prediction_to_coco_mask(file_name, prediction):
coco_results = []
try:
@ -122,8 +131,7 @@ def convert_prediction_to_coco_mask(file_name, prediction):
masks = masker([masks], [prediction])[0].numpy()
rles = [
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
for mask in masks
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
@ -146,20 +154,22 @@ def convert_prediction_to_coco_mask(file_name, prediction):
return coco_results
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
path = pathlib.Path(json_result_file)
if rm and path.exists(): path.unlink()
if rm and path.exists():
path.unlink()
with open(path, "a") as f:
for s in coco_results:
f.write(json.dumps(s))
f.write('\n')
f.write("\n")
def remove_dup(l):
seen = set()
seen_add = seen.add
return [x for x in l if not (x in seen or seen_add(x))]
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
@ -177,23 +187,28 @@ def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
for line in f:
coco_results.append(json.loads(line))
coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
coco_gt = COCO(str(BASEDIR / "annotations/instances_val2017.json"))
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
unique_list = [json.loads(s) for s in set_of_json]
with open(f'{json_result_file}.flattend', "w") as f:
with open(f"{json_result_file}.flattend", "w") as f:
json.dump(unique_list, f)
coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
coco_dt = coco_gt.loadRes(str(f"{json_result_file}.flattend"))
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval
def iterate(files, bs=1):
batch = []
for file in files:
batch.append(file)
if len(batch) >= bs: yield batch; batch = []
if len(batch) > 0: yield batch; batch = []
if len(batch) >= bs:
yield batch
batch = []
if len(batch) > 0:
yield batch
batch = []

View File

@ -9,36 +9,45 @@ BASEDIR = pathlib.Path(__file__).parent / "imagenet"
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
cir = {v[0]: int(k) for k, v in ci.items()}
@functools.lru_cache(None)
def get_train_files():
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
return [(BASEDIR / "train" / x) for x in train_files]
@functools.lru_cache(None)
def get_val_files():
val_files = glob.glob(str(BASEDIR / "val/*/*"))
return val_files
# rrc = transforms.RandomResizedCrop(224)
import torchvision.transforms.functional as F
def image_load(fn):
img = Image.open(fn).convert('RGB')
img = Image.open(fn).convert("RGB")
img = F.resize(img, 256, Image.BILINEAR)
img = F.center_crop(img, 224)
ret = np.array(img)
return ret
def iterate(bs=32, val=True, shuffle=True):
files = get_val_files() if val else get_train_files()
order = list(range(0, len(files)))
if shuffle: random.shuffle(order)
if shuffle:
random.shuffle(order)
from multiprocessing import Pool
p = Pool(16)
for i in range(0, len(files), bs):
X = p.map(image_load, [files[i] for i in order[i : i + bs]])
Y = [cir[files[i].split("/")[-2]] for i in order[i : i + bs]]
yield (np.array(X), np.array(Y))
def fetch_batch(bs, val=False):
files = get_val_files() if val else get_train_files()
samp = np.random.randint(0, len(files), size=(bs))
@ -47,7 +56,7 @@ def fetch_batch(bs, val=False):
Y = [cir[x.split("/")[0]] for x in files]
return np.array(X), np.array(Y)
if __name__ == "__main__":
X, Y = fetch_batch(64)
print(X.shape, Y)

View File

@ -4,17 +4,26 @@ from pathlib import Path
from tqdm import tqdm
import tarfile, os
def imagenet_extract(file, path, small=False):
with tarfile.open(name=file) as tar:
if small: # Show progressbar only for big files
for member in tar.getmembers(): tar.extract(path=path, member=member)
for member in tar.getmembers():
tar.extract(path=path, member=member)
else:
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member)
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())):
tar.extract(path=path, member=member)
tar.close()
def imagenet_prepare_val():
# Read in the labels file
with open(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt", 'r') as f:
with open(
Path(__file__).parent
/ "imagenet"
/ "imagenet_2012_validation_synset_labels.txt",
"r",
) as f:
labels = f.read().splitlines()
f.close()
# Get a list of images
@ -23,8 +32,16 @@ def imagenet_prepare_val():
# Create folders and move files into those
for co, dir in enumerate(labels):
os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True)
os.replace(Path(__file__).parent / "imagenet" / "val" / images[co], Path(__file__).parent / "imagenet" / "val" / dir / images[co])
os.remove(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt")
os.replace(
Path(__file__).parent / "imagenet" / "val" / images[co],
Path(__file__).parent / "imagenet" / "val" / dir / images[co],
)
os.remove(
Path(__file__).parent
/ "imagenet"
/ "imagenet_2012_validation_synset_labels.txt"
)
def imagenet_prepare_train():
images = os.listdir(Path(__file__).parent / "imagenet" / "train")
@ -32,20 +49,47 @@ def imagenet_prepare_train():
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file():
images[co] = tarf[:-4] # remove .tar from extracted tar files
os.makedirs(Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True)
imagenet_extract(Path(__file__).parent / "imagenet" / "train" / tarf, Path(__file__).parent/ "imagenet" / "train" / images[co], small=True)
os.makedirs(
Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True
)
imagenet_extract(
Path(__file__).parent / "imagenet" / "train" / tarf,
Path(__file__).parent / "imagenet" / "train" / images[co],
small=True,
)
os.remove(Path(__file__).parent / "imagenet" / "train" / tarf)
if __name__ == "__main__":
os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True)
os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True)
os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True)
fetch("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent / "imagenet" / "imagenet_class_index.json")
fetch("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent / "imagenet"/ "imagenet_2012_validation_synset_labels.txt")
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar") # 7GB
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "val")
fetch(
"https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json",
Path(__file__).parent / "imagenet" / "imagenet_class_index.json",
)
fetch(
"https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt",
Path(__file__).parent
/ "imagenet"
/ "imagenet_2012_validation_synset_labels.txt",
)
fetch(
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar",
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
) # 7GB
imagenet_extract(
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
Path(__file__).parent / "imagenet" / "val",
)
imagenet_prepare_val()
if os.getenv('IMGNET_TRAIN', None) is not None:
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar") #138GB!
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "train")
if os.getenv("IMGNET_TRAIN", None) is not None:
fetch(
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar",
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
) # 138GB!
imagenet_extract(
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
Path(__file__).parent / "imagenet" / "train",
)
imagenet_prepare_train()

View File

@ -23,41 +23,70 @@ mv kits extra/datasets
```
"""
@functools.lru_cache(None)
def get_val_files():
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
data = fetch(
"https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt"
).read_text()
return sorted(
[x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")]
)
def load_pair(file_path):
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(
file_path / "segmentation.nii.gz"
)
image_spacings = image.header["pixdim"][1:4].tolist()
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(
np.uint8
)
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
return image, label, image_spacings
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
if image_spacings != target_spacing:
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
spc_arr, targ_arr, shp_arr = (
np.array(image_spacings),
np.array(target_spacing),
np.array(image.shape[1:]),
)
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
image = F.interpolate(
torch.from_numpy(np.expand_dims(image, axis=0)),
size=new_shape,
mode="trilinear",
align_corners=True,
)
label = F.interpolate(
torch.from_numpy(np.expand_dims(label, axis=0)),
size=new_shape,
mode="nearest",
)
image = np.squeeze(image.numpy(), axis=0)
label = np.squeeze(label.numpy(), axis=0)
return image, label
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
image = np.clip(image, min_clip, max_clip)
image = (image - mean) / std
return image
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
current_shape = image.shape[1:]
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
paddings = [(0, 0)] + [
(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)
]
image = np.pad(image, paddings, mode="edge")
label = np.pad(label, paddings, mode="edge")
return image, label
def preprocess(file_path):
image, label, image_spacings = load_pair(file_path)
image, label = resample3d(image, label, image_spacings)
@ -65,16 +94,20 @@ def preprocess(file_path):
image, label = pad_to_min_shape(image, label)
return image, label
def iterate(val=True, shuffle=False):
if not val: raise NotImplementedError
if not val:
raise NotImplementedError
files = get_val_files()
order = list(range(0, len(files)))
if shuffle: random.shuffle(order)
if shuffle:
random.shuffle(order)
for file in files:
X, Y = preprocess(file)
X = np.expand_dims(X, axis=0)
yield (X, Y)
def gaussian_kernel(n, std):
gaussian_1d = signal.gaussian(n, std)
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
@ -84,14 +117,44 @@ def gaussian_kernel(n, std):
gaussian_3d /= gaussian_3d.max()
return gaussian_3d
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
def pad_input(
volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3
):
bounds = [
(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)
]
bounds = [
bounds[i]
if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i]
else bounds[i] + strides[i]
for i in range(dim)
]
paddings = [
bounds[2] // 2,
bounds[2] - bounds[2] // 2,
bounds[1] // 2,
bounds[1] - bounds[1] // 2,
bounds[0] // 2,
bounds[0] - bounds[0] // 2,
0,
0,
0,
0,
]
return (
F.pad(
torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val
).numpy(),
paddings,
)
def sliding_window_inference(
model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5
):
from tinygrad.jit import TinyJit
mdl_run = TinyJit(lambda x: model(x).realize())
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
@ -119,13 +182,40 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
for i in range(0, strides[0] * size[0], strides[0]):
for j in range(0, strides[1] * size[1], strides[1]):
for k in range(0, strides[2] * size[2], strides[2]):
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
out = mdl_run(
Tensor(
inputs[
...,
i : roi_shape[0] + i,
j : roi_shape[1] + j,
k : roi_shape[2] + k,
]
)
).numpy()
result[
...,
i : roi_shape[0] + i,
j : roi_shape[1] + j,
k : roi_shape[2] + k,
] += (
out * norm_patch
)
norm_map[
...,
i : roi_shape[0] + i,
j : roi_shape[1] + j,
k : roi_shape[2] + k,
] += norm_patch
result /= norm_map
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
result = result[
...,
paddings[4] : image_shape[0] + paddings[4],
paddings[2] : image_shape[1] + paddings[2],
paddings[0] : image_shape[2] + paddings[0],
]
return result, labels
if __name__ == "__main__":
for X, Y in iterate():
print(X.shape, Y.shape)

View File

@ -19,17 +19,30 @@ BASEDIR = pathlib.Path(__file__).parent / "librispeech"
with open(BASEDIR / "dev-clean-wav.json") as f:
ci = json.load(f)
FILTER_BANK = np.expand_dims(librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0)
FILTER_BANK = np.expand_dims(
librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0
)
WINDOW = librosa.filters.get_window("hann", 320)
def feature_extract(x, x_lens):
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
# pre-emphasis
x = np.concatenate((np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1)
x = np.concatenate(
(np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1
)
# stft
x = librosa.stft(x, n_fft=512, window=WINDOW, hop_length=160, win_length=320, center=True, pad_mode="reflect")
x = librosa.stft(
x,
n_fft=512,
window=WINDOW,
hop_length=160,
win_length=320,
center=True,
pad_mode="reflect",
)
x = np.stack((x.real, x.imag), axis=-1)
# power spectrum
@ -56,18 +69,24 @@ def feature_extract(x, x_lens):
features_mean[i, :] = features[i, :, : x_lens[i]].mean(axis=1)
features_std[i, :] = features[i, :, : x_lens[i]].std(axis=1, ddof=1)
features_std += 1e-5
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(features_std, 2)
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(
features_std, 2
)
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
def load_wav(file):
sample = soundfile.read(file)[0].astype(np.float32)
return sample, sample.shape[0]
def iterate(bs=1, start=0):
print(f"there are {len(ci)} samples in the dataset")
for i in range(start, len(ci), bs):
samples, sample_lens = zip(*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]])
samples, sample_lens = zip(
*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]]
)
samples = list(samples)
# pad to same length
max_len = max(sample_lens)
@ -75,7 +94,10 @@ def iterate(bs=1, start=0):
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
samples, sample_lens = np.array(samples), np.array(sample_lens)
yield feature_extract(samples, sample_lens), np.array([v["transcript"] for v in ci[i : i + bs]])
yield feature_extract(samples, sample_lens), np.array(
[v["transcript"] for v in ci[i : i + bs]]
)
if __name__ == "__main__":
X, Y = next(iterate())

View File

@ -12,133 +12,441 @@ import concurrent.futures
BASEDIR = pathlib.Path(__file__).parent / "open-images-v6-mlperf"
BUCKET_NAME = "open-images-dataset"
BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana',
'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle',
'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot',
'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread',
'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry',
'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart',
'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken',
'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin',
'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store',
'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard',
'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly',
'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant',
'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork',
'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses',
'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar',
'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels',
'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard',
'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair',
'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream',
'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite',
'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse',
'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror',
'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule',
'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building',
'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen',
'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow',
'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle',
'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion',
'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard',
'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon',
'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light',
'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan',
'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television',
'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower',
'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase',
'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch',
'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman',
'Zebra', 'Zucchini',
BBOX_ANNOTATIONS_URL = (
"https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
)
MAP_CLASSES_URL = (
"https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
)
MLPERF_CLASSES = [
"Airplane",
"Antelope",
"Apple",
"Backpack",
"Balloon",
"Banana",
"Barrel",
"Baseball bat",
"Baseball glove",
"Bee",
"Beer",
"Bench",
"Bicycle",
"Bicycle helmet",
"Bicycle wheel",
"Billboard",
"Book",
"Bookcase",
"Boot",
"Bottle",
"Bowl",
"Bowling equipment",
"Box",
"Boy",
"Brassiere",
"Bread",
"Broccoli",
"Bronze sculpture",
"Bull",
"Bus",
"Bust",
"Butterfly",
"Cabinetry",
"Cake",
"Camel",
"Camera",
"Candle",
"Candy",
"Cannon",
"Canoe",
"Carrot",
"Cart",
"Castle",
"Cat",
"Cattle",
"Cello",
"Chair",
"Cheese",
"Chest of drawers",
"Chicken",
"Christmas tree",
"Coat",
"Cocktail",
"Coffee",
"Coffee cup",
"Coffee table",
"Coin",
"Common sunflower",
"Computer keyboard",
"Computer monitor",
"Convenience store",
"Cookie",
"Countertop",
"Cowboy hat",
"Crab",
"Crocodile",
"Cucumber",
"Cupboard",
"Curtain",
"Deer",
"Desk",
"Dinosaur",
"Dog",
"Doll",
"Dolphin",
"Door",
"Dragonfly",
"Drawer",
"Dress",
"Drum",
"Duck",
"Eagle",
"Earrings",
"Egg (Food)",
"Elephant",
"Falcon",
"Fedora",
"Flag",
"Flowerpot",
"Football",
"Football helmet",
"Fork",
"Fountain",
"French fries",
"French horn",
"Frog",
"Giraffe",
"Girl",
"Glasses",
"Goat",
"Goggles",
"Goldfish",
"Gondola",
"Goose",
"Grape",
"Grapefruit",
"Guitar",
"Hamburger",
"Handbag",
"Harbor seal",
"Headphones",
"Helicopter",
"High heels",
"Hiking equipment",
"Horse",
"House",
"Houseplant",
"Human arm",
"Human beard",
"Human body",
"Human ear",
"Human eye",
"Human face",
"Human foot",
"Human hair",
"Human hand",
"Human head",
"Human leg",
"Human mouth",
"Human nose",
"Ice cream",
"Jacket",
"Jeans",
"Jellyfish",
"Juice",
"Kitchen & dining room table",
"Kite",
"Lamp",
"Lantern",
"Laptop",
"Lavender (Plant)",
"Lemon",
"Light bulb",
"Lighthouse",
"Lily",
"Lion",
"Lipstick",
"Lizard",
"Man",
"Maple",
"Microphone",
"Mirror",
"Mixing bowl",
"Mobile phone",
"Monkey",
"Motorcycle",
"Muffin",
"Mug",
"Mule",
"Mushroom",
"Musical keyboard",
"Necklace",
"Nightstand",
"Office building",
"Orange",
"Owl",
"Oyster",
"Paddle",
"Palm tree",
"Parachute",
"Parrot",
"Pen",
"Penguin",
"Personal flotation device",
"Piano",
"Picture frame",
"Pig",
"Pillow",
"Pizza",
"Plate",
"Platter",
"Porch",
"Poster",
"Pumpkin",
"Rabbit",
"Rifle",
"Roller skates",
"Rose",
"Salad",
"Sandal",
"Saucer",
"Saxophone",
"Scarf",
"Sea lion",
"Sea turtle",
"Sheep",
"Shelf",
"Shirt",
"Shorts",
"Shrimp",
"Sink",
"Skateboard",
"Ski",
"Skull",
"Skyscraper",
"Snake",
"Sock",
"Sofa bed",
"Sparrow",
"Spider",
"Spoon",
"Sports uniform",
"Squirrel",
"Stairs",
"Stool",
"Strawberry",
"Street light",
"Studio couch",
"Suit",
"Sun hat",
"Sunglasses",
"Surfboard",
"Sushi",
"Swan",
"Swimming pool",
"Swimwear",
"Tank",
"Tap",
"Taxi",
"Tea",
"Teddy bear",
"Television",
"Tent",
"Tie",
"Tiger",
"Tin can",
"Tire",
"Toilet",
"Tomato",
"Tortoise",
"Tower",
"Traffic light",
"Train",
"Tripod",
"Truck",
"Trumpet",
"Umbrella",
"Van",
"Vase",
"Vehicle registration plate",
"Violin",
"Wall clock",
"Waste container",
"Watch",
"Whale",
"Wheel",
"Wheelchair",
"Whiteboard",
"Window",
"Wine",
"Wine glass",
"Woman",
"Zebra",
"Zucchini",
]
def openimages():
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
if not ann_file.is_file():
fetch_openimages(ann_file)
return ann_file
# this slows down the conversion a lot!
# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py
def extract_dims(path): return Image.open(path).size[::-1]
def extract_dims(path):
return Image.open(path).size[::-1]
def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES):
def export_to_coco(
class_map,
annotations,
image_list,
dataset_path,
output_path,
classes=MLPERF_CLASSES,
):
output_path.parent.mkdir(parents=True, exist_ok=True)
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"])
class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner")
categories_map = pd.DataFrame(
[(i, c) for i, c in enumerate(classes)],
columns=["category_id", "category_name"],
)
class_map = class_map.merge(
categories_map, left_on="DisplayName", right_on="category_name", how="inner"
)
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
annotations = annotations.merge(class_map, on="LabelName", how="inner")
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand")
annotations[["height", "width"]] = annotations.apply(
lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"),
axis=1,
result_type="expand",
)
# Images
imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None}
for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows())
imgs = [
{
"id": int(id + 1),
"file_name": f"{image_id}.jpg",
"height": row["height"],
"width": row["width"],
"license": None,
"coco_url": None,
}
for (id, image_id), row in (
annotations.groupby(["image_id", "ImageID"]).first().iterrows()
)
]
# Annotations
annots = []
for i, row in annotations.iterrows():
xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]]
x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h
coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h}
coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]})
xmin, ymin, xmax, ymax, img_w, img_h = [
row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]
]
x, y, w, h = (
xmin * img_w,
ymin * img_h,
(xmax - xmin) * img_w,
(ymax - ymin) * img_h,
)
coco_annot = {
"id": int(i) + 1,
"image_id": int(row["image_id"] + 1),
"category_id": int(row["category_id"]),
"bbox": [x, y, w, h],
"area": w * h,
}
coco_annot.update(
{
k: row[k]
for k in [
"IsOccluded",
"IsInside",
"IsDepiction",
"IsTruncated",
"IsGroupOf",
]
}
)
coco_annot["iscrowd"] = int(row["IsGroupOf"])
annots.append(coco_annot)
info = {"dataset": "openimages_mlperf", "version": "v6"}
coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots}
coco_annotations = {
"info": info,
"licenses": [],
"categories": cats,
"images": imgs,
"annotations": annots,
}
with open(output_path, "w") as fp:
json.dump(coco_annotations, fp)
def get_image_list(class_map, annotations, classes=MLPERF_CLASSES):
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique()
image_ids = annotations[np.isin(annotations["LabelName"], labels)][
"ImageID"
].unique()
return image_ids
def download_image(bucket, image_id, data_dir):
try:
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
except botocore.exceptions.ClientError as exception:
sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}")
sys.exit(
f"ERROR when downloading image `validation/{image_id}`: {str(exception)}"
)
def fetch_openimages(output_fn):
bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
bucket = boto3.resource(
"s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)
).Bucket(BUCKET_NAME)
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
annotations_dir.mkdir(parents=True, exist_ok=True)
data_dir.mkdir(parents=True, exist_ok=True)
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1]
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split("/")[-1]
fetch(BBOX_ANNOTATIONS_URL, annotations_fn)
annotations = pd.read_csv(annotations_fn)
classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1]
classmap_fn = annotations_dir / MAP_CLASSES_URL.split("/")[-1]
fetch(MAP_CLASSES_URL, classmap_fn)
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
image_list = get_image_list(class_map, annotations)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list]
for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))):
futures = [
executor.submit(download_image, bucket, image_id, data_dir)
for image_id in image_list
]
for future in (
t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))
):
t.set_description(f"Downloading images")
future.result()
print("Converting annotations to COCO format...")
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
def image_load(fn):
img_folder = BASEDIR / "validation/data"
img = Image.open(img_folder / fn).convert('RGB')
img = Image.open(img_folder / fn).convert("RGB")
import torchvision.transforms.functional as F
ret = F.resize(img, size=(800, 800))
ret = np.array(ret)
return ret, img.size[::-1]
def prepare_target(annotations, img_id, img_size):
boxes = [annot["bbox"] for annot in annotations]
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
@ -150,7 +458,13 @@ def prepare_target(annotations, img_id, img_size):
classes = [annot["category_id"] for annot in annotations]
classes = np.array(classes, dtype=np.int64)
classes = classes[keep]
return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size}
return {
"boxes": boxes,
"labels": classes,
"image_id": img_id,
"image_size": img_size,
}
def iterate(coco, bs=8):
image_ids = sorted(coco.imgs.keys())

View File

@ -13,10 +13,15 @@ if __name__ == "__main__":
assert x.shape[0] == y.shape[0]
bs = x.shape[0]
if X is None:
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8)
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64)
X = Tensor.empty(
sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8
)
Y = Tensor.empty(
sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64
)
print(X.shape, Y.shape)
X[idx : idx + bs].assign(x)
Y[idx : idx + bs].assign(y)
idx += bs
if idx >= sz: break
if idx >= sz:
break

View File

@ -6,9 +6,14 @@ import numpy as np
from tinygrad.helpers import fetch
BASEDIR = Path(__file__).parent / "squad"
def init_dataset():
os.makedirs(BASEDIR, exist_ok=True)
fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
fetch(
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
BASEDIR / "dev-v1.1.json",
)
with open(BASEDIR / "dev-v1.1.json") as f:
data = json.load(f)["data"]
@ -32,14 +37,17 @@ def init_dataset():
qa_id = qa["id"]
q_text = qa["question"]
examples.append({
examples.append(
{
"id": qa_id,
"question": q_text,
"context": doc_tokens,
"answers": list(map(lambda x: x["text"], qa["answers"]))
})
"answers": list(map(lambda x: x["text"], qa["answers"])),
}
)
return examples
def _check_is_max_context(doc_spans, cur_span_index, position):
best_score, best_span_index = None, None
for di, (doc_start, doc_length) in enumerate(doc_spans):
@ -56,6 +64,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
best_span_index = di
return cur_span_index == best_span_index
def convert_example_to_features(example, tokenizer):
query_tokens = tokenizer.tokenize(example["question"])
@ -101,7 +110,9 @@ def convert_example_to_features(example, tokenizer):
for i in range(doc_length):
split_token_index = doc_start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
token_is_max_context[len(tokens)] = _check_is_max_context(
doc_spans, di, split_token_index
)
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1)
tokens.append("[SEP]")
@ -119,17 +130,24 @@ def convert_example_to_features(example, tokenizer):
assert len(input_mask) == 384
assert len(segment_ids) == 384
outputs.append({
outputs.append(
{
"input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(
np.float32
),
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(
np.float32
),
"token_to_orig_map": token_to_orig_map,
"token_is_max_context": token_is_max_context,
"tokens": tokens,
})
}
)
return outputs
def iterate(tokenizer, start=0):
examples = init_dataset()
print(f"there are {len(examples)} pairs in the dataset")
@ -140,8 +158,11 @@ def iterate(tokenizer, start=0):
# we need to yield all features here as the f1 score is the maximum over all features
yield features, example
if __name__ == "__main__":
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))
tokenizer = BertTokenizer(
str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt")
)
X, Y = next(iterate(tokenizer))
print(" ".join(X[0]["tokens"]))

View File

@ -5,11 +5,13 @@ from tinygrad.helpers import DEBUG, getenv
import multiprocessing as mp
import os
# this needs to be called before everything else if you are using distributed
def preinit():
os.environ["DELAYED_RUNTIME_INIT"] = "1"
mp.set_start_method("spawn")
# out-of-band communication/synchronization
class _OOB:
def __init__(self, pipes: List[Tuple[Connection, Connection]]):
@ -22,14 +24,18 @@ class _OOB:
# receive some data from a target rank, blocks until data is received
def recv(self, target_rank: int) -> Any:
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
OOB: Optional[_OOB] = None
def init_oob(world_size: int):
os.environ["WORLD_SIZE"] = str(world_size)
global OOB
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
# this runs in the spawned process so we can do all the delayed runtime initialization
def _process_wrap(rank: int, device: str, oob: _OOB, fn: Callable, args=()):
# setup the rank
@ -41,19 +47,27 @@ def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
# do specific runtime initialization for distributed
from tinygrad import Device
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1])
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(
device.split(":")[-1]
)
if "GPU" in device:
from tinygrad.runtime.ops_gpu import CL
CL.post_init(device_num)
elif "HIP" in device:
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(device_num)
if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}")
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(
device_num
)
if DEBUG >= 1:
print(f"distributed process {rank} initialized runtime for device {device}")
# convert device to be process specific
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
fn(*args)
# wrapper around mp.Process that initializes the runtime
def spawn(rank: int, device: str, fn: Callable, args=()) -> mp.Process:
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()

View File

@ -3,6 +3,7 @@ from tinygrad.helpers import getenv
from extra.dist import world
def allreduce(t: Tensor) -> Tensor:
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
@ -11,7 +12,9 @@ def allreduce(t:Tensor) -> Tensor:
# pad to evenly divide
if flattened.shape[0] % WORLD_SIZE != 0:
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
flattened = Tensor.cat(
flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE))
)
# chunk
chunks = flattened.chunk(WORLD_SIZE, dim=0)

83
extra/dist/world.py vendored
View File

@ -4,15 +4,18 @@ from multiprocessing import shared_memory
from tinygrad.helpers import DEBUG, colored, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut
try:
import gpuctypes.hip as hip
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
except: RawHIPBuffer = None
except:
RawHIPBuffer = None
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.jit import CacheCollector
from tinygrad.tensor import Tensor, Function
import numpy as np
# match the function signature of JITRunner so we can put it in the cache
def __send_rb(args, variables=None, wait=False, jit=False):
x, target_rank, y = args[:3]
@ -20,19 +23,31 @@ def __send_rb(args, variables=None, wait=False, jit=False):
check(hip.hipSetDevice(x._device))
check(hip.hipDeviceSynchronize())
else:
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
else: y.fromCPU(x.toCPU())
if isinstance(x, RawBufferCopyInOut):
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
else:
y.fromCPU(x.toCPU())
dist.OOB.send(None, target_rank)
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
if DEBUG >= 2:
print(
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}"
)
def __recv_rb(args, variables=None, wait=False, jit=False):
x, target_rank, y = args[:3]
dist.OOB.recv(target_rank)
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
x._transfer(y)
elif isinstance(x, RawBuffer): x._copyin(y.toCPU())
else: x.fromCPU(y.toCPU())
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}")
elif isinstance(x, RawBuffer):
x._copyin(y.toCPU())
else:
x.fromCPU(y.toCPU())
if DEBUG >= 2:
print(
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}"
)
# send a rawbuffer from out rank to the target rank
def _send_rb(x: RawBuffer, target_rank: int):
@ -40,7 +55,11 @@ def _send_rb(x:RawBuffer, target_rank:int):
# send ipc handle
check(hip.hipSetDevice(x._device))
check(hip.hipDeviceSynchronize())
check(hip.hipIpcGetMemHandle(ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf))
check(
hip.hipIpcGetMemHandle(
ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf
)
)
dist.OOB.send((handle, x._device), target_rank)
# jit support
@ -48,20 +67,26 @@ def _send_rb(x:RawBuffer, target_rank:int):
CacheCollector.add(__send_rb, [x, target_rank, None], {})
else:
# create shared memory
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
shm_name = (
s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)
).name
s.close()
# copy the buffer into shared memory
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
# fast path when we can directly copyout
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
else: y.fromCPU(x.toCPU())
if isinstance(x, RawBufferCopyInOut):
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
else:
y.fromCPU(x.toCPU())
dist.OOB.send(shm_name, target_rank)
# jit support
CacheCollector.add(__send_rb, [x, target_rank, y], {})
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
if DEBUG >= 2:
print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
# receive a rawbuffer from the target rank
def _recv_rb(x: RawBuffer, target_rank: int):
@ -69,7 +94,9 @@ def _recv_rb(x:RawBuffer, target_rank:int):
# open ipc handle
handle, y_device = dist.OOB.recv(target_rank)
check(hip.hipSetDevice(y_device))
check(hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0))
check(
hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0)
)
# build a new buffer
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
@ -81,34 +108,50 @@ def _recv_rb(x:RawBuffer, target_rank:int):
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
# fast path when we can directly copyin
if isinstance(x, RawBuffer): x._copyin(y.toCPU())
else: x.fromCPU(y.toCPU())
if isinstance(x, RawBuffer):
x._copyin(y.toCPU())
else:
x.fromCPU(y.toCPU())
# jit support
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
if DEBUG >= 2:
print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
# sends a lazybuffer from our rank to the target rank
def _send_lb(x: LazyBuffer, target_rank: int) -> None:
assert x.st.contiguous and x.realized, "sending buffer must be contiguous and realized"
assert (
x.st.contiguous and x.realized
), "sending buffer must be contiguous and realized"
_send_rb(x.realized, target_rank)
# receive a lazybuffer from the target rank
def _recv_lb(x: LazyBuffer, target_rank: int) -> LazyBuffer:
assert x.st.contiguous and x.realized, "receiving buffer must be contiguous and realized"
assert (
x.st.contiguous and x.realized
), "receiving buffer must be contiguous and realized"
_recv_rb(x.realized, target_rank)
return x
class Send(Function):
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype
_send_lb(x, target_rank)
return x
class Recv(Function):
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
self.target_rank = target_rank
return _recv_lb(x, target_rank)
def send(x:Tensor, target_rank:int) -> Tensor: return Send.apply(x.contiguous().realize(), target_rank=target_rank)
def recv(x:Tensor, target_rank:int) -> Tensor: return Recv.apply(x.contiguous().realize(), target_rank=target_rank)
def send(x: Tensor, target_rank: int) -> Tensor:
return Send.apply(x.contiguous().realize(), target_rank=target_rank)
def recv(x: Tensor, target_rank: int) -> Tensor:
return Recv.apply(x.contiguous().realize(), target_rank=target_rank)

View File

@ -17,5 +17,10 @@ if __name__ == "__main__":
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
for f in cur3.fetchall():
v = pickle.loads(f[-1])
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
print(
" ",
len(f[0]) if isinstance(f[0], str) else f[0],
f[1:-1],
str(v)[0:50],
)
# print(f"{len(k):10d}, {sk} -> {v}")

View File

@ -7,77 +7,190 @@ import json
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
def compile_net(
run: TinyJit, special_names: Dict[int, str]
) -> Tuple[
Dict[str, str],
List[Tuple[str, List[str], List[int]]],
Dict[str, Tuple[int, DType, int]],
Dict[str, Tensor],
]:
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for ji in run.jit_cache:
fxn = ji.prg
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
functions[
fxn.name
] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
for i, arg in enumerate(ji.rawbufs):
key = id(arg)
if key not in bufs:
if key in special_names:
bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key)
bufs[key] = (
special_names[key],
arg.size * arg.dtype.itemsize,
arg.dtype,
key,
)
else:
bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key)
bufs[key] = (
f"buf_{bufnum}",
arg.size * arg.dtype.itemsize,
arg.dtype,
key,
)
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
if i > 0:
bufs_to_save[
bufs[key][0]
] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size))
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
return (
functions,
statements,
{name: (size, dtype, key) for (name, size, dtype, key) in bufs.values()},
bufs_to_save,
)
def jit_model(model, *args) -> Tuple[TinyJit, Dict[int, str]]:
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
assert hasattr(model, "forward") or callable(
model
), "model needs a forward function"
@TinyJit
def run(*x):
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
assert (
isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor)
), "model output must be a Tensor, tuple, or a list of Tensors for export"
out = [out] if isinstance(out, Tensor) else out
return [o.realize() for o in out]
# twice to run the JIT
for _ in range(2): the_output = run(*args)
for _ in range(2):
the_output = run(*args)
special_names = {}
# hack to put the inputs back
for (j, i), idx in run.input_replace.items():
realized_input = args[idx].lazydata.realized
run.jit_cache[j].rawbufs[i] = realized_input
special_names[id(realized_input)] = f'input{idx}'
special_names[id(realized_input)] = f"input{idx}"
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
for i, output in enumerate(the_output):
special_names[id(output.lazydata.realized)] = f'output{i}'
special_names[id(output.lazydata.realized)] = f"output{i}"
return run, special_names
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
def export_model_clang(
functions: Dict[str, str],
statements: Dict[str, Tuple[str, int, int]],
bufs: Dict[str, Tuple[str, int, int]],
bufs_to_save: Dict[str, Tensor],
input_names: List[str],
output_names: List[str],
) -> str:
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
cprog = [CLANG_PROGRAM_HEADER]
for name, cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
weight = "".join(["\\x%02X" % x for x in bytes(cl._buf)])
cprog.append(f'unsigned char {name}_data[] = "{weight}";')
inputs = ", ".join([f'float* {input}' for input in input_names])
outputs = ", ".join([f'float* {output}' for output in output_names])
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
inputs = ", ".join([f"float* {input}" for input in input_names])
outputs = ", ".join([f"float* {output}" for output in output_names])
cprog += [
f"float {name}[{len}];"
if name not in bufs_to_save
else f"float *{name} = (float *){name}_data;"
for name, (len, dtype, _key) in bufs.items()
if name not in ["input", "outputs"]
]
cprog += list(functions.values())
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(cprog)
cprog += (
[f"void net({inputs}, {outputs}) {{"]
+ [
f"{name}({', '.join(args)});"
for (name, args, _global_size, _local_size) in statements
]
+ ["}"]
)
return "\n".join(cprog)
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f"""
def export_model_webgpu(
functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names
) -> Tuple[str, int, int]:
kernel_code = "\n\n".join(
[
f"const {key} = `{code.replace(key, 'main')}`;"
for key, code in functions.items()
]
)
kernel_names = ", ".join(
[name for (name, _args, _global_size, _local_size) in statements]
)
kernel_calls = "\n ".join(
[
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
for i, (_name, args, global_size, _local_size) in enumerate(statements)
]
)
_bufs = "\n ".join(
[
f"const {name} = "
+ (
f"createEmptyBuf(device, {size});"
if _key not in weight_names
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))"
)
+ ";"
for name, (size, dtype, _key) in bufs.items()
]
)
gpu_write_bufs = "\n ".join(
[
f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
for i, input_name in enumerate(input_names)
]
)
input_writers = "\n ".join(
[
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
+ f"_{inp_name});"
+ f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);"
for i, inp_name in enumerate(input_names)
]
)
gpu_read_bufs = "\n ".join(
[
f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});"
for i, output_name in enumerate(output_names)
]
)
outbuf_copies = "\n ".join(
[
f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);"
for i, output_name in enumerate(output_names)
]
)
output_readers = "\n ".join(
[
f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();"
for i in range(len(output_names))
]
)
output_return = "[{}]".format(
",".join([f"resultBuffer{i}" for i in range(len(output_names))])
)
return (
f"""
const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
@ -134,10 +247,15 @@ const setupNet = async (device, safetensor) => {{
return {output_return};
}}
}}
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
"""
+ f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
)
def export_model(model, target: str, *inputs):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
assert (
Device.DEFAULT in EXPORT_SUPPORTED_DEVICE
), "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
run, special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
state = get_state_dict(model)
@ -146,34 +264,56 @@ def export_model(model, target:str, *inputs):
output_names = [name for _, name in special_names.items() if "output" in name]
prg = ""
if target == "clang":
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
prg = export_model_clang(
functions, statements, bufs, bufs_to_save, input_names, output_names
)
elif target == "webgpu":
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
prg = export_model_webgpu(
functions,
statements,
bufs,
bufs_to_save,
weight_names,
input_names,
output_names,
)
else:
prg = json.dumps({
prg = json.dumps(
{
"backend": Device.DEFAULT,
"inputs": [{
"size": bufs[name][0],
"dtype": bufs[name][1].name
} for name in input_names],
"outputs": [{
"size": bufs[name][0],
"dtype": bufs[name][1].name
} for name in output_names],
"inputs": [
{"size": bufs[name][0], "dtype": bufs[name][1].name}
for name in input_names
],
"outputs": [
{"size": bufs[name][0], "dtype": bufs[name][1].name}
for name in output_names
],
"functions": functions,
"statements": [{
"statements": [
{
"kernel": kernel,
"args": args,
"global_size": global_size,
"local_size": local_size
} for (kernel, args, global_size, local_size) in statements],
"local_size": local_size,
}
for (kernel, args, global_size, local_size) in statements
],
"buffers": {
name: {
"size": size,
"dtype": dtype.name,
"id": weight_names[_key] if _key in weight_names else ""
} for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]
"id": weight_names[_key] if _key in weight_names else "",
}
})
for name, (size, dtype, _key) in bufs.items()
if name not in ["input", "outputs"]
},
}
)
return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state
return (
prg,
{input: bufs[input][0] for input in input_names},
{output: bufs[output][0] for output in output_names},
state,
)

View File

@ -2,6 +2,7 @@
import numpy as np
import time
import sys
np.set_printoptions(linewidth=160)
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const
@ -11,18 +12,61 @@ from llvmlite import ir # type: ignore
# https://github.com/corsix/amx/blob/main/Instructions.md
# 12 lines for AMX support
from functools import partialmethod
class AMX:
@staticmethod
def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True)
def nop_op_imm5(op, imm5, builder):
builder.asm(
ir.FunctionType(ir.VoidType(), []),
f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}",
"",
tuple(),
True,
)
@staticmethod
def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True)
def op_gpr(op, builder, gpr):
builder.asm(
ir.FunctionType(ir.VoidType(), [ir.IntType(64)]),
f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0",
"r",
(gpr,),
True,
)
set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3)
ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7)
ldx, ldy, stx, sty = (
partialmethod(op_gpr, 0),
partialmethod(op_gpr, 1),
partialmethod(op_gpr, 2),
partialmethod(op_gpr, 3),
)
ldz, stz, ldzi, stzi = (
partialmethod(op_gpr, 4),
partialmethod(op_gpr, 5),
partialmethod(op_gpr, 6),
partialmethod(op_gpr, 7),
)
extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13)
mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16)
vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22)
fma64, fms64, fma32, fms32 = (
partialmethod(op_gpr, 10),
partialmethod(op_gpr, 11),
partialmethod(op_gpr, 12),
partialmethod(op_gpr, 13),
)
mac16, fma16, fms16 = (
partialmethod(op_gpr, 14),
partialmethod(op_gpr, 15),
partialmethod(op_gpr, 16),
)
vecint, vecfp, matint, matfp, genlut = (
partialmethod(op_gpr, 18),
partialmethod(op_gpr, 19),
partialmethod(op_gpr, 20),
partialmethod(op_gpr, 21),
partialmethod(op_gpr, 22),
)
N = 4096
@ -54,7 +98,11 @@ c = LLVMBuffer.fromCPU(np.zeros(256))
bufs = [c, a, b]
module = ir.Module(name=__file__)
func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
func = ir.Function(
module,
ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()] * 3),
name="exec",
)
# load all
entry = ir.IRBuilder(func.append_basic_block(name="entry"))
@ -69,7 +117,19 @@ y.add_incoming(int_const(0), entry._block)
yp = loop_1_exit.add(y, int_const(32 * 2))
y.add_incoming(yp, loop_1_exit._block)
prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch")
prefetch_function = ir.Function(
module,
ir.FunctionType(
ir.VoidType(),
[
ir.PointerType(ir.FloatType()),
ir.IntType(32),
ir.IntType(32),
ir.IntType(32),
],
),
name="llvm.prefetch",
)
xptr = y
addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
@ -79,7 +139,12 @@ addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1 << 62), addr))
xptr = loop_1_exit.add(xptr, int_const(32))
AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))))
AMX.ldy(
loop_1_exit,
loop_1_exit.add(
int_const(1 << 62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
),
)
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16 * 4) << 10))
@ -93,7 +158,9 @@ AMX.clr(exit)
entry.branch(loop_1._block)
loop_1.branch(loop_1_exit._block)
loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block)
loop_1_exit.cbranch(
loop_1_exit.icmp_unsigned("==", yp, int_const(N * N)), exit._block, loop_1._block
)
exit.ret(int_const(0))
cfunc = LLVM().exec(module, bufs, N**2)
@ -185,4 +252,3 @@ np.testing.assert_allclose(c.toCPU()[:sn.shape[0]], sn, atol=1e-4, rtol=1e-4)
print(cn.astype(np.int64))
np.testing.assert_allclose(c.toCPU(), cn, atol=1e-4, rtol=1e-5)
"""

View File

@ -1,5 +1,6 @@
import os
import numpy as np
os.environ["CUDA"] = "1"
from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda
@ -21,7 +22,10 @@ c = RawCUDABuffer.fromCPU(np.ones((N,N),dtype=np.float32))
FLOPS = N * N * N * 2
BW = N * N * 3 * 4
prog = CUDAProgram("wmma_example", compile_cuda(f"""
prog = CUDAProgram(
"wmma_example",
compile_cuda(
f"""
#include <mma.h>
using namespace nvcuda;
@ -88,10 +92,23 @@ __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLO
}}
}}
}}
"""))
"""
),
)
global_size, local_size = [(N // 16) // 4, (N // 16) // 4], [32, 1, 1]
tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
tm = min(
[
prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)
for _ in range(20)
]
)
print(
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
)
np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), c.toCPU().reshape((N,N)).T, atol=1e-2)
np.testing.assert_allclose(
na.T.astype(np.float32) @ nb.T.astype(np.float32),
c.toCPU().reshape((N, N)).T,
atol=1e-2,
)

View File

@ -15,6 +15,7 @@ from tinygrad.helpers import partition, GlobalCounters, Context, getenv, prod, d
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
from tinygrad.ops import LoadOps, ReduceOps
def single_kernel():
# single kernel
sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4)
@ -22,11 +23,17 @@ def single_kernel():
x = CLBuffer(prod(sz2), dtypes.imageh(sz2))
w = CLBuffer(prod(sz3), dtypes.imageh(sz3))
old = CLProgram("r_32_16_16_64_4_4_4", open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read())
old_tms = [old([1,1,32], [16,16,1], out, x, w, wait=True)*1e6 for _ in range(5)]
old = CLProgram(
"r_32_16_16_64_4_4_4",
open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read(),
)
old_tms = [
old([1, 1, 32], [16, 16, 1], out, x, w, wait=True) * 1e6 for _ in range(5)
]
print(old_tms, 67.107 / min(old_tms) * 1e3)
exit(0)
# CONV=0 PYTHONPATH="." LATEDEBUG=5 OPT=99 IMAGE=2 FLOAT16=1 NOLOCALS=1 python3 extra/fastvits/fastvits_speed.py
if __name__ == "__main__":
# single_kernel()
@ -43,7 +50,11 @@ if __name__ == "__main__":
out = x.sequential([c1, c2, c3, c4, c5])
schedule = out.lazydata.schedule()
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps and any(y.op in ReduceOps for y in x.ast.get_lazyops()))
schedule, schedule_input = partition(
schedule,
lambda x: x.ast.op not in LoadOps
and any(y.op in ReduceOps for y in x.ast.get_lazyops()),
)
run_schedule(schedule_input)
run_schedule(schedule[: getenv("CONV")])
print("*** init done ***")

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
import os
# os.environ['OMP_NUM_THREADS'] = '1'
import time
import numpy as np

View File

@ -78,5 +78,3 @@ if __name__ == "__main__":
new_tms.append(new([256, 1, 1], [4, 16, 1], out, x, w, b, wait=True))
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")

View File

@ -30,12 +30,21 @@ a = hipallocator.alloc(N*N*4)
b = hipallocator.alloc(N * N * 2)
c = hipallocator.alloc(N * N * 2)
na = np.empty(N * N, np.float32)
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
nb = (
np.random.default_rng()
.standard_normal(size=(N, N), dtype=np.float32)
.astype(np.float16)
)
nc = (
np.random.default_rng()
.standard_normal(size=(N, N), dtype=np.float32)
.astype(np.float16)
)
hipallocator.copyin(b, bytearray(nb))
hipallocator.copyin(c, bytearray(nc))
lib = compile_hip(f"""
lib = compile_hip(
f"""
#define F32
typedef float float8 __attribute__((ext_vector_type(8)));
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
@ -92,10 +101,12 @@ extern "C" __global__ void __launch_bounds__ (128, 1) test(float* c, __half* a,
}}
}}
}}
}}""")
}}"""
)
prog = HIPProgram(device, "test", lib)
def timeit(fxn):
st = time.perf_counter()
et = fxn()
@ -103,11 +114,28 @@ def timeit(fxn):
# print(f"{ret*1e6:.2f} us")
return et
global_size, local_size = [N // (KX * 16 * 2), N // (KY * 16 * 2), 1], [32, 2, 2]
print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}")
tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(1000)])
print(
"global/local size",
global_size,
local_size,
f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}",
)
tm = min(
[
timeit(
lambda: prog(
a, b, c, global_size=global_size, local_size=local_size, wait=True
)
)
for _ in range(1000)
]
)
hipallocator.copyout(flat_mv(na.data), a)
na = na.reshape(N, N)
comp = nb.astype(np.float32) @ nc.astype(np.float32)
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
print(
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
)
np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)

View File

@ -14,7 +14,12 @@ A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
OPS = DEVICES * BS * N * N * N * 2
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
def matmul(A, B):
return jnp.matmul(A, B, preferred_element_type=jnp.float32)
pmatmul = jax.pmap(matmul)
MAX_TFLOPS = 123 * DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
@ -23,5 +28,6 @@ for i in range(10):
C = pmatmul(A, B).block_until_ready()
et = time.perf_counter() - st
tflops = (OPS * 1e-12) / et
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
print(
f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}"
)

View File

@ -1,4 +1,5 @@
import os
# os.environ["METAL"] = "1"
import numpy as np
@ -18,14 +19,16 @@ nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float
try:
import time, torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
b = torch.from_numpy(nb).to("mps")
c = torch.from_numpy(nc).to("mps")
def torch_prog(b, c):
st = time.perf_counter()
a = torch.nn.functional.conv2d(b, c, padding=PADDING)
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
except RuntimeError:
@ -34,16 +37,23 @@ except RuntimeError:
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad import Device
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return b.conv2d(c, padding=PADDING).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
Device[a.device].synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(5)])
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")

View File

@ -1,4 +1,5 @@
import os
os.environ["METAL"] = "1"
import time
import numpy as np
@ -10,15 +11,22 @@ LID = 2
a = RawMetalBuffer(N * N, dtypes.float32)
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nb = np.random.default_rng().standard_normal(
size=(N, N), dtype=np.float32
) # .astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(
size=(N, N), dtype=np.float32
) # .astype(np.int32).astype(np.float32)
b = RawMetalBuffer.fromCPU(nb)
c = RawMetalBuffer.fromCPU(nc)
FLOPS = N * N * N * 2
BW = N * N * 3 * 4
prog = MetalProgram("test", compile_metal(f"""
prog = MetalProgram(
"test",
compile_metal(
f"""
#include <metal_stdlib>
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
using namespace metal;
@ -80,46 +88,83 @@ kernel void test(device float *a, device const float *data1, device const float
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
}}"""))
}}"""
),
)
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
tm = min(
[
timeit(
lambda: prog(
a,
b,
c,
global_size=[N // (8 * 4), N // (8 * 4 * LID), 1],
local_size=[32, LID, 1],
wait=True,
)
)
for _ in range(20)
]
)
na = a.toCPU().reshape(N, N)
comp = nb @ nc
if N <= 32:
print(na)
print(comp)
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
print(
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
)
np.testing.assert_allclose(na, comp, atol=1e-3)
import torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
b = torch.from_numpy(nb).to("mps")
c = torch.from_numpy(nc).to("mps")
def torch_prog(b, c):
st = time.perf_counter()
a = b @ c
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
print(
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch"
)
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.runtime.ops_metal import METAL
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return (b @ c).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
METAL.synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")
print(
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad"
)

View File

@ -1,4 +1,5 @@
import os
# os.environ["METAL"] = "1"
import numpy as np
import time, torch, torch.mps
@ -10,6 +11,7 @@ from tinygrad import Device
from tinygrad.helpers import colored, getenv, CI
import os
os.environ["METAL"] = "1"
import time
import numpy as np
@ -20,27 +22,38 @@ N = 16384
M = 4096
FLOPS = N * M * 2
nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nb = np.random.default_rng().standard_normal(
size=(N), dtype=np.float32
) # .astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(
size=(N, M), dtype=np.float32
) # .astype(np.int32).astype(np.float32)
import torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
b = torch.from_numpy(nb).to("mps")
c = torch.from_numpy(nc).to("mps")
def torch_prog(b, c):
st = time.perf_counter()
a = b @ c
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch")
print(
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch"
)
torch_a = (b @ c).cpu()
WORKSIZE_ROW = 16
WORKSIZE_COL = 1
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
GLOBAL_SIZE = [M // (LOCAL_SIZE[0] * LOCAL_SIZE[1] * 4), 1, 1]
prog = compile_metal(f"""
prog = compile_metal(
f"""
#include <metal_stdlib>
using namespace metal;
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
@ -86,41 +99,59 @@ kernel void test(device float* data0, const device float* data1, const device fl
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
}}
}}
""")
"""
)
prog = MetalProgram("test", prog)
# print(prog_string)
na = np.zeros(M, dtype=np.float32)
b = RawMetalBuffer.fromCPU(nb)
c = RawMetalBuffer.fromCPU(nc)
def metalrun():
a = RawMetalBuffer.fromCPU(na)
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
return a
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(metalrun) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
print(
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal"
)
metal_a = metalrun().toCPU().reshape(M)
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.runtime.ops_metal import METAL
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return (b @ c).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
METAL.synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(200)])
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad")
print(
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad"
)
tiny_a = tiny_jit(b, c).numpy()
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)

View File

@ -2,14 +2,28 @@ import numpy as np
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
N = getenv("N", 4096)
CNT = getenv("CNT", 10)
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
a, b = (
Tensor.rand(N, N, dtype=dtype_in).realize(),
Tensor.rand(N, N, dtype=dtype_in).realize(),
)
for i in range(CNT):
if i > 0 and getenv("RAND", 0) != 0:
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
a, b = (
Tensor.rand(N, N, dtype=dtype_in).realize(),
Tensor.rand(N, N, dtype=dtype_in).realize(),
)
c = (
(a.reshape(N, 1, N) * b.permute(1, 0).reshape(1, N, N))
.float()
.sum(axis=2)
.realize()
if getenv("ACCUM_FP32")
else (a @ b).realize()
)
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
nc = c.numpy()
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2)

View File

@ -1,13 +1,13 @@
import time
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
gpus = tf.config.list_physical_devices("GPU")
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
logical_gpus = tf.config.list_logical_devices("GPU")
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
@ -26,8 +26,12 @@ for dtype in [tf.float16, tf.float32]:
def tf_prog(b, c):
st = time.perf_counter()
a = tf.matmul(b, c)
tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done.
tf.debugging.check_numerics(
a, "Nan or Inf in result"
) # Ensures that the calculation is done.
return time.perf_counter() - st
tm = min([tf_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
print(
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
)

View File

@ -13,5 +13,8 @@ for dtype in [torch.float16, torch.float32]:
a = b @ c
torch.cuda.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
print(
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
)

View File

@ -5,6 +5,7 @@ M, N, K = 1024, 1024, 1024
try:
import tvm
from tvm import te
# print(tvm.target.Target.list_kinds())
# c, opencl
@ -39,9 +40,13 @@ C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
sched = C.lazydata.schedule()
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False))
lin = Linearizer(
sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False)
)
# lin.hand_coded_optimizations()
lin.linearize()
from tinygrad.runtime.ops_clang import renderer
src = renderer("mmult", lin.uops)
print(src)

View File

@ -1,11 +1,13 @@
import numpy as np
from tinygrad.tensor import Tensor
def mask_like(like, mask_inx, mask_value=1.0):
mask = np.zeros_like(like).reshape(-1)
mask[mask_inx] = mask_value
return mask.reshape(like.shape)
def jacobian(func, input):
output = func(input)
@ -19,13 +21,14 @@ def jacobian(func, input):
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum()
o_scalar = Tensor(mask_like(output.numpy(), o, 1.0)).mul(output).sum()
o_scalar.backward()
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
J[o, i] = grad
return J
def numerical_jacobian(func, input, eps=1e-3):
output = func(input)
@ -36,14 +39,19 @@ def numerical_jacobian(func, input, eps = 1e-3):
for i in range(ji):
eps_perturb = mask_like(input.numpy(), i, mask_value=eps)
output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
output_perturb_add = (
func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
)
output_perturb_sub = (
func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
)
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2 * eps)
NJ[:, i] = grad_approx
return NJ
def gradcheck(func, input, eps=1e-3, atol=1e-3, rtol=1e-3):
NJ = numerical_jacobian(func, input, eps)
J = jacobian(func, input)

View File

@ -2,6 +2,7 @@ import multiprocessing, subprocess
import cloudpickle
from typing import Any
def _early_exec_process(qin, qout):
while True:
path, inp = qin.get()
@ -10,41 +11,62 @@ def _early_exec_process(qin, qout):
except Exception as e:
qout.put(e)
def enable_early_exec():
qin: multiprocessing.Queue = multiprocessing.Queue()
qout: multiprocessing.Queue = multiprocessing.Queue()
p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout))
p.daemon = True
p.start()
def early_exec(x):
qin.put(x)
ret = qout.get()
if isinstance(ret, Exception): raise ret
else: return ret
if isinstance(ret, Exception):
raise ret
else:
return ret
return early_exec
def proc(itermaker, q) -> None:
try:
for x in itermaker(): q.put(x)
for x in itermaker():
q.put(x)
except Exception as e:
q.put(e)
finally:
q.put(None)
q.close()
class _CloudpickleFunctionWrapper:
def __init__(self, fn): self.fn = fn
def __getstate__(self): return cloudpickle.dumps(self.fn)
def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn)
def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs)
def __init__(self, fn):
self.fn = fn
def __getstate__(self):
return cloudpickle.dumps(self.fn)
def __setstate__(self, pfn):
self.fn = cloudpickle.loads(pfn)
def __call__(self, *args, **kwargs) -> Any:
return self.fn(*args, **kwargs)
def cross_process(itermaker, maxsize=16):
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
p = multiprocessing.Process(
target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q)
)
p.start()
while True:
ret = q.get()
if isinstance(ret, Exception): raise ret
elif ret is None: break
else: yield ret
if isinstance(ret, Exception):
raise ret
elif ret is None:
break
else:
yield ret

View File

@ -6,6 +6,7 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.ops_gpu import CLBuffer
from tinygrad.helpers import GlobalCounters
def print_objects():
# gc.collect()
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
@ -15,8 +16,12 @@ def print_objects():
realized_buffers = [x.realized for x in lazybuffers if x.realized]
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]
print(f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB")
print(f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers")
print(
f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB"
)
print(
f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers"
)
print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned")
cnt = 0
@ -33,11 +38,14 @@ def print_objects():
cnt += 1
for x in gpubuffers_orphaned:
if getattr(x, '_buf', None): del x._buf
if getattr(x, '_image', None): del x._image
if getattr(x, "_buf", None):
del x._buf
if getattr(x, "_image", None):
del x._image
return len(gpubuffers_orphaned)
"""
import gc

View File

@ -7,39 +7,44 @@ from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'H\003'
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
_globals['_TRAINERSPEC']._serialized_start=45
_globals['_TRAINERSPEC']._serialized_end=1581
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
_globals['_NORMALIZERSPEC']._serialized_start=1584
_globals['_NORMALIZERSPEC']._serialized_end=1793
_globals['_SELFTESTDATA']._serialized_start=1795
_globals['_SELFTESTDATA']._serialized_end=1916
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
_globals['_MODELPROTO']._serialized_start=1919
_globals['_MODELPROTO']._serialized_end=2429
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
_globals["DESCRIPTOR"]._options = None
_globals["DESCRIPTOR"]._serialized_options = b"H\003"
_globals["_TRAINERSPEC"].fields_by_name["mining_sentence_size"]._options = None
_globals["_TRAINERSPEC"].fields_by_name[
"mining_sentence_size"
]._serialized_options = b"\030\001"
_globals["_TRAINERSPEC"].fields_by_name["training_sentence_size"]._options = None
_globals["_TRAINERSPEC"].fields_by_name[
"training_sentence_size"
]._serialized_options = b"\030\001"
_globals["_TRAINERSPEC"]._serialized_start = 45
_globals["_TRAINERSPEC"]._serialized_end = 1581
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570
_globals["_NORMALIZERSPEC"]._serialized_start = 1584
_globals["_NORMALIZERSPEC"]._serialized_end = 1793
_globals["_SELFTESTDATA"]._serialized_start = 1795
_globals["_SELFTESTDATA"]._serialized_end = 1916
_globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864
_globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905
_globals["_MODELPROTO"]._serialized_start = 1919
_globals["_MODELPROTO"]._serialized_end = 2429
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407
# @@protoc_insertion_point(module_scope)

View File

@ -3,17 +3,22 @@ from typing import List
from tinygrad.nn.optim import Optimizer
from tinygrad.tensor import Tensor
class LR_Scheduler:
def __init__(self, optimizer: Optimizer):
self.optimizer = optimizer
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
self.epoch_counter = Tensor(
[0], requires_grad=False, device=self.optimizer.device
)
def get_lr(self): pass
def get_lr(self):
pass
def step(self) -> None:
self.epoch_counter.assign(self.epoch_counter + 1).realize()
self.optimizer.lr.assign(self.get_lr()).realize()
class MultiStepLR(LR_Scheduler):
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
super().__init__(optimizer)
@ -25,18 +30,38 @@ class MultiStepLR(LR_Scheduler):
return self.optimizer.lr
return self.optimizer.lr * self.gamma
class ReduceLROnPlateau(LR_Scheduler):
def __init__(self, optimizer: Optimizer, mode="min", factor=0.1, patience=10, threshold=1e-4, threshold_mode="rel"):
def __init__(
self,
optimizer: Optimizer,
mode="min",
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode="rel",
):
assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"]
super().__init__(optimizer)
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = mode, factor, patience, threshold, threshold_mode
self.best = float('inf') if mode == "min" else float('-inf')
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = (
mode,
factor,
patience,
threshold,
threshold_mode,
)
self.best = float("inf") if mode == "min" else float("-inf")
self.bad_epoch = 0
if mode == "min": self.threshold *= -1
if mode == "min":
self.threshold *= -1
def is_better(self, current: float) -> bool:
dynamic_threshold = self.best*(1+self.threshold) if self.threshold_mode == "rel" else self.best+self.threshold
dynamic_threshold = (
self.best * (1 + self.threshold)
if self.threshold_mode == "rel"
else self.best + self.threshold
)
if self.mode == "min":
return current < dynamic_threshold
return current > dynamic_threshold
@ -53,6 +78,7 @@ class ReduceLROnPlateau(LR_Scheduler):
self.optimizer.lr *= self.factor
self.bad_epoch = 0
class CosineAnnealingLR(LR_Scheduler):
def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0):
super().__init__(optimizer)
@ -61,26 +87,54 @@ class CosineAnnealingLR(LR_Scheduler):
self.eta_max = optimizer.lr.numpy()[0]
def get_lr(self) -> Tensor:
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device)
return Tensor(
[
self.eta_min
+ 0.5
* (self.eta_max - self.eta_min)
* (1 + math.cos((self.epoch_counter.numpy()[0] / self.T_max) * math.pi))
],
device=self.optimizer.device,
)
class OneCycleLR(LR_Scheduler):
def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
anneal_strategy: str = 'linear', cycle_momentum: bool = False):
def __init__(
self,
optimizer: Optimizer,
max_lr: float,
div_factor: float,
final_div_factor: float,
total_steps: int,
pct_start: float,
anneal_strategy: str = "linear",
cycle_momentum: bool = False,
):
self.initial_lr = Tensor([max_lr / div_factor]).contiguous()
self.max_lr = Tensor([max_lr]).contiguous()
self.min_lr = self.initial_lr / final_div_factor
super().__init__(optimizer)
self.total_steps = total_steps
self.pct_start = pct_start
assert anneal_strategy == 'linear', 'only linear annealing supported'
assert not cycle_momentum, 'cycle momentum not supported'
assert anneal_strategy == "linear", "only linear annealing supported"
assert not cycle_momentum, "cycle momentum not supported"
self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR
@staticmethod
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor: return ((end - start) * pct + start)
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor:
return (end - start) * pct + start
def get_lr(self) -> Tensor:
return (self.epoch_counter < self.total_steps * self.pct_start).where(
self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)),
self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start)))
self._annealing_linear(
self.initial_lr,
self.max_lr,
self.epoch_counter / (self.total_steps * self.pct_start),
),
self._annealing_linear(
self.max_lr,
self.min_lr,
(self.epoch_counter - (self.total_steps * self.pct_start))
/ (self.total_steps * (1 - self.pct_start)),
),
)

View File

@ -5,8 +5,29 @@ from pathlib import Path
class BertForQuestionAnswering:
def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1):
self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
def __init__(
self,
hidden_size=1024,
intermediate_size=4096,
max_position_embeddings=512,
num_attention_heads=16,
num_hidden_layers=24,
type_vocab_size=2,
vocab_size=30522,
attention_probs_dropout_prob=0.1,
hidden_dropout_prob=0.1,
):
self.bert = Bert(
hidden_size,
intermediate_size,
max_position_embeddings,
num_attention_heads,
num_hidden_layers,
type_vocab_size,
vocab_size,
attention_probs_dropout_prob,
hidden_dropout_prob,
)
self.qa_outputs = Linear(hidden_size, 2)
def load_from_pretrained(self):
@ -16,15 +37,20 @@ class BertForQuestionAnswering:
fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
import torch
with open(fn, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
for k, v in state_dict.items():
if "dropout" in k: continue # skip dropout
if "pooler" in k: continue # skip pooler
if "dropout" in k:
continue # skip dropout
if "pooler" in k:
continue # skip pooler
get_child(self, k).assign(v.numpy()).realize()
def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor):
def __call__(
self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor
):
sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.chunk(2, dim=-1)
@ -33,10 +59,35 @@ class BertForQuestionAnswering:
return Tensor.stack([start_logits, end_logits])
class Bert:
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob)
def __init__(
self,
hidden_size,
intermediate_size,
max_position_embeddings,
num_attention_heads,
num_hidden_layers,
type_vocab_size,
vocab_size,
attention_probs_dropout_prob,
hidden_dropout_prob,
):
self.embeddings = BertEmbeddings(
hidden_size,
max_position_embeddings,
type_vocab_size,
vocab_size,
hidden_dropout_prob,
)
self.encoder = BertEncoder(
hidden_size,
intermediate_size,
num_attention_heads,
num_hidden_layers,
attention_probs_dropout_prob,
hidden_dropout_prob,
)
def __call__(self, input_ids, attention_mask, token_type_ids):
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
@ -47,8 +98,16 @@ class Bert:
return encoder_outputs
class BertEmbeddings:
def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob):
def __init__(
self,
hidden_size,
max_position_embeddings,
type_vocab_size,
vocab_size,
hidden_dropout_prob,
):
self.word_embeddings = Embedding(vocab_size, hidden_size)
self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
@ -59,7 +118,11 @@ class BertEmbeddings:
input_shape = input_ids.shape
seq_length = input_shape[1]
position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
position_ids = (
Tensor.arange(seq_length, requires_grad=False)
.unsqueeze(0)
.expand(*input_shape)
)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
@ -69,18 +132,49 @@ class BertEmbeddings:
embeddings = embeddings.dropout(self.dropout)
return embeddings
class BertEncoder:
def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob):
self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)]
def __init__(
self,
hidden_size,
intermediate_size,
num_attention_heads,
num_hidden_layers,
attention_probs_dropout_prob,
hidden_dropout_prob,
):
self.layer = [
BertLayer(
hidden_size,
intermediate_size,
num_attention_heads,
attention_probs_dropout_prob,
hidden_dropout_prob,
)
for _ in range(num_hidden_layers)
]
def __call__(self, hidden_states, attention_mask):
for layer in self.layer:
hidden_states = layer(hidden_states, attention_mask)
return hidden_states
class BertLayer:
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
def __init__(
self,
hidden_size,
intermediate_size,
num_attention_heads,
attention_probs_dropout_prob,
hidden_dropout_prob,
):
self.attention = BertAttention(
hidden_size,
num_attention_heads,
attention_probs_dropout_prob,
hidden_dropout_prob,
)
self.intermediate = BertIntermediate(hidden_size, intermediate_size)
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
@ -90,6 +184,7 @@ class BertLayer:
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertOutput:
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
self.dense = Linear(intermediate_size, hidden_size)
@ -102,10 +197,21 @@ class BertOutput:
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# approximation of the error function
def erf(x):
t = (1 + 0.3275911 * x.abs()).reciprocal()
return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp())
return x.sign() * (
1
- (
(((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736)
* t
+ 0.254829592
)
* t
* (-(x.square())).exp()
)
class BertIntermediate:
def __init__(self, hidden_size, intermediate_size):
@ -116,9 +222,18 @@ class BertIntermediate:
# tinygrad gelu is openai gelu but we need the original bert gelu
return x * 0.5 * (1.0 + erf(x / 1.41421))
class BertAttention:
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
def __init__(
self,
hidden_size,
num_attention_heads,
attention_probs_dropout_prob,
hidden_dropout_prob,
):
self.self = BertSelfAttention(
hidden_size, num_attention_heads, attention_probs_dropout_prob
)
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
def __call__(self, hidden_states, attention_mask):
@ -126,6 +241,7 @@ class BertAttention:
attention_output = self.output(self_output, hidden_states)
return attention_output
class BertSelfAttention:
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
self.num_attention_heads = num_attention_heads
@ -147,17 +263,24 @@ class BertSelfAttention:
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
context_layer = Tensor.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, self.dropout)
context_layer = Tensor.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask, self.dropout
)
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size)
context_layer = context_layer.reshape(
context_layer.shape[0], context_layer.shape[1], self.all_head_size
)
return context_layer
def transpose_for_scores(self, x):
x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size)
x = x.reshape(
x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size
)
return x.transpose(1, 2)
class BertSelfOutput:
def __init__(self, hidden_size, hidden_dropout_prob):
self.dense = Linear(hidden_size, hidden_size)

View File

@ -2,6 +2,7 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
from tinygrad.helpers import fetch, get_child
class Block:
def __init__(self, dim):
self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
@ -11,18 +12,43 @@ class Block:
self.gamma = Tensor.ones(dim)
def __call__(self, x: Tensor):
return x + x.sequential([
self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm,
self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2)
])
return x + x.sequential(
[
self.dwconv,
lambda x: x.permute(0, 2, 3, 1),
self.norm,
self.pwconv1,
Tensor.gelu,
self.pwconv2,
lambda x: (self.gamma * x).permute(0, 3, 1, 2),
]
)
class ConvNeXt:
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
def __init__(
self,
in_chans=3,
num_classes=1000,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
):
self.downsample_layers = [
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
[
Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm2d(dims[0], eps=1e-6),
],
*[
[
LayerNorm2d(dims[i], eps=1e-6),
Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
]
for i in range(len(dims) - 1)
],
]
self.stages = [
[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))
]
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
self.norm = LayerNorm(dims[-1])
self.head = Linear(dims[-1], num_classes)
@ -31,6 +57,7 @@ class ConvNeXt:
x = x.sequential(downsample).sequential(stage)
return x.mean([-2, -1]).sequential([self.norm, self.head])
# *** model definition is done ***
versions = {
@ -38,24 +65,32 @@ versions = {
"small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
"base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
"large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]}
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]},
}
def get_model(version, load_weights=False):
model = ConvNeXt(**versions[version])
if load_weights:
from tinygrad.nn.state import torch_load
weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
weights = torch_load(
fetch(
f"https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth"
)
)["model"]
for k, v in weights.items():
mv = get_child(model, k)
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
return model
if __name__ == "__main__":
model = get_model("tiny", True)
# load image
from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
img = Tensor(preprocess(chicken_img))
Tensor.training = False

View File

@ -4,8 +4,19 @@ from tinygrad.nn import BatchNorm2d
from tinygrad.helpers import get_child, fetch
from tinygrad.nn.state import torch_load
class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
def __init__(
self,
kernel_size,
strides,
expand_ratio,
input_filters,
output_filters,
se_ratio,
has_se,
track_running_stats=True,
):
oup = expand_ratio * input_filters
if expand_ratio != 1:
self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
@ -37,12 +48,19 @@ class MBConvBlock:
x = inputs
if self._expand_conv:
x = self._bn0(x.conv2d(self._expand_conv)).swish()
x = x.conv2d(self._depthwise_conv, padding=self.pad, stride=self.strides, groups=self._depthwise_conv.shape[0])
x = x.conv2d(
self._depthwise_conv,
padding=self.pad,
stride=self.strides,
groups=self._depthwise_conv.shape[0],
)
x = self._bn1(x).swish()
if self.has_se:
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
x_squeezed = x_squeezed.conv2d(
self._se_reduce, self._se_reduce_bias
).swish()
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
x = x.mul(x_squeezed.sigmoid())
@ -51,8 +69,17 @@ class MBConvBlock:
x = x.add(inputs)
return x
class EfficientNet:
def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True):
def __init__(
self,
number=0,
classes=1000,
has_se=True,
track_running_stats=True,
input_channels=3,
has_fc_output=True,
):
self.number = number
global_params = [
# width, depth
@ -106,10 +133,31 @@ class EfficientNet:
]
self._blocks = []
for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
for (
num_repeats,
kernel_size,
strides,
expand_ratio,
input_filters,
output_filters,
se_ratio,
) in blocks_args:
input_filters, output_filters = round_filters(input_filters), round_filters(
output_filters
)
for n in range(round_repeats(num_repeats)):
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
self._blocks.append(
MBConvBlock(
kernel_size,
strides,
expand_ratio,
input_filters,
output_filters,
se_ratio,
has_se=has_se,
track_running_stats=track_running_stats,
)
)
input_filters = output_filters
strides = (1, 1)
@ -140,25 +188,34 @@ class EfficientNet:
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
}
b0 = torch_load(fetch(model_urls[self.number]))
for k, v in b0.items():
if k.endswith("num_batches_tracked"): continue
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
if k.endswith("num_batches_tracked"):
continue
for cat in [
"_conv_head",
"_conv_stem",
"_depthwise_conv",
"_expand_conv",
"_fc",
"_project_conv",
"_se_reduce",
"_se_expand",
]:
if cat in k:
k = k.replace('.bias', '_bias')
k = k.replace('.weight', '')
k = k.replace(".bias", "_bias")
k = k.replace(".weight", "")
# print(k, v.shape)
mv = get_child(self, k)
vnp = v # .astype(np.float32)
vnp = vnp if k != '_fc' else vnp.cpu().T
vnp = vnp if k != "_fc" else vnp.cpu().T
# vnp = vnp if vnp.shape != () else np.array([vnp])
if mv.shape == vnp.shape:
mv.assign(vnp.to(mv.device))
else:
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))

View File

@ -2,11 +2,15 @@ from typing import Tuple, Union, Optional, Dict
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
from tinygrad.helpers import getenv
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(
1, end, 1, dim // 2, 2
)
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
def complex_mult(A, c, d):
@ -15,20 +19,33 @@ def complex_mult(A, c, d):
co = a * d + b * c
return ro.cat(co, dim=-1)
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
assert (
freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1]
), f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
c, d = (
freqs_cis[:, : xq.shape[1], :, :, 0:1],
freqs_cis[:, : xq.shape[1], :, :, 1:2],
)
xq_out = complex_mult(xq, c, d)
xk_out = complex_mult(xk, c, d)
return xq_out.flatten(3), xk_out.flatten(3)
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
bs, seqlen, n_kv_heads, head_dim = x.shape
if n_rep == 1: return x
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
if n_rep == 1:
return x
return (
x.reshape(bs, seqlen, n_kv_heads, 1, head_dim)
.expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
.reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
)
class RMSNorm:
def __init__(self, dim, eps=1e-6):
@ -39,10 +56,13 @@ class RMSNorm:
# TODO: convert to float?
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
class Attention:
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.n_kv_heads = (
n_kv_heads if n_kv_heads is not None else n_heads
) # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.max_context = max_context
@ -52,7 +72,13 @@ class Attention:
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
def __call__(
self,
x: Tensor,
start_pos: Union[Variable, int],
freqs_cis: Tensor,
mask: Optional[Tensor],
) -> Tensor:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
@ -62,21 +88,40 @@ class Attention:
# create kv cache
if not hasattr(self, "cache_k"):
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
self.cache_k, self.cache_v = Tensor.zeros(
bsz, self.max_context, self.n_kv_heads, self.head_dim
), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
# update the cache
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
self.cache_k.assign(
keys.pad(
(None, (0, self.max_context - start_pos - seqlen), None, None)
).contiguous()
).realize()
self.cache_v.assign(
values.pad(
(None, (0, self.max_context - start_pos - seqlen), None, None)
).contiguous()
).realize()
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
xq, keys, values = (
xq.transpose(1, 2),
keys.transpose(1, 2),
values.transpose(1, 2),
)
attn = (
xq.scaled_dot_product_attention(keys, values, mask)
.transpose(1, 2)
.reshape(bsz, seqlen, -1)
)
return self.wo(attn)
class FeedForward:
def __init__(self, dim, hidden_dim, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
@ -84,36 +129,88 @@ class FeedForward:
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
def __call__(self, x: Tensor) -> Tensor:
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
return self.w2(
self.w1(x).silu() * self.w3(x)
) # SwiGLU [arxiv/2002.05202, eq (5)]
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear):
def __init__(
self,
dim: int,
hidden_dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
max_context: int,
linear=nn.Linear,
):
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
self.feed_forward = FeedForward(dim, hidden_dim, linear)
self.attention_norm = RMSNorm(dim, norm_eps)
self.ffn_norm = RMSNorm(dim, norm_eps)
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
def __call__(
self,
x: Tensor,
start_pos: Union[Variable, int],
freqs_cis: Tensor,
mask: Optional[Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
return (h + self.feed_forward(self.ffn_norm(h))).realize()
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear) for _ in range(n_layers)]
def __init__(
self,
dim: int,
hidden_dim: int,
n_heads: int,
n_layers: int,
norm_eps: float,
vocab_size,
linear=nn.Linear,
n_kv_heads=None,
rope_theta=10000,
max_context=1024,
jit=True,
):
self.layers = [
TransformerBlock(
dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear
)
for _ in range(n_layers)
]
self.norm = RMSNorm(dim, norm_eps)
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.output = linear(dim, vocab_size, bias=False)
self.max_context = max_context
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
self.freqs_cis = precompute_freqs_cis(
dim // n_heads, self.max_context * 2, rope_theta
)
self.forward_jit = TinyJit(self.forward) if jit else None
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
def forward(
self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float = 0.0
):
_bsz, seqlen = tokens.shape
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
freqs_cis = self.freqs_cis.shrink(
(None, (start_pos, start_pos + seqlen), None, None, None)
)
mask = (
Tensor.full(
(1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32
)
.triu(start_pos + 1)
.realize()
if seqlen > 1
else None
)
h = self.tok_embeddings(tokens)
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
logits = self.output(self.norm(h))
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().flatten().realize()
@ -121,27 +218,54 @@ class Transformer:
# TODO: better way to handle the first call v.s. the rest?
if tokens.shape[0:2] == (1, 1) and self.forward_jit and getenv("JIT", 1):
assert start_pos > 0
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
return self.forward_jit(
tokens,
Variable("start_pos", 1, self.max_context).bind(start_pos),
temperature,
)
return self.forward(tokens, start_pos, temperature)
# *** helpers ***
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
def convert_from_huggingface(
weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int
):
def permute(v: Tensor, n_heads: int):
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
return (
v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1])
.transpose(1, 2)
.reshape(*v.shape[:2])
)
keymap = {
"model.embed_tokens.weight": "tok_embeddings.weight",
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
**{
f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
for l in range(len(model.layers))
},
**{
f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
for x in ["q", "k", "v", "o"]
for l in range(len(model.layers))
},
**{
f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
for l in range(len(model.layers))
},
**{
f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
for l in range(len(model.layers))
},
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
sd = {}
for k, v in weights.items():
if ".rotary_emb." in k: continue
if ".rotary_emb." in k:
continue
v = v.to(Device.DEFAULT)
if "model.layers" in k:
if "q_proj" in k:

View File

@ -10,64 +10,95 @@ from tinygrad.nn.state import torch_load
from extra.models.resnet import ResNet
from extra.models.retinanet import nms as _box_nms
USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
USE_NP_GATHER = os.getenv("FULL_TINYGRAD", "0") == "0"
def rint(tensor):
x = (tensor * 2).cast(dtypes.int32).contiguous().cast(dtypes.float32) / 2
return (x < 0).where(x.floor(), x.ceil())
def nearest_interpolate(tensor, scale_factor):
bs, c, py, px = tensor.shape
return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
return (
tensor.reshape(bs, c, py, 1, px, 1)
.expand(bs, c, py, scale_factor, px, scale_factor)
.reshape(bs, c, py * scale_factor, px * scale_factor)
)
def meshgrid(x, y):
grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
grid_x = Tensor.cat(
*[x[idx : idx + 1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])]
)
grid_y = Tensor.cat(*[y.unsqueeze(0)] * x.shape[0])
return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
def topk(input_, k, dim=-1, largest=True, sorted=False):
k = min(k, input_.shape[dim] - 1)
input_ = input_.numpy()
if largest: input_ *= -1
if largest:
input_ *= -1
ind = np.argpartition(input_, k, axis=dim)
if largest: input_ *= -1
if largest:
input_ *= -1
ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
if not sorted: return Tensor(input_), ind
if largest: input_ *= -1
if not sorted:
return Tensor(input_), ind
if largest:
input_ *= -1
ind_part = np.argsort(input_, axis=dim)
ind = np.take_along_axis(ind, ind_part, axis=dim)
if largest: input_ *= -1
if largest:
input_ *= -1
val = np.take_along_axis(input_, ind_part, axis=dim)
return Tensor(val), ind
# This is very slow for large arrays, or indices
def _gather(array, indices):
indices = indices.float().to(array.device)
reshape_arg = [1] * array.ndim + [array.shape[-1]]
return Tensor.where(
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
array, 0,
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1])
== Tensor.arange(array.shape[-1])
.reshape(*reshape_arg)
.expand(*indices.shape, array.shape[-1]),
array,
0,
).sum(indices.ndim)
# TODO: replace npgather with a faster gather using tinygrad only
# NOTE: this blocks the gradient
def npgather(array, indices):
if isinstance(array, Tensor): array = array.numpy()
if isinstance(indices, Tensor): indices = indices.numpy()
if isinstance(indices, list): indices = np.asarray(indices)
if isinstance(array, Tensor):
array = array.numpy()
if isinstance(indices, Tensor):
indices = indices.numpy()
if isinstance(indices, list):
indices = np.asarray(indices)
return Tensor(array[indices.astype(int)])
def get_strides(shape):
prod = [1]
for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
for idx in range(len(shape) - 1, -1, -1):
prod.append(prod[-1] * shape[idx])
# something about ints is broken with gpu, cuda
return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu()
# with keys as integer array for all axes
def tensor_getitem(tensor, *keys):
# something about ints is broken with gpu, cuda
flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cpu().cast(dtypes.int32)
flat_keys = (
Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1)
.cpu()
.cast(dtypes.int32)
)
strides = get_strides(tensor.shape)
idxs = (flat_keys * strides).sum(1)
gatherer = npgather if USE_NP_GATHER else _gather
@ -97,7 +128,8 @@ def tensor_gather(tensor, indices):
class LastLevelMaxPool:
def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
def __call__(self, x):
return [Tensor.max_pool2d(x, 1, 2)]
# transpose
@ -117,9 +149,7 @@ class BoxList:
if not isinstance(bbox, Tensor):
bbox = Tensor(bbox)
if bbox.ndim != 2:
raise ValueError(
"bbox should have 2 dimensions, got {}".format(bbox.ndim)
)
raise ValueError("bbox should have 2 dimensions, got {}".format(bbox.ndim))
if bbox.shape[-1] != 4:
raise ValueError(
"last dimenion of bbox should have a "
@ -145,7 +175,9 @@ class BoxList:
box = self.bbox
if self.mode == "xyxy":
TO_REMOVE = 1
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (
box[:, 3] - box[:, 1] + TO_REMOVE
)
elif self.mode == "xywh":
area = box[:, 2] * box[:, 3]
return area
@ -241,7 +273,8 @@ class BoxList:
transposed_ymax = image_height - ymin
transposed_boxes = Tensor.cat(
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax),
dim=-1,
)
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
for k, v in self.extra_fields.items():
@ -289,7 +322,11 @@ def cat_boxlist(bboxes):
else:
cat_boxes = BoxList(bboxes[0].bbox, size, mode)
for field in fields:
cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
cat_field_list = [
bbox.get_field(field)
for bbox in bboxes
if bbox.get_field(field).shape[0] > 0
]
if len(cat_box_list) > 0:
data = Tensor.cat(*cat_field_list, dim=0)
@ -305,8 +342,12 @@ class FPN:
def __init__(self, in_channels_list, out_channels):
self.inner_blocks, self.layer_blocks = [], []
for in_channels in in_channels_list:
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
self.inner_blocks.append(
nn.Conv2d(in_channels, out_channels, kernel_size=1)
)
self.layer_blocks.append(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
self.top_block = LastLevelMaxPool()
def __call__(self, x: Tensor):
@ -357,9 +398,7 @@ class AnchorGenerator:
):
if len(anchor_strides) == 1:
anchor_stride = anchor_strides[0]
cell_anchors = [
generate_anchors(anchor_stride, sizes, aspect_ratios)
]
cell_anchors = [generate_anchors(anchor_stride, sizes, aspect_ratios)]
else:
if len(anchor_strides) != len(sizes):
raise RuntimeError("FPN should have #anchor_strides == #sizes")
@ -368,7 +407,7 @@ class AnchorGenerator:
generate_anchors(
anchor_stride,
size if isinstance(size, (tuple, list)) else (size,),
aspect_ratios
aspect_ratios,
)
for anchor_stride, size in zip(anchor_strides, sizes)
]
@ -387,10 +426,18 @@ class AnchorGenerator:
grid_height, grid_width = size
device = base_anchors.device
shifts_x = Tensor.arange(
start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
start=0,
stop=grid_width * stride,
step=stride,
dtype=dtypes.float32,
device=device,
)
shifts_y = Tensor.arange(
start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
start=0,
stop=grid_height * stride,
step=stride,
dtype=dtypes.float32,
device=device,
)
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
@ -398,7 +445,9 @@ class AnchorGenerator:
shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
anchors.append(
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(
-1, 4
)
)
return anchors
@ -415,14 +464,16 @@ class AnchorGenerator:
)
else:
device = anchors.device
inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
inds_inside = Tensor.ones(
anchors.shape[0], dtype=dtypes.uint8, device=device
)
boxlist.add_field("visibility", inds_inside)
def __call__(self, image_list, feature_maps):
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
anchors = []
for (image_height, image_width) in image_list.image_sizes:
for image_height, image_width in image_list.image_sizes:
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
boxlist = BoxList(
@ -437,14 +488,19 @@ class AnchorGenerator:
def generate_anchors(
stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
):
return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
return _generate_anchors(
stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios))
)
def _generate_anchors(base_size, scales, aspect_ratios):
anchor = Tensor([1, 1, base_size, base_size]) - 1
anchors = _ratio_enum(anchor, aspect_ratios)
anchors = Tensor.cat(
*[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
*[
_scale_enum(anchors[i, :], scales).reshape(-1, 4)
for i in range(anchors.shape[0])
]
)
return anchors
@ -460,12 +516,15 @@ def _whctrs(anchor):
def _mkanchors(ws, hs, x_ctr, y_ctr):
ws = ws[:, None]
hs = hs[:, None]
anchors = Tensor.cat(*(
anchors = Tensor.cat(
*(
x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1),
), dim=1)
),
dim=1,
)
return anchors
@ -504,7 +563,7 @@ class RPNHead:
class BoxCoder(object):
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
self.weights = weights
self.bbox_xform_clip = bbox_xform_clip
@ -557,7 +616,11 @@ class BoxCoder(object):
y = pred_ctr_y - 0.5 * pred_h
w = pred_ctr_x + 0.5 * pred_w - 1
h = pred_ctr_y + 0.5 * pred_h - 1
pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
pred_boxes = (
Tensor.stack([x, y, w, h])
.permute(1, 2, 0)
.reshape(rel_codes.shape[0], rel_codes.shape[1])
)
return pred_boxes
@ -578,9 +641,7 @@ def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
def remove_small_boxes(boxlist, min_size):
xywh_boxes = boxlist.convert("xywh").bbox
_, _, ws, hs = xywh_boxes.chunk(4, dim=1)
keep = ((
(ws >= min_size) * (hs >= min_size)
) > 0).reshape(-1)
keep = (((ws >= min_size) * (hs >= min_size)) > 0).reshape(-1)
if keep.sum().numpy() == len(boxlist):
return boxlist
else:
@ -630,8 +691,12 @@ class RPNPostProcessor:
box_regression_list = []
concat_anchors_list = []
for batch_idx in range(N):
box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
box_regression_list.append(
tensor_gather(box_regression[batch_idx], topk_idx[batch_idx])
)
concat_anchors_list.append(
tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx])
)
box_regression = Tensor.stack(box_regression_list)
concat_anchors = Tensor.stack(concat_anchors_list)
@ -677,9 +742,7 @@ class RPNPostProcessor:
for i in range(num_images):
objectness = boxlists[i].get_field("objectness")
post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
_, inds_sorted = topk(objectness,
post_nms_top_n, dim=0, sorted=False
)
_, inds_sorted = topk(objectness, post_nms_top_n, dim=0, sorted=False)
boxlists[i] = boxlists[i][inds_sorted]
return boxlists
@ -689,9 +752,7 @@ class RPN:
self.anchor_generator = AnchorGenerator()
in_channels = 256
head = RPNHead(
in_channels, self.anchor_generator.num_anchors_per_location()[0]
)
head = RPNHead(in_channels, self.anchor_generator.num_anchors_per_location()[0])
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
box_selector_test = RPNPostProcessor(
pre_nms_top_n=1000,
@ -699,7 +760,7 @@ class RPN:
nms_thresh=0.7,
min_size=0,
box_coder=rpn_box_coder,
fpn_post_nms_top_n=1000
fpn_post_nms_top_n=1000,
)
self.head = head
self.box_selector_test = box_selector_test
@ -725,7 +786,7 @@ def make_conv3x3(
stride=stride,
padding=dilation,
dilation=dilation,
bias=False if use_gn else True
bias=False if use_gn else True,
)
return conv
@ -746,10 +807,18 @@ class MaskRCNNFPNFeatureExtractor:
use_gn = False
layers = (256, 256, 256, 256)
dilation = 1
self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
self.mask_fcn1 = make_conv3x3(
input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn
)
self.mask_fcn2 = make_conv3x3(
layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn
)
self.mask_fcn3 = make_conv3x3(
layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn
)
self.mask_fcn4 = make_conv3x3(
layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn
)
self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
def __call__(self, x, proposals):
@ -833,7 +902,9 @@ def _bilinear_interpolate(
y = Tensor.where(ymask[:, None, :], y, 0)
x = Tensor.where(xmask[:, None, :], x, 0)
key1 = roi_batch_ind[:, None, None, None, None, None]
key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
key2 = Tensor.arange(channels, device=input.device)[
None, :, None, None, None, None
]
key3 = y[:, None, :, None, :, None]
key4 = x[:, None, None, :, None, :]
return tensor_getitem(input, key1, key2, key3, key4) # [K, C, PH, PW, IY, IX]
@ -855,8 +926,11 @@ def _bilinear_interpolate(
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
# https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
def _roi_align(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned
):
orig_dtype = input.dtype
_, _, height, width = input.shape
ph = Tensor.arange(pooled_height, device=input.device)
@ -879,8 +953,12 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
bin_size_w = roi_width / pooled_width
exact_sampling = sampling_ratio > 0
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
roi_bin_grid_h = (
sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
)
roi_bin_grid_w = (
sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
)
if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
@ -923,6 +1001,7 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
output = output.cast(orig_dtype)
return output
class ROIAlign:
def __init__(self, output_size, spatial_scale, sampling_ratio):
self.output_size = output_size
@ -931,7 +1010,13 @@ class ROIAlign:
def __call__(self, input, rois):
output = _roi_align(
input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
input,
rois,
self.spatial_scale,
self.output_size[0],
self.output_size[1],
self.sampling_ratio,
aligned=False,
)
return output
@ -1002,7 +1087,16 @@ class Pooler:
all_idxs.extend(idx_in_level)
results.append(pooler_output)
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
return tensor_gather(
Tensor.cat(*results),
[
x[0]
for x in sorted(
{i: idx for i, idx in enumerate(all_idxs)}.items(),
key=lambda x: x[1],
)
],
)
class FPNPredictor:
@ -1027,13 +1121,13 @@ class PostProcessor:
nms=0.5,
detections_per_img=100,
box_coder=None,
cls_agnostic_bbox_reg=False
cls_agnostic_bbox_reg=False,
):
self.score_thresh = score_thresh
self.nms = nms
self.detections_per_img = detections_per_img
if box_coder is None:
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
self.box_coder = box_coder
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
@ -1090,9 +1184,7 @@ class PostProcessor:
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", scores_j)
if len(boxlist_for_class):
boxlist_for_class = boxlist_nms(
boxlist_for_class, self.nms
)
boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
num_labels = len(boxlist_for_class)
boxlist_for_class.add_field(
"labels", Tensor.full((num_labels,), j, device=device)
@ -1119,8 +1211,8 @@ class RoIBoxHead:
score_thresh=0.05,
nms=0.5,
detections_per_img=100,
box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
cls_agnostic_bbox_reg=False
box_coder=BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)),
cls_agnostic_bbox_reg=False,
)
def __call__(self, features, proposals, targets=None):
@ -1210,7 +1302,6 @@ def to_image_list(tensors, size_divisible=32):
elif isinstance(tensors, (tuple, list)):
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
if size_divisible > 0:
stride = size_divisible
max_size = list(max_size)
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
@ -1237,10 +1328,13 @@ class MaskRCNN:
self.roi_heads = RoIHeads(self.backbone.out_channels)
def load_from_pretrained(self):
fn = Path('./') / "weights/maskrcnn.pt"
fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
fn = Path("./") / "weights/maskrcnn.pt"
fetch(
"https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth",
fn,
)
state_dict = torch_load(fn)['model']
state_dict = torch_load(fn)["model"]
loaded_keys = []
for k, v in state_dict.items():
if "module." in k:
@ -1265,7 +1359,7 @@ class MaskRCNN:
return result
if __name__ == '__main__':
if __name__ == "__main__":
resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
model = MaskRCNN(backbone=resnet)
model.load_from_pretrained()

View File

@ -3,20 +3,33 @@ from tinygrad.tensor import Tensor
from tinygrad.nn.state import torch_load
from tinygrad.helpers import fetch, get_child
class BasicBlock:
expansion = 1
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
assert (
groups == 1 and base_width == 64
), "BasicBlock only supports groups=1 and base_width=64"
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, stride=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = []
if stride != 1 or in_planes != self.expansion * planes:
self.downsample = [
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(self.expansion * planes),
]
def __call__(self, x):
@ -31,20 +44,44 @@ class Bottleneck:
# NOTE: stride_in_1x1=False, this is the v1.5 variant
expansion = 4
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
def __init__(
self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64
):
width = int(planes * (base_width / 64.0)) * groups
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
self.conv1 = nn.Conv2d(
in_planes,
width,
kernel_size=1,
stride=stride if stride_in_1x1 else 1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
self.conv2 = nn.Conv2d(
width,
width,
kernel_size=3,
padding=1,
stride=1 if stride_in_1x1 else stride,
groups=groups,
bias=False,
)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(
width, self.expansion * planes, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.downsample = []
if stride != 1 or in_planes != self.expansion * planes:
self.downsample = [
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(self.expansion * planes),
]
def __call__(self, x):
@ -55,15 +92,18 @@ class Bottleneck:
out = out.relu()
return out
class ResNet:
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
def __init__(
self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False
):
self.num = num
self.block = {
18: BasicBlock,
34: BasicBlock,
50: Bottleneck,
101: Bottleneck,
152: Bottleneck
152: Bottleneck,
}[num]
self.num_blocks = {
@ -71,7 +111,7 @@ class ResNet:
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3,8,36,3]
152: [3, 8, 36, 3],
}[num]
self.in_planes = 64
@ -80,36 +120,64 @@ class ResNet:
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
self.layer1 = self._make_layer(
self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1
)
self.layer2 = self._make_layer(
self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1
)
self.layer3 = self._make_layer(
self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1
)
self.layer4 = self._make_layer(
self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1
)
self.fc = (
nn.Linear(512 * self.block.expansion, num_classes)
if num_classes is not None
else None
)
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
if block == Bottleneck:
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
layers.append(
block(
self.in_planes,
planes,
stride,
stride_in_1x1,
self.groups,
self.base_width,
)
)
else:
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
layers.append(
block(self.in_planes, planes, stride, self.groups, self.base_width)
)
self.in_planes = planes * block.expansion
return layers
def forward(self, x):
is_feature_only = self.fc is None
if is_feature_only: features = []
if is_feature_only:
features = []
out = self.bn1(self.conv1(x)).relu()
out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2)
out = out.sequential(self.layer1)
if is_feature_only: features.append(out)
if is_feature_only:
features.append(out)
out = out.sequential(self.layer2)
if is_feature_only: features.append(out)
if is_feature_only:
features.append(out)
out = out.sequential(self.layer3)
if is_feature_only: features.append(out)
if is_feature_only:
features.append(out)
out = out.sequential(self.layer4)
if is_feature_only: features.append(out)
if is_feature_only:
features.append(out)
if not is_feature_only:
out = out.mean([2, 3])
out = self.fc(out).log_softmax()
@ -123,12 +191,16 @@ class ResNet:
# TODO replace with fake torch load
model_urls = {
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
(50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
(101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
(18, 1, 64): "https://download.pytorch.org/models/resnet18-5c106cde.pth",
(34, 1, 64): "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
(50, 1, 64): "https://download.pytorch.org/models/resnet50-19c8e357.pth",
(
50,
32,
4,
): "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
(101, 1, 64): "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
(152, 1, 64): "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}
self.url = model_urls[(self.num, self.groups, self.base_width)]
@ -136,17 +208,24 @@ class ResNet:
obj: Tensor = get_child(self, k)
dat = v.detach().numpy()
if 'fc.' in k and obj.shape != dat.shape:
if "fc." in k and obj.shape != dat.shape:
print("skipping fully connected layer")
continue # Skip FC if transfer learning
# TODO: remove or when #777 is merged
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape)
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (
k,
obj.shape,
dat.shape,
)
obj.assign(dat)
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(
50, num_classes=num_classes, groups=32, width_per_group=4
)

View File

@ -4,6 +4,7 @@ import tinygrad.nn as nn
from extra.models.resnet import ResNet
import numpy as np
def nms(boxes, scores, thresh=0.5):
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
@ -15,11 +16,14 @@ def nms(boxes, scores, thresh=0.5):
inter_y1 = np.maximum(y1[cur], y1[to_process])
inter_x2 = np.minimum(x2[cur], x2[to_process])
inter_y2 = np.minimum(y2[cur], y2[to_process])
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1)
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(
0, inter_y2 - inter_y1 + 1
)
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
to_process = to_process[np.where(iou <= thresh)[0]]
return keep
def decode_bbox(offsets, anchors):
dx, dy, dw, dh = np.rollaxis(offsets, 1)
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
@ -30,6 +34,7 @@ def decode_bbox(offsets, anchors):
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
anchors = []
@ -41,39 +46,87 @@ def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
shifts_x, shifts_y = np.meshgrid(
np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h
)
shifts_x = shifts_x.reshape(-1)
shifts_y = shifts_y.reshape(-1)
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
shifts = np.stack(
[shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32
)
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
return anchors
class RetinaNet:
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
def __init__(
self,
backbone: ResNet,
num_classes=264,
num_anchors=9,
scales=None,
aspect_ratios=None,
):
assert isinstance(backbone, ResNet)
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
scales = (
tuple(
(i, int(i * 2 ** (1 / 3)), int(i * 2 ** (2 / 3)))
for i in 2 ** np.arange(5, 10)
)
if scales is None
else scales
)
aspect_ratios = (
((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
)
self.num_anchors, self.num_classes = num_anchors, num_classes
assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios))
assert len(scales) == len(aspect_ratios) and all(
self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)
)
self.backbone = ResNetFPN(backbone)
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
self.head = RetinaHead(
self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes
)
self.anchor_gen = lambda input_size: generate_anchors(
input_size,
self.backbone.compute_grid_sizes(input_size),
scales,
aspect_ratios,
)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
return self.head(self.backbone(x))
def load_from_pretrained(self):
model_urls = {
(50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
(50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
(
50,
1,
64,
): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
(
50,
32,
4,
): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
}
self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)]
self.url = model_urls[
(
self.backbone.body.num,
self.backbone.body.groups,
self.backbone.body.base_width,
)
]
from torch.hub import load_state_dict_from_url
state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu')
state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict
state_dict = load_state_dict_from_url(
self.url, progress=True, map_location="cpu"
)
state_dict = state_dict["model"] if "model" in state_dict.keys() else state_dict
for k, v in state_dict.items():
obj = get_child(self, k)
dat = v.detach().numpy()
@ -81,10 +134,21 @@ class RetinaNet:
obj.assign(dat)
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
def postprocess_detections(
self,
predictions,
input_size=(800, 800),
image_sizes=None,
orig_image_sizes=None,
score_thresh=0.05,
topk_candidates=1000,
nms_thresh=0.5,
):
anchors = self.anchor_gen(input_size)
grid_sizes = self.backbone.compute_grid_sizes(input_size)
split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
split_idx = np.cumsum(
[int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]]
)
detections = []
for i, predictions_per_image in enumerate(predictions):
h, w = input_size if image_sizes is None else image_sizes[i]
@ -94,7 +158,9 @@ class RetinaNet:
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
image_boxes, image_scores, image_labels = [], [], []
for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors):
for offsets_per_level, scores_per_level, anchors_per_level in zip(
offsets_per_image, scores_per_image, anchors
):
# remove low scoring boxes
scores_per_level = scores_per_level.flatten()
keep_idxs = scores_per_level > score_thresh
@ -104,16 +170,23 @@ class RetinaNet:
topk_idxs = np.where(keep_idxs)[0]
num_topk = min(len(topk_idxs), topk_candidates)
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs]
topk_idxs, scores_per_level = (
topk_idxs[sort_idxs],
scores_per_level[sort_idxs],
)
# bbox coords from offsets
anchor_idxs = topk_idxs // self.num_classes
labels_per_level = topk_idxs % self.num_classes
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
boxes_per_level = decode_bbox(
offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
)
# clip to image size
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4)
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(
-1, 4
)
image_boxes.append(boxes_per_level)
image_scores.append(scores_per_level)
@ -127,7 +200,9 @@ class RetinaNet:
keep_mask = np.zeros_like(image_scores, dtype=bool)
for class_id in np.unique(image_labels):
curr_indices = np.where(image_labels == class_id)[0]
curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh)
curr_keep_indices = nms(
image_boxes[curr_indices], image_scores[curr_indices], nms_thresh
)
keep_mask[curr_indices[curr_keep_indices]] = True
keep = np.where(keep_mask)[0]
keep = keep[image_scores[keep].argsort()[::-1]]
@ -139,42 +214,91 @@ class RetinaNet:
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
# xywh format
image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1)
image_boxes = np.concatenate(
[image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1
)
detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]})
detections.append(
{
"boxes": image_boxes,
"scores": image_scores[keep],
"labels": image_labels[keep],
}
)
return detections
class ClassificationHead:
def __init__(self, in_channels, num_anchors, num_classes):
self.num_classes = num_classes
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
self.conv = flatten(
[
(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
lambda x: x.relu(),
)
for _ in range(4)
]
)
self.cls_logits = nn.Conv2d(
in_channels, num_anchors * num_classes, kernel_size=3, padding=1
)
def __call__(self, x):
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
out = [
self.cls_logits(feat.sequential(self.conv))
.permute(0, 2, 3, 1)
.reshape(feat.shape[0], -1, self.num_classes)
for feat in x
]
return out[0].cat(*out[1:], dim=1).sigmoid()
class RegressionHead:
def __init__(self, in_channels, num_anchors):
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
self.conv = flatten(
[
(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
lambda x: x.relu(),
)
for _ in range(4)
]
)
self.bbox_reg = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=3, padding=1
)
def __call__(self, x):
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
out = [
self.bbox_reg(feat.sequential(self.conv))
.permute(0, 2, 3, 1)
.reshape(feat.shape[0], -1, 4)
for feat in x
]
return out[0].cat(*out[1:], dim=1)
class RetinaHead:
def __init__(self, in_channels, num_anchors, num_classes):
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
self.classification_head = ClassificationHead(
in_channels, num_anchors, num_classes
)
self.regression_head = RegressionHead(in_channels, num_anchors)
def __call__(self, x):
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
out = pred_bbox.cat(pred_class, dim=-1)
return out
class ResNetFPN:
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
self.out_channels = out_channels
self.body = resnet
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
in_channels_list = [
(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers
]
self.fpn = FPN(in_channels_list, out_channels)
# this is needed to decouple inference from postprocessing (anchors generation)
@ -190,10 +314,15 @@ class ResNetFPN:
p5 = p4.sequential(self.body.layer4)
return self.fpn([p3, p4, p5])
class ExtraFPNBlock:
def __init__(self, in_channels, out_channels):
self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.p6 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, padding=1
)
self.p7 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
self.use_P5 = in_channels == out_channels
def __call__(self, p, c):
@ -204,13 +333,20 @@ class ExtraFPNBlock:
p.extend([p6, p7])
return p
class FPN:
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
self.inner_blocks, self.layer_blocks = [], []
for in_channels in in_channels_list:
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
self.inner_blocks.append(
nn.Conv2d(in_channels, out_channels, kernel_size=1)
)
self.layer_blocks.append(
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
self.extra_blocks = (
ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
)
def __call__(self, x):
last_inner = self.inner_blocks[-1](x[-1])
@ -219,9 +355,17 @@ class FPN:
inner_lateral = self.inner_blocks[idx](x[idx])
# upsample to inner_lateral's shape
(ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2]
(ih, iw), (oh, ow), prefix = (
last_inner.shape[-2:],
inner_lateral.shape[-2:],
last_inner.shape[:-2],
)
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow]
inner_top_down = (
last_inner.reshape(*prefix, ih, 1, iw, 1)
.expand(*prefix, ih, eh, iw, ew)
.reshape(*prefix, ih * eh, iw * ew)[:, :, :oh, :ow]
)
last_inner = inner_lateral + inner_top_down
results.insert(0, self.layer_blocks[idx](last_inner))
@ -229,8 +373,10 @@ class FPN:
results = self.extra_blocks(results, x)
return results
if __name__ == "__main__":
from extra.models.resnet import ResNeXt50_32X4D
backbone = ResNeXt50_32X4D()
retina = RetinaNet(backbone)
retina.load_from_pretrained()

View File

@ -7,10 +7,31 @@ from pathlib import Path
class RNNT:
def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
def __init__(
self,
input_features=240,
vocab_size=29,
enc_hidden_size=1024,
pred_hidden_size=320,
joint_hidden_size=512,
pre_enc_layers=2,
post_enc_layers=3,
pred_layers=2,
stack_time_factor=2,
dropout=0.32,
):
self.encoder = Encoder(
input_features,
enc_hidden_size,
pre_enc_layers,
post_enc_layers,
stack_time_factor,
dropout,
)
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
self.joint = Joint(
vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
)
@TinyJit
def __call__(self, x, y, hc=None):
@ -30,7 +51,12 @@ class RNNT:
return outputs
def _greedy_decode(self, logits, logit_len):
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
hc = Tensor.zeros(
self.prediction.rnn.layers,
2,
self.prediction.hidden_size,
requires_grad=False,
)
labels = []
label = Tensor.zeros(1, 1, requires_grad=False)
mask = Tensor.zeros(1, requires_grad=False)
@ -41,7 +67,14 @@ class RNNT:
while not_blank and added < 30:
if len(labels) > 0:
mask = (mask + 1).clip(0, 1)
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
label = (
Tensor(
[[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]],
requires_grad=False,
)
+ 1
- 1
)
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
k = jhc[0, 0, :29].argmax(axis=0).numpy()
not_blank = k != 28
@ -61,31 +94,59 @@ class RNNT:
def load_from_pretrained(self):
fn = Path(__file__).parents[1] / "weights/rnnt.pt"
fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
fetch(
"https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1",
fn,
)
import torch
with open(fn, "rb") as f:
state_dict = torch.load(f, map_location="cpu")["state_dict"]
# encoder
for i in range(2):
self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
self.encoder.pre_rnn.cells[i].weights_ih.assign(
state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy()
)
self.encoder.pre_rnn.cells[i].weights_hh.assign(
state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy()
)
self.encoder.pre_rnn.cells[i].bias_ih.assign(
state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy()
)
self.encoder.pre_rnn.cells[i].bias_hh.assign(
state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy()
)
for i in range(3):
self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
self.encoder.post_rnn.cells[i].weights_ih.assign(
state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy()
)
self.encoder.post_rnn.cells[i].weights_hh.assign(
state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy()
)
self.encoder.post_rnn.cells[i].bias_ih.assign(
state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy()
)
self.encoder.post_rnn.cells[i].bias_hh.assign(
state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy()
)
# prediction
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
for i in range(2):
self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
self.prediction.rnn.cells[i].weights_ih.assign(
state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy()
)
self.prediction.rnn.cells[i].weights_hh.assign(
state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy()
)
self.prediction.rnn.cells[i].bias_ih.assign(
state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy()
)
self.prediction.rnn.cells[i].bias_hh.assign(
state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy()
)
# joint
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
@ -104,7 +165,9 @@ class LSTMCell:
self.bias_hh = Tensor.uniform(hidden_size * 4)
def __call__(self, x, hc):
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[: x.shape[0]].linear(
self.weights_hh.T, self.bias_hh
)
i, f, g, o = gates.chunk(4, 1)
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
@ -121,7 +184,12 @@ class LSTM:
self.hidden_size = hidden_size
self.layers = layers
self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
self.cells = [
LSTMCell(input_size, hidden_size, dropout)
if i == 0
else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0)
for i in range(layers)
]
def __call__(self, x, hc):
@TinyJit
@ -129,7 +197,9 @@ class LSTM:
return self.do_step(x_, hc_)
if hc is None:
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
hc = Tensor.zeros(
self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False
)
output = None
for t in range(x.shape[0]):
@ -159,10 +229,20 @@ class StackTime:
class Encoder:
def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
def __init__(
self,
input_size,
hidden_size,
pre_layers,
post_layers,
stack_time_factor,
dropout,
):
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
self.stack_time = StackTime(stack_time_factor)
self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
self.post_rnn = LSTM(
stack_time_factor * hidden_size, hidden_size, post_layers, dropout
)
def __call__(self, x, x_lens):
x, _ = self.pre_rnn(x, None)
@ -185,7 +265,9 @@ class Prediction:
class Joint:
def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
def __init__(
self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
):
self.dropout = dropout
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)

View File

@ -1,8 +1,17 @@
import numpy as np
from tinygrad.tensor import Tensor
class TransformerBlock:
def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1):
def __init__(
self,
embed_dim,
num_heads,
ff_dim,
prenorm=False,
act=lambda x: x.relu(),
dropout=0.1,
):
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.num_heads = num_heads
@ -10,11 +19,23 @@ class TransformerBlock:
self.prenorm, self.act = prenorm, act
self.dropout = dropout
self.query = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.key = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.value = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.query = (
Tensor.scaled_uniform(embed_dim, embed_dim),
Tensor.zeros(embed_dim),
)
self.key = (
Tensor.scaled_uniform(embed_dim, embed_dim),
Tensor.zeros(embed_dim),
)
self.value = (
Tensor.scaled_uniform(embed_dim, embed_dim),
Tensor.zeros(embed_dim),
)
self.out = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.out = (
Tensor.scaled_uniform(embed_dim, embed_dim),
Tensor.zeros(embed_dim),
)
self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
@ -24,25 +45,41 @@ class TransformerBlock:
def attn(self, x):
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
query, key, value = [x.linear(*y).reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)).transpose(1,2) for y in [self.query, self.key, self.value]]
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(1,2)
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
query, key, value = [
x.linear(*y)
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size))
.transpose(1, 2)
for y in [self.query, self.key, self.value]
]
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(
1, 2
)
return attention.reshape(
shape=(x.shape[0], -1, self.num_heads * self.head_size)
).linear(*self.out)
def __call__(self, x):
if self.prenorm:
x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout)
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(
*self.ff2
).dropout(self.dropout)
else:
x = x + self.attn(x).dropout(self.dropout)
x = x.layernorm().linear(*self.ln1)
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(
self.dropout
)
x = x.layernorm().linear(*self.ln2)
return x
class Transformer:
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
self.maxlen, self.syms = maxlen, syms
self.embed = Tensor.scaled_uniform(maxlen+syms, embed_dim, requires_grad=False)
self.embed = Tensor.scaled_uniform(
maxlen + syms, embed_dim, requires_grad=False
)
self.tbs = []
for i in range(layers):
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
@ -57,8 +94,11 @@ class Transformer:
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
onehot = onehot.reshape(bs * x.shape[1], self.maxlen + self.syms)
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
x = (
Tensor(onehot, device=x.device)
.dot(self.embed)
.reshape(shape=(bs, x.shape[1], -1))
)
x = x.sequential(self.tbs)
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax()
return x.reshape(shape=(bs, -1, x.shape[-1]))

View File

@ -4,25 +4,63 @@ from tinygrad import nn
from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch, get_child
class DownsampleBlock:
def __init__(self, c0, c1, stride=2):
self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
self.conv1 = [
nn.Conv2d(
c0,
c1,
kernel_size=(3, 3, 3),
stride=stride,
padding=(1, 1, 1, 1, 1, 1),
bias=False,
),
nn.InstanceNorm(c1),
Tensor.relu,
]
self.conv2 = [
nn.Conv2d(
c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False
),
nn.InstanceNorm(c1),
Tensor.relu,
]
def __call__(self, x):
return x.sequential(self.conv1).sequential(self.conv2)
class UpsampleBlock:
def __init__(self, c0, c1):
self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)]
self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
self.upsample_conv = [
nn.ConvTranspose2d(c0, c1, kernel_size=(2, 2, 2), stride=2)
]
self.conv1 = [
nn.Conv2d(
2 * c1,
c1,
kernel_size=(3, 3, 3),
padding=(1, 1, 1, 1, 1, 1),
bias=False,
),
nn.InstanceNorm(c1),
Tensor.relu,
]
self.conv2 = [
nn.Conv2d(
c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False
),
nn.InstanceNorm(c1),
Tensor.relu,
]
def __call__(self, x, skip):
x = x.sequential(self.upsample_conv)
x = Tensor.cat(x, skip, dim=1)
return x.sequential(self.conv1).sequential(self.conv2)
class UNet3D:
def __init__(self, in_channels=1, n_class=3):
filters = [32, 64, 128, 256, 320]
@ -30,7 +68,9 @@ class UNet3D:
self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])]
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [
UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])
]
self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
def __call__(self, x):
@ -47,13 +87,17 @@ class UNet3D:
def load_from_pretrained(self):
fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt"
fetch("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
fetch(
"https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1",
fn,
)
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
for k, v in state_dict.items():
obj = get_child(self, k)
assert obj.shape == v.shape, (k, obj.shape, v.shape)
obj.assign(v.numpy())
if __name__ == "__main__":
mdl = UNet3D()
mdl.load_from_pretrained()

Some files were not shown because too many files have changed in this diff Show More