Reformat, uh, everything, with black
parent
01503ca90d
commit
661dcc5ed0
|
@ -4,13 +4,17 @@ import pathlib
|
|||
from hexdump import hexdump
|
||||
|
||||
fxn = None
|
||||
|
||||
|
||||
def disasm(buf):
|
||||
global fxn
|
||||
if fxn is None:
|
||||
shared = pathlib.Path(__file__).parent / "disasm.so"
|
||||
if not shared.is_file():
|
||||
os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so')
|
||||
fxn = ctypes.CDLL(shared.as_posix())['disasm']
|
||||
os.system(
|
||||
f"cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so"
|
||||
)
|
||||
fxn = ctypes.CDLL(shared.as_posix())["disasm"]
|
||||
# hexdump(buf)
|
||||
END = b"\x00\x00\x00\x00\x00\x00\x00\x03"
|
||||
buf = buf[0x510:] # this right?
|
||||
|
|
|
@ -23,21 +23,24 @@ from abc import ABC
|
|||
|
||||
# we will be using the clang backend
|
||||
from tinygrad import Device
|
||||
|
||||
Device.DEFAULT = "CLANG"
|
||||
|
||||
# first, 2+3 as a Tensor, the highest level
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
a = Tensor([2])
|
||||
b = Tensor([3])
|
||||
result = a + b
|
||||
print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
|
||||
assert result.numpy()[0] == 5.
|
||||
assert result.numpy()[0] == 5.0
|
||||
|
||||
# %%
|
||||
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
|
||||
# it's worth reading tinygrad/tensor.py. it's pretty beautiful
|
||||
import tinygrad.mlops as mlops
|
||||
|
||||
|
||||
# this is the good old familiar Tensor class
|
||||
class Tensor:
|
||||
# these two are pretty straightforward
|
||||
|
@ -51,10 +54,13 @@ class Tensor:
|
|||
lazydata: LazyBuffer
|
||||
|
||||
# high level ops (hlops) are defined on this class. example: relu
|
||||
def relu(self): return self.maximum(0)
|
||||
def relu(self):
|
||||
return self.maximum(0)
|
||||
|
||||
# log is an mlop, this is the wrapper function in Tensor
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def log(self):
|
||||
return mlops.Log.apply(self)
|
||||
|
||||
|
||||
# all the definitions of the derivatives are subclasses of Function (like mlops.Log)
|
||||
# there's only 18 mlops for derivatives for everything (in tinygrad/mlops.py, code 9/10)
|
||||
|
@ -62,13 +68,18 @@ class Tensor:
|
|||
# you can differentiate the world using the chain rule
|
||||
class Function:
|
||||
# example types of forward and backward
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: pass
|
||||
def backward(self, x:LazyBuffer) -> LazyBuffer: pass
|
||||
def forward(self, x: LazyBuffer) -> LazyBuffer:
|
||||
pass
|
||||
|
||||
def backward(self, x: LazyBuffer) -> LazyBuffer:
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
# == LazyBuffer (in tinygrad/lazy.py, code 5/10) ==
|
||||
from tinygrad.helpers import DType
|
||||
|
||||
|
||||
# this is where the properties live that you thought were a part of Tensor
|
||||
# LazyBuffer is like a Tensor without derivatives, at the mlop layer
|
||||
class LazyBuffer:
|
||||
|
@ -91,6 +102,7 @@ class LazyBuffer:
|
|||
# this LazyOp describes the computation needed to realize this LazyBuffer
|
||||
op: Optional[LazyOp]
|
||||
|
||||
|
||||
# LazyOp (in tinygrad/ops.py, code 4/10)
|
||||
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
|
||||
class LazyOp:
|
||||
|
@ -98,13 +110,52 @@ class LazyOp:
|
|||
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
||||
arg: Optional[Any] = None # and an optional static argument
|
||||
|
||||
|
||||
# there's currently 26 Ops you have to implement for an accelerator.
|
||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto()
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto()
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto()
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto()
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
|
||||
class UnaryOps(Enum):
|
||||
EXP2 = auto()
|
||||
LOG2 = auto()
|
||||
CAST = auto()
|
||||
SIN = auto()
|
||||
SQRT = auto()
|
||||
|
||||
|
||||
class BinaryOps(Enum):
|
||||
ADD = auto()
|
||||
SUB = auto()
|
||||
MUL = auto()
|
||||
DIV = auto()
|
||||
CMPLT = auto()
|
||||
MAX = auto()
|
||||
|
||||
|
||||
class ReduceOps(Enum):
|
||||
SUM = auto()
|
||||
MAX = auto()
|
||||
|
||||
|
||||
class MovementOps(Enum):
|
||||
RESHAPE = auto()
|
||||
PERMUTE = auto()
|
||||
EXPAND = auto()
|
||||
PAD = auto()
|
||||
SHRINK = auto()
|
||||
STRIDE = auto()
|
||||
|
||||
|
||||
class TernaryOps(Enum):
|
||||
MULACC = auto()
|
||||
WHERE = auto()
|
||||
|
||||
|
||||
class LoadOps(Enum):
|
||||
EMPTY = auto()
|
||||
CONST = auto()
|
||||
FROM = auto()
|
||||
CONTIGUOUS = auto()
|
||||
CUSTOM = auto()
|
||||
|
||||
|
||||
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
|
||||
# you do not need to implement the MovementOps
|
||||
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)
|
||||
|
@ -135,7 +186,9 @@ assert len(lazyop.src) == 2
|
|||
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
||||
assert lazyop.src[0].op.op == LoadOps.FROM
|
||||
assert lazyop.src[0].op.src[0].device == "CPU"
|
||||
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
||||
assert (
|
||||
lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2
|
||||
), "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||
|
||||
# now we realize the LazyBuffer
|
||||
|
@ -151,12 +204,15 @@ assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU,
|
|||
|
||||
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
|
||||
|
||||
|
||||
# Interpreted backends are very simple (example: CPU and TORCH)
|
||||
class Interpreted:
|
||||
# and they have a lookup table to functions for the Ops
|
||||
fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.EXP2: lambda x: np.exp2(x),
|
||||
BinaryOps.ADD: lambda x,y: x+y}
|
||||
BinaryOps.ADD: lambda x, y: x + y,
|
||||
}
|
||||
|
||||
|
||||
# Compiled backends take a little more (example: GPU and LLVM)
|
||||
class Compiled:
|
||||
|
@ -166,26 +222,40 @@ class Compiled:
|
|||
# and a runtime, which runs the generated code
|
||||
runtime: Type[Runtime]
|
||||
|
||||
|
||||
# Runtime is what actually runs the kernels for a compiled backend
|
||||
class Runtime(ABC):
|
||||
# `name` is the name of the function, and `prg` is the code
|
||||
# the constructor compiles the code
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
def __init__(self, name: str, prg: str):
|
||||
pass
|
||||
|
||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||
def __call__(self, *bufs:List[Buffer], global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
|
||||
def __call__(
|
||||
self,
|
||||
*bufs: List[Buffer],
|
||||
global_size: Optional[List[int]],
|
||||
local_size: Optional[List[int]],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
# == Buffer (in tinygrad/device.py, code 6/10) ==
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Buffer is where the data is actually held. it's pretty close to just memory
|
||||
class Buffer(ABC):
|
||||
# create an empty rawbuffer that holds `size` elements of type `dtype`
|
||||
# `opaque` is an opaque container class
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None): pass
|
||||
def __init__(self, device: str, size: int, dtype: DType, opaque: Any = None):
|
||||
pass
|
||||
|
||||
# toCPU converts the RawBuffer to a numpy array with shape (size,)
|
||||
def toCPU(self) -> np.ndarray: pass
|
||||
def toCPU(self) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
# == Example: 2+3 in raw clang ==
|
||||
|
@ -205,6 +275,7 @@ from tinygrad.runtime.ops_clang import ClangProgram, compile_clang
|
|||
# then we copy the numpy in to RawMallocBuffers
|
||||
# last, we create an empty output buffer
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
input_a, input_b = MallocAllocator.alloc(4), MallocAllocator.alloc(4)
|
||||
output = MallocAllocator.alloc(4)
|
||||
|
||||
|
@ -214,7 +285,9 @@ MallocAllocator.copyin(input_a, numpy_a.data.cast("B"))
|
|||
MallocAllocator.copyin(input_b, numpy_b.data.cast("B"))
|
||||
|
||||
# compile the program, run it, and 2+3 does indeed equal 5
|
||||
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
|
||||
program = ClangProgram(
|
||||
"add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
|
||||
)
|
||||
program(output, input_a, input_b)
|
||||
numpy_out = np.empty(1, dtype=np.float32)
|
||||
MallocAllocator.copyout(numpy_out.data.cast("B"), output)
|
||||
|
@ -229,7 +302,16 @@ np.testing.assert_allclose(numpy_out, numpy_a+numpy_b)
|
|||
# the first step of transforming an AST into code is to "linearize" it, think like toposort on the AST
|
||||
# for that, we use the Linearizer, which turns an AST into a list of (linear) UOps
|
||||
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto();
|
||||
|
||||
class UOps(Enum):
|
||||
LOOP = auto()
|
||||
DEFINE_LOCAL = auto()
|
||||
LOAD = auto()
|
||||
ALU = auto()
|
||||
CONST = auto()
|
||||
ENDLOOP = auto()
|
||||
STORE = auto()
|
||||
|
||||
|
||||
class UOp:
|
||||
uop: UOps
|
||||
|
@ -238,26 +320,34 @@ class UOp:
|
|||
arg: Any
|
||||
num: int # UOps are unique
|
||||
|
||||
|
||||
class Linearizer:
|
||||
# create the kernel with the AST
|
||||
# NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change
|
||||
def __init__(self, ast:LazyOp): pass
|
||||
def linearize(self): pass
|
||||
def __init__(self, ast: LazyOp):
|
||||
pass
|
||||
|
||||
def linearize(self):
|
||||
pass
|
||||
|
||||
# when linearize is run, it fills in this list
|
||||
uops: List[UOp]
|
||||
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
result = Tensor(2).realize() + Tensor(3).realize()
|
||||
|
||||
# use the real Linearizer to linearize 2+3
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
||||
sched = result.lazydata.schedule()
|
||||
linearizer = Linearizer(sched[-1].ast)
|
||||
linearizer.linearize()
|
||||
|
||||
# print the uops
|
||||
for uop in linearizer.uops: print(uop)
|
||||
for uop in linearizer.uops:
|
||||
print(uop)
|
||||
|
||||
# output:
|
||||
"""
|
||||
|
@ -275,11 +365,13 @@ for uop in linearizer.uops: print(uop)
|
|||
# here, we have an example where we fetch the generated code from the JIT
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
result = Tensor(2) + Tensor(3)
|
||||
|
||||
# we have a global cache used by the JIT
|
||||
# from there, we can see the generated clang code
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
CacheCollector.start() # enables the cache
|
||||
result.realize() # create the program and runs it
|
||||
cache_saved = CacheCollector.finish() # disable the cache
|
||||
|
@ -319,7 +411,9 @@ print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
|||
# we can then reshape it, and the strides change again
|
||||
# note how the permute stays applied
|
||||
a = a.reshape((5, 2, 5, 2))
|
||||
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
||||
print(
|
||||
a
|
||||
) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
||||
|
||||
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
|
||||
a = a.reshape((100,))
|
||||
|
|
|
@ -49,14 +49,21 @@ b = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I",
|
|||
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
||||
|
||||
# describe the computation
|
||||
ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
ld_1 = LazyOp(
|
||||
BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||
)
|
||||
ld_2 = LazyOp(
|
||||
BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||
)
|
||||
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
|
||||
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
st_0 = LazyOp(
|
||||
BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||
)
|
||||
|
||||
# convert the computation to a "linearized" format (print the format)
|
||||
lin = Device[DEVICE].get_linearizer(st_0).linearize()
|
||||
for u in lin.uops: print(u)
|
||||
for u in lin.uops:
|
||||
print(u)
|
||||
|
||||
# compile a program (and print the source)
|
||||
fxn = Device[DEVICE].to_program(lin)
|
||||
|
@ -79,6 +86,7 @@ from tinygrad.realize import run_schedule
|
|||
# allocate some values + load in values
|
||||
# TODO: remove numpy here
|
||||
import numpy as np
|
||||
|
||||
a = LazyBuffer.fromCPU(np.array([2], np.int32)).copy_to_device(DEVICE)
|
||||
b = LazyBuffer.fromCPU(np.array([3], np.int32)).copy_to_device(DEVICE)
|
||||
|
||||
|
@ -87,10 +95,12 @@ out = a.e(BinaryOps.ADD, b)
|
|||
|
||||
# schedule the computation as a list of kernels
|
||||
sched = out.schedule()
|
||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
for si in sched:
|
||||
print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
|
||||
# DEBUGGING: print the compute ast as a tree
|
||||
from tinygrad.graph import print_tree
|
||||
|
||||
print_tree(sched[-1].ast)
|
||||
# NOTE: sched[-1].ast is the same as st_0 above
|
||||
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
from typing import Tuple
|
||||
import time
|
||||
from tinygrad import Tensor, TinyJit, nn, Variable
|
||||
from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype
|
||||
from tinygrad.helpers import (
|
||||
dtypes,
|
||||
) # TODO: wouldn't need this if argmax returned the right dtype
|
||||
import gymnasium as gym
|
||||
from tqdm import trange
|
||||
import numpy as np # TODO: remove numpy import
|
||||
|
||||
|
||||
class ActorCritic:
|
||||
def __init__(self, in_features, out_features, hidden_state=32):
|
||||
self.l1 = nn.Linear(in_features, hidden_state)
|
||||
|
@ -20,6 +23,7 @@ class ActorCritic:
|
|||
x = self.c1(obs).relu()
|
||||
return act, self.c2(x)
|
||||
|
||||
|
||||
def evaluate(model: ActorCritic, test_env: gym.Env) -> float:
|
||||
(obs, _), terminated, truncated = test_env.reset(), False, False
|
||||
total_rew = 0.0
|
||||
|
@ -29,21 +33,26 @@ def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
|
|||
total_rew += float(rew)
|
||||
return total_rew
|
||||
|
||||
|
||||
# TODO: time should be < 5s on M1 Max
|
||||
if __name__ == "__main__":
|
||||
env = gym.make('CartPole-v1')
|
||||
env = gym.make("CartPole-v1")
|
||||
|
||||
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
|
||||
|
||||
@TinyJit
|
||||
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def train_step(
|
||||
x: Tensor, selected_action: Tensor, reward: Tensor, old_log_dist: Tensor
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
with Tensor.train():
|
||||
log_dist, value = model(x)
|
||||
|
||||
# get advantage
|
||||
advantage = reward.reshape(-1, 1) - value
|
||||
mask = selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)
|
||||
mask = selected_action.reshape(-1, 1) == Tensor.arange(
|
||||
log_dist.shape[1]
|
||||
).reshape(1, -1).expand(selected_action.shape[0], -1)
|
||||
masked_advantage = mask * advantage.detach()
|
||||
|
||||
# PPO
|
||||
|
@ -51,7 +60,9 @@ if __name__ == "__main__":
|
|||
clipped_ratios = ratios.clip(1 - 0.2, 1 + 0.2) * masked_advantage
|
||||
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
|
||||
|
||||
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
|
||||
entropy_loss = (
|
||||
(log_dist.exp() * log_dist).sum(-1).mean()
|
||||
) # this encourages diversity
|
||||
critic_loss = advantage.square().mean()
|
||||
opt.zero_grad()
|
||||
(action_loss + entropy_loss * 0.0005 + critic_loss).backward()
|
||||
|
@ -96,7 +107,11 @@ if __name__ == "__main__":
|
|||
discounts = np.power(0.99, np.arange(len(rews)))
|
||||
Rn += [np.sum(rews[i:] * discounts[: len(rews) - i]) for i in range(len(rews))]
|
||||
|
||||
Xn, An, Rn = Xn[-MAX_REPLAY_BUFFER:], An[-MAX_REPLAY_BUFFER:], Rn[-MAX_REPLAY_BUFFER:]
|
||||
Xn, An, Rn = (
|
||||
Xn[-MAX_REPLAY_BUFFER:],
|
||||
An[-MAX_REPLAY_BUFFER:],
|
||||
Rn[-MAX_REPLAY_BUFFER:],
|
||||
)
|
||||
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
|
||||
|
||||
# TODO: make this work
|
||||
|
@ -105,10 +120,16 @@ if __name__ == "__main__":
|
|||
|
||||
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
|
||||
for i in range(5):
|
||||
samples = Tensor.randint(BS, high=X.shape[0]).realize() # TODO: remove the need for this
|
||||
samples = Tensor.randint(
|
||||
BS, high=X.shape[0]
|
||||
).realize() # TODO: remove the need for this
|
||||
# TODO: is this recompiling based on the shape?
|
||||
action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
|
||||
t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}")
|
||||
action_loss, entropy_loss, critic_loss = train_step(
|
||||
X[samples], A[samples], R[samples], old_log_dist[samples]
|
||||
)
|
||||
t.set_description(
|
||||
f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}"
|
||||
)
|
||||
|
||||
test_rew = evaluate(model, gym.make('CartPole-v1', render_mode='human'))
|
||||
test_rew = evaluate(model, gym.make("CartPole-v1", render_mode="human"))
|
||||
print(f"test reward: {test_rew}")
|
||||
|
|
|
@ -4,18 +4,29 @@ from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
|||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm2d(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm2d(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
nn.Conv2d(1, 32, 5),
|
||||
Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5),
|
||||
Tensor.relu,
|
||||
nn.BatchNorm2d(32),
|
||||
Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3),
|
||||
Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3),
|
||||
Tensor.relu,
|
||||
nn.BatchNorm2d(64),
|
||||
Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1),
|
||||
nn.Linear(576, 10),
|
||||
]
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return x.sequential(self.layers)
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
||||
|
@ -29,17 +40,25 @@ if __name__ == "__main__":
|
|||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
loss = (
|
||||
model(X_train[samples])
|
||||
.sparse_categorical_crossentropy(Y_train[samples])
|
||||
.backward()
|
||||
)
|
||||
opt.step()
|
||||
return loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
|
||||
def get_test_acc() -> Tensor:
|
||||
return ((model(X_test).argmax(axis=1) == Y_test).mean() * 100).realize()
|
||||
|
||||
test_acc = float('nan')
|
||||
test_acc = float("nan")
|
||||
for i in (t := trange(70)):
|
||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
|
||||
samples = Tensor.randint(
|
||||
512, high=X_train.shape[0]
|
||||
) # TODO: put this in the JIT when rand is fixed
|
||||
loss = train_step(samples)
|
||||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
if i % 10 == 9:
|
||||
test_acc = get_test_acc().item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
|
|
|
@ -10,9 +10,11 @@ from tinygrad.helpers import GlobalCounters
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
|
||||
def tensors_allocated():
|
||||
return sum(isinstance(x, Tensor) for x in gc.get_objects())
|
||||
|
||||
|
||||
NUM = getenv("NUM", 2)
|
||||
BS = getenv("BS", 8)
|
||||
CNT = getenv("CNT", 10)
|
||||
|
@ -25,9 +27,12 @@ if __name__ == "__main__":
|
|||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters: p.realize()
|
||||
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
|
||||
else: optimizer = optim.SGD(parameters, lr=0.001)
|
||||
for p in parameters:
|
||||
p.realize()
|
||||
if ADAM:
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
else:
|
||||
optimizer = optim.SGD(parameters, lr=0.001)
|
||||
|
||||
Tensor.training = TRAINING
|
||||
Tensor.no_grad = not BACKWARD
|
||||
|
@ -42,7 +47,8 @@ if __name__ == "__main__":
|
|||
st = time.monotonic()
|
||||
out = model.forward(x_train)
|
||||
loss = out.log_softmax().mul(y_train).mean()
|
||||
if i == 2 and CLCACHE: CacheCollector.start()
|
||||
if i == 2 and CLCACHE:
|
||||
CacheCollector.start()
|
||||
if BACKWARD:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
@ -54,7 +60,8 @@ if __name__ == "__main__":
|
|||
et = time.monotonic()
|
||||
else:
|
||||
st = mt = time.monotonic()
|
||||
for prg, args in cl_cache: prg(*args)
|
||||
for prg, args in cl_cache:
|
||||
prg(*args)
|
||||
et = time.monotonic()
|
||||
|
||||
if i == 2 and CLCACHE:
|
||||
|
@ -64,4 +71,6 @@ if __name__ == "__main__":
|
|||
loss_cpu = loss.detach().numpy()
|
||||
cl = time.monotonic()
|
||||
|
||||
print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
print(
|
||||
f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, sys, traceback
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from io import StringIO
|
||||
|
@ -9,27 +10,46 @@ from tinygrad.helpers import Timing, colored, getenv, fetch
|
|||
from extra.models.llama import Transformer, convert_from_huggingface
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
def create_fixed_tokenizer(output_file):
|
||||
print("creating fixed tokenizer")
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
|
||||
mp.ParseFromString(
|
||||
fetch(
|
||||
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true"
|
||||
).read_bytes()
|
||||
)
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(mp.SerializeToString())
|
||||
|
||||
|
||||
# TODO: make loading bf16 fast so we can remove this
|
||||
def create_model_cache(output_file, model):
|
||||
print(f"creating model cache at {output_file}")
|
||||
# TODO: add read only Tensors
|
||||
with Timing("download weights: "):
|
||||
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
|
||||
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
|
||||
part1 = nn.state.torch_load(
|
||||
fetch(
|
||||
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"
|
||||
)
|
||||
)
|
||||
part2 = nn.state.torch_load(
|
||||
fetch(
|
||||
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"
|
||||
)
|
||||
)
|
||||
|
||||
with Timing("weights -> model: "):
|
||||
nn.state.load_state_dict(model, convert_from_huggingface(part1, model, 32, 8), strict=False)
|
||||
nn.state.load_state_dict(model, convert_from_huggingface(part2, model, 32, 8), strict=False)
|
||||
nn.state.load_state_dict(
|
||||
model, convert_from_huggingface(part1, model, 32, 8), strict=False
|
||||
)
|
||||
nn.state.load_state_dict(
|
||||
model, convert_from_huggingface(part2, model, 32, 8), strict=False
|
||||
)
|
||||
|
||||
with Timing("saving float16 cache: "):
|
||||
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
|
||||
|
@ -37,27 +57,44 @@ def create_model_cache(output_file, model):
|
|||
print("cache created, rerun to use")
|
||||
exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||
with Timing("create model: "):
|
||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
|
||||
model = Transformer(
|
||||
4096,
|
||||
14336,
|
||||
n_heads=32,
|
||||
n_layers=32,
|
||||
norm_eps=1e-5,
|
||||
vocab_size=32002,
|
||||
n_kv_heads=8,
|
||||
max_context=4096,
|
||||
)
|
||||
|
||||
cached_model = "/tmp/cached_openhermes.safetensors"
|
||||
if not os.path.isfile(cached_model): create_model_cache(cached_model, model)
|
||||
if not os.path.isfile(cached_model):
|
||||
create_model_cache(cached_model, model)
|
||||
with Timing("loading float16 cache: "):
|
||||
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
|
||||
|
||||
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
|
||||
if not os.path.isfile("/tmp/tokenizer.model"):
|
||||
create_fixed_tokenizer("/tmp/tokenizer.model")
|
||||
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
IM_END = 32000
|
||||
IM_START = 32001
|
||||
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
||||
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
|
||||
|
||||
def encode_prompt(k, v):
|
||||
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
|
||||
|
||||
def start_prompt(k):
|
||||
return [IM_START] + spp.encode(f"{k}\n")
|
||||
|
||||
def output(outputted, toks, color):
|
||||
cur = spp.decode(toks)[len(outputted) :]
|
||||
sys.stdout.write(colored(cur, color))
|
||||
|
@ -67,7 +104,10 @@ if __name__ == "__main__":
|
|||
|
||||
# *** app below this line ***
|
||||
|
||||
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
|
||||
toks = [spp.bos_id()] + encode_prompt(
|
||||
"system",
|
||||
"You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input",
|
||||
)
|
||||
|
||||
PROMPT = getenv("PROMPT", 1)
|
||||
temperature = getenv("TEMP", 0.7)
|
||||
|
@ -83,24 +123,34 @@ if __name__ == "__main__":
|
|||
turn = not turn
|
||||
old_output_len = len(outputted)
|
||||
while 1:
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
|
||||
tok = (
|
||||
model(Tensor([toks[start_pos:]]), start_pos, temperature)
|
||||
.multinomial()
|
||||
.item()
|
||||
)
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
||||
if tok == IM_END: break
|
||||
if tok == spp.eos_id(): break
|
||||
if tok == IM_END:
|
||||
break
|
||||
if tok == spp.eos_id():
|
||||
break
|
||||
new_output = outputted[old_output_len:]
|
||||
|
||||
if new_output.endswith("```") and '```python\n' in new_output:
|
||||
python_code = new_output.split('```python\n')[1].split("```")[0]
|
||||
if new_output.endswith("```") and "```python\n" in new_output:
|
||||
python_code = new_output.split("```python\n")[1].split("```")[0]
|
||||
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
|
||||
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
|
||||
if (
|
||||
input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower()
|
||||
== "y"
|
||||
):
|
||||
my_stdout = StringIO()
|
||||
try:
|
||||
with redirect_stdout(my_stdout): exec(python_code)
|
||||
with redirect_stdout(my_stdout):
|
||||
exec(python_code)
|
||||
result = my_stdout.getvalue()
|
||||
except Exception as e:
|
||||
result = ''.join(traceback.format_exception_only(e))
|
||||
result = "".join(traceback.format_exception_only(e))
|
||||
toks += spp.encode(f"\nOutput:\n```\n{result}```")
|
||||
outputted = output(outputted, toks, "yellow")
|
||||
old_output_len = len(outputted)
|
||||
|
|
|
@ -9,8 +9,16 @@ import ast
|
|||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
mode = (
|
||||
"clang"
|
||||
if getenv("CLANG", "") != ""
|
||||
else "webgpu"
|
||||
if getenv("WEBGPU", "") != ""
|
||||
else ""
|
||||
)
|
||||
prg, inp_sizes, out_sizes, state = export_model(
|
||||
model, mode, Tensor.randn(1, 3, 224, 224)
|
||||
)
|
||||
dirname = Path(__file__).parent
|
||||
if getenv("CLANG", "") == "":
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
|
@ -20,19 +28,33 @@ if __name__ == "__main__":
|
|||
else:
|
||||
cprog = [prg]
|
||||
# image library!
|
||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")]
|
||||
cprog += [
|
||||
"#define STB_IMAGE_IMPLEMENTATION",
|
||||
fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h")
|
||||
.read_text()
|
||||
.replace("half", "_half"),
|
||||
]
|
||||
|
||||
# imagenet labels, move to datasets?
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
lbls = ast.literal_eval(
|
||||
fetch(
|
||||
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||
).read_text()
|
||||
)
|
||||
lbls = ['"' + lbls[i] + '"' for i in range(1000)]
|
||||
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
|
||||
outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
|
||||
inputs = "\n".join(
|
||||
[f"float {inp}[{inp_size}];" for inp, inp_size in inp_sizes.items()]
|
||||
)
|
||||
outputs = "\n".join(
|
||||
[f"float {out}[{out_size}];" for out, out_size in out_sizes.items()]
|
||||
)
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
cprog.append(inputs)
|
||||
cprog.append(outputs)
|
||||
|
||||
# buffers (empty + weights)
|
||||
cprog.append("""
|
||||
cprog.append(
|
||||
"""
|
||||
int main(int argc, char* argv[]) {
|
||||
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
|
||||
int X=0, Y=0, chan=0;
|
||||
|
@ -62,8 +84,9 @@ if __name__ == "__main__":
|
|||
}
|
||||
if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
|
||||
else printf("%s\\n", lbls[best_idx]);
|
||||
}""")
|
||||
}"""
|
||||
)
|
||||
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print('\n'.join(cprog))
|
||||
print("\n".join(cprog))
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# An example to compile a small Tensorflow model to extremely portable C code
|
||||
|
||||
import os, sys
|
||||
os.environ["CLANG"] = '1'
|
||||
os.environ["GPU"] = '1'
|
||||
|
||||
os.environ["CLANG"] = "1"
|
||||
os.environ["GPU"] = "1"
|
||||
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
@ -12,32 +13,42 @@ from examples.compile_efficientnet import compile_net
|
|||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
def get_uncompiled_model2(dataset_size=32, output_size=4):
|
||||
inputs = tf.keras.Input(shape=(dataset_size,), name="inputs")
|
||||
x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs)
|
||||
x = tf.keras.layers.BatchNormalization()(x)
|
||||
x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x)
|
||||
outputs = tf.keras.layers.Dense(output_size, activation="sigmoid", name="predictions")(x)
|
||||
outputs = tf.keras.layers.Dense(
|
||||
output_size, activation="sigmoid", name="predictions"
|
||||
)(x)
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
return model
|
||||
|
||||
|
||||
def create_onnx_model(keras_model):
|
||||
input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
|
||||
input_signature = [tf.TensorSpec([1, 32], tf.float32, name="x")]
|
||||
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
|
||||
return onnx_model
|
||||
|
||||
|
||||
def compile_onnx_model(onnx_model):
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
@TinyJit
|
||||
def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize()
|
||||
def run(x):
|
||||
return run_onnx({"x": x}, debug=False)["predictions"].realize()
|
||||
|
||||
the_input = Tensor.randn(1, 32)
|
||||
the_output = run(the_input)
|
||||
the_output = run(the_input)
|
||||
|
||||
special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"}
|
||||
special_names = {
|
||||
id(the_input.lazydata.realized.cl): "input",
|
||||
id(the_output.lazydata.realized.cl): "outputs",
|
||||
}
|
||||
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
cprog = ["#include <string.h>", "#include <stdio.h>", "#include <stdlib.h>"] + cprog
|
||||
|
||||
|
@ -60,7 +71,8 @@ def compile_onnx_model(onnx_model):
|
|||
cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"]
|
||||
|
||||
# test program
|
||||
cprog.append(f"""int main(int argc, char *argv[]) {{
|
||||
cprog.append(
|
||||
f"""int main(int argc, char *argv[]) {{
|
||||
// read in the weights from disk
|
||||
FILE *f = fopen("/tmp/tf_weights", "rb");
|
||||
float *weights = (float *)malloc({len(weights)});
|
||||
|
@ -75,25 +87,38 @@ def compile_onnx_model(onnx_model):
|
|||
for (int i = 0; i < 32; i++) scanf("%f", &input[i]);
|
||||
float *outputs = infer(input);
|
||||
printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]);
|
||||
}}""")
|
||||
}}"""
|
||||
)
|
||||
|
||||
# ready the program
|
||||
prg = '\n'.join(cprog)
|
||||
prg = "\n".join(cprog)
|
||||
print(prg)
|
||||
|
||||
# add test weights
|
||||
subprocess.check_output(['clang', '-O2', '-lm', '-fPIC', '-x', 'c', '-', '-o', "/tmp/tf_test"], input=prg.encode('utf-8'))
|
||||
subprocess.check_output(
|
||||
["clang", "-O2", "-lm", "-fPIC", "-x", "c", "-", "-o", "/tmp/tf_test"],
|
||||
input=prg.encode("utf-8"),
|
||||
)
|
||||
|
||||
tinygrad_output = [x for x in the_output.numpy()[0]]
|
||||
print("tinygrad:", tinygrad_output, file=sys.stderr)
|
||||
|
||||
c_input = ' '.join(["%f" % x for x in the_input[0].numpy()])+"\n"
|
||||
c_output = [float(x) for x in subprocess.check_output(["/tmp/tf_test"], input=c_input.encode('utf-8')).decode('utf-8').strip().split(" ")]
|
||||
c_input = " ".join(["%f" % x for x in the_input[0].numpy()]) + "\n"
|
||||
c_output = [
|
||||
float(x)
|
||||
for x in subprocess.check_output(
|
||||
["/tmp/tf_test"], input=c_input.encode("utf-8")
|
||||
)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split(" ")
|
||||
]
|
||||
print("compiled:", c_output, file=sys.stderr)
|
||||
|
||||
np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5)
|
||||
return the_input.numpy(), c_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
keras_model = get_uncompiled_model2()
|
||||
onnx_model = create_onnx_model(keras_model)
|
||||
|
@ -101,4 +126,3 @@ if __name__ == "__main__":
|
|||
tf_output = keras_model(test_input).numpy()[0]
|
||||
print("keras: ", tf_output, file=sys.stderr)
|
||||
np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
|
|
@ -12,7 +12,14 @@ import pyaudio
|
|||
import yaml
|
||||
from llama import LLaMa
|
||||
from vits import MODELS as VITS_MODELS
|
||||
from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
|
||||
from vits import (
|
||||
Y_LENGTH_ESTIMATE_SCALARS,
|
||||
HParams,
|
||||
Synthesizer,
|
||||
TextMapper,
|
||||
get_hparams_from_file,
|
||||
load_model,
|
||||
)
|
||||
from whisper import init_whisper, transcribe_waveform
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
@ -29,16 +36,26 @@ IM_END = 32002
|
|||
|
||||
|
||||
# Functions for encoding prompts to chatml md
|
||||
def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
||||
def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
|
||||
def encode_prompt(spp, k, v):
|
||||
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
|
||||
|
||||
|
||||
def start_prompt(spp, k):
|
||||
return [IM_START] + spp.encode(f"{k}\n")
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
for i in range(0, len(lst), n): yield lst[i:i + n]
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def create_fixed_tokenizer():
|
||||
"""Function needed for extending tokenizer with additional chat tokens"""
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
|
||||
|
||||
tokenizer_path = fetch(
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model"
|
||||
)
|
||||
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
|
||||
print("creating fixed tokenizer")
|
||||
mp = spb2.ModelProto()
|
||||
|
@ -50,16 +67,28 @@ def create_fixed_tokenizer():
|
|||
tokenizer_path.write_bytes(mp.SerializeToString())
|
||||
return tokenizer_path
|
||||
|
||||
def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
|
||||
|
||||
def llama_prepare(
|
||||
llama: LLaMa, temperature: float, pre_prompt_path: Path
|
||||
) -> tuple[list[int], str, str, str]:
|
||||
"""Prepares a llama model from a specified pre-prompt file"""
|
||||
with open(str(pre_prompt_path)) as f:
|
||||
config = yaml.safe_load(f.read())
|
||||
toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
|
||||
toks = [llama.tokenizer.bos_id()] + encode_prompt(
|
||||
llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ")
|
||||
)
|
||||
for i in config["examples"]:
|
||||
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
|
||||
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
|
||||
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
|
||||
return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
|
||||
return (
|
||||
toks,
|
||||
config["user_delim"],
|
||||
config["resp_delim"],
|
||||
len(toks),
|
||||
llama.tokenizer.decode(toks),
|
||||
)
|
||||
|
||||
|
||||
def llama_generate(
|
||||
llama: LLaMa,
|
||||
|
@ -70,7 +99,7 @@ def llama_generate(
|
|||
user_delim: str,
|
||||
resp_delim: str,
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
max_tokens=1000,
|
||||
):
|
||||
"""Generates an output for the specified prompt"""
|
||||
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
|
||||
|
@ -79,7 +108,9 @@ def llama_generate(
|
|||
outputted = llama.tokenizer.decode(toks)
|
||||
init_length = len(outputted)
|
||||
for _ in range(max_tokens):
|
||||
probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy()
|
||||
probs_np = llama.model(
|
||||
Tensor([toks[start_pos:]]), start_pos, temperature
|
||||
).numpy()
|
||||
token = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
toks.append(token)
|
||||
|
@ -90,12 +121,14 @@ def llama_generate(
|
|||
sys.stdout.write(cur[len(outputted) :])
|
||||
sys.stdout.flush()
|
||||
outputted = cur
|
||||
if toks[-1] == IM_END: break
|
||||
if toks[-1] == IM_END:
|
||||
break
|
||||
else:
|
||||
toks.append(IM_END)
|
||||
print() # because the output is flushed
|
||||
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
|
||||
|
||||
|
||||
def tts(
|
||||
text_to_synthesize: str,
|
||||
synth: Synthesizer,
|
||||
|
@ -110,24 +143,45 @@ def tts(
|
|||
text_mapper: TextMapper,
|
||||
model_has_multiple_speakers: bool,
|
||||
batch_size=600,
|
||||
vits_batch_size=1000
|
||||
vits_batch_size=1000,
|
||||
):
|
||||
if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||
if model_to_use == "mmts-tts":
|
||||
text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||
|
||||
# Convert the input text to a tensor.
|
||||
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
|
||||
stn_tst = text_mapper.get_text(
|
||||
text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners
|
||||
)
|
||||
init_shape = stn_tst.shape
|
||||
assert init_shape[0] < batch_size, "text is too long"
|
||||
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||
sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
|
||||
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(
|
||||
0
|
||||
), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||
sid = (
|
||||
Tensor([speaker_id], dtype=dtypes.int64)
|
||||
if model_has_multiple_speakers
|
||||
else None
|
||||
)
|
||||
|
||||
# Perform inference.
|
||||
audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0]
|
||||
audio_tensor = synth.infer(
|
||||
x_tst,
|
||||
x_tst_lengths,
|
||||
sid,
|
||||
noise_scale,
|
||||
length_scale,
|
||||
noise_scale_w,
|
||||
emotion_embedding=emotion_embedding,
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use]
|
||||
if estimate_max_y_length
|
||||
else None,
|
||||
batch_size=vits_batch_size,
|
||||
)[0, 0]
|
||||
# Save the audio output.
|
||||
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
||||
return audio_data
|
||||
|
||||
|
||||
def init_vits(
|
||||
model_to_use: str,
|
||||
emotion_path: Path,
|
||||
|
@ -142,21 +196,44 @@ def init_vits(
|
|||
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
||||
model_has_multiple_speakers = hps.data.n_speakers > 0
|
||||
if model_has_multiple_speakers:
|
||||
if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
|
||||
if speaker_id >= hps.data.n_speakers:
|
||||
raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
|
||||
if hps.__contains__("speakers"): # maps speaker ids to names
|
||||
speakers = hps.speakers
|
||||
if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
||||
if isinstance(speakers, list):
|
||||
speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
||||
|
||||
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
|
||||
emotion_embedding = None
|
||||
if emotion_path is not None:
|
||||
if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
|
||||
else: raise ValueError("Emotion path must be a .npy file.")
|
||||
if emotion_path.endswith(".npy"):
|
||||
emotion_embedding = Tensor(
|
||||
np.load(emotion_path), dtype=dtypes.int64
|
||||
).unsqueeze(0)
|
||||
else:
|
||||
raise ValueError("Emotion path must be a .npy file.")
|
||||
|
||||
# Load symbols, instantiate TextMapper and clean the text.
|
||||
if hps.__contains__("symbols"): symbols = hps.symbols
|
||||
elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
|
||||
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
|
||||
if hps.__contains__("symbols"):
|
||||
symbols = hps.symbols
|
||||
elif model_to_use == "mmts-tts":
|
||||
symbols = [
|
||||
x.replace("\n", "")
|
||||
for x in fetch(
|
||||
"https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt"
|
||||
)
|
||||
.open(encoding="utf-8")
|
||||
.readlines()
|
||||
]
|
||||
else:
|
||||
symbols = (
|
||||
["_"]
|
||||
+ list(';:,.!?¡¿—…"«»“” ')
|
||||
+ list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")
|
||||
+ list(
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
)
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
|
||||
# Load the model.
|
||||
|
@ -168,18 +245,23 @@ def init_vits(
|
|||
|
||||
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def output_stream(num_channels: int, sample_rate: int):
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
|
||||
stream = p.open(
|
||||
format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True
|
||||
)
|
||||
yield stream
|
||||
except KeyboardInterrupt: pass
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_writer():
|
||||
try:
|
||||
|
@ -191,10 +273,17 @@ def log_writer():
|
|||
print(*logs, sep="\n")
|
||||
print(sep)
|
||||
|
||||
|
||||
def listener(q: mp.Queue, event: mp.Event):
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
stream = p.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK,
|
||||
)
|
||||
did_print = False
|
||||
while True:
|
||||
data = stream.read(CHUNK) # read data to avoid overflow
|
||||
|
@ -210,7 +299,10 @@ def listener(q: mp.Queue, event: mp.Event):
|
|||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
|
||||
|
||||
def mp_output_stream(
|
||||
q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int
|
||||
):
|
||||
with output_stream(num_channels, sample_rate) as stream:
|
||||
while True:
|
||||
try:
|
||||
|
@ -219,8 +311,10 @@ def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_r
|
|||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import nltk
|
||||
|
||||
nltk.download("punkt")
|
||||
Tensor.no_grad = True
|
||||
# Parse CLI arguments
|
||||
|
@ -230,75 +324,212 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
|
||||
|
||||
# LLAMA args
|
||||
parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
|
||||
parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
|
||||
parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
|
||||
parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
|
||||
parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
|
||||
parser.add_argument(
|
||||
"--llama_pre_prompt_path",
|
||||
type=Path,
|
||||
default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml",
|
||||
help="Path to yaml file which contains all pre-prompt data needed. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_count", type=int, default=1000, help="Max number of tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Temperature in the softmax",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_quantize",
|
||||
action="store_true",
|
||||
help="Quantize the weights to int8 in memory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_model",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_gen",
|
||||
type=str,
|
||||
default="tiny",
|
||||
required=False,
|
||||
help="Generation of the model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_size",
|
||||
type=str,
|
||||
default="1B-Chat",
|
||||
required=False,
|
||||
help="Size of model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_tokenizer",
|
||||
type=Path,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to llama tokenizer.model",
|
||||
)
|
||||
|
||||
# vits args
|
||||
parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
|
||||
parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
|
||||
parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
|
||||
parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
|
||||
parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
|
||||
parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
|
||||
parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
|
||||
parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
|
||||
parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
|
||||
parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
|
||||
parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
|
||||
parser.add_argument(
|
||||
"--vits_model_to_use",
|
||||
default="vctk",
|
||||
help="Specify the model to use. Default is 'vctk'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_speaker_id",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Specify the speaker ID. Default is 6.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_noise_scale",
|
||||
type=float,
|
||||
default=0.667,
|
||||
help="Specify the noise scale. Default is 0.667.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_noise_scale_w",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Specify the noise scale w. Default is 0.8.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_length_scale",
|
||||
type=float,
|
||||
default=1,
|
||||
help="Specify the length scale. Default is 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Specify the seed (set to None if no seed). Default is 1337.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_num_channels",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Specify the number of audio output channels. Default is 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_sample_width",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Specify the number of bytes per sample, adjust if necessary. Default is 2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_emotion_path",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Specify the path to emotion reference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_estimate_max_y_length",
|
||||
type=str,
|
||||
default=False,
|
||||
help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary."
|
||||
)
|
||||
|
||||
# conversation args
|
||||
parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
|
||||
parser.add_argument(
|
||||
"--max_sentence_length",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max words in one sentence to pass to vits",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init models
|
||||
model, enc = init_whisper(args.whisper_model_name)
|
||||
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
|
||||
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(
|
||||
args.vits_model_to_use,
|
||||
args.vits_emotion_path,
|
||||
args.vits_speaker_id,
|
||||
args.vits_seed,
|
||||
)
|
||||
|
||||
# Download tinyllama chat as a default model
|
||||
if args.llama_model is None:
|
||||
args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
|
||||
args.llama_model = fetch(
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors",
|
||||
"tinyllamachat.safetensors",
|
||||
)
|
||||
args.llama_gen = "tiny"
|
||||
args.llama_size = "1B-Chat"
|
||||
# Add 3 more tokens to the tokenizer
|
||||
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
|
||||
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"):
|
||||
args.llama_tokenizer = create_fixed_tokenizer()
|
||||
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
|
||||
llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
|
||||
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
|
||||
llama = LLaMa.build(
|
||||
args.llama_model,
|
||||
tokenizer_path,
|
||||
args.llama_gen,
|
||||
args.llama_size,
|
||||
args.llama_quantize,
|
||||
)
|
||||
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(
|
||||
llama, args.llama_temperature, args.llama_pre_prompt_path
|
||||
)
|
||||
|
||||
# Start child process for mic input
|
||||
q = mp.Queue()
|
||||
is_listening_event = mp.Event()
|
||||
p = mp.Process(target=listener, args=(q, is_listening_event,))
|
||||
p = mp.Process(
|
||||
target=listener,
|
||||
args=(
|
||||
q,
|
||||
is_listening_event,
|
||||
),
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
# Start child process for speaker output
|
||||
out_q = mp.Queue()
|
||||
out_counter = mp.Value("i", 0)
|
||||
out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
|
||||
out_p = mp.Process(
|
||||
target=mp_output_stream,
|
||||
args=(
|
||||
out_q,
|
||||
out_counter,
|
||||
args.vits_num_channels,
|
||||
hps.data.sampling_rate,
|
||||
),
|
||||
)
|
||||
out_p.daemon = True
|
||||
out_p.start()
|
||||
|
||||
# JIT tts
|
||||
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
|
||||
tts(
|
||||
i, synth, hps, emotion_embedding,
|
||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
||||
args.vits_noise_scale_w, args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
||||
i,
|
||||
synth,
|
||||
hps,
|
||||
emotion_embedding,
|
||||
args.vits_speaker_id,
|
||||
args.vits_model_to_use,
|
||||
args.vits_noise_scale,
|
||||
args.vits_noise_scale_w,
|
||||
args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length,
|
||||
text_mapper,
|
||||
model_has_multiple_speakers,
|
||||
)
|
||||
|
||||
# Start the pipeline
|
||||
with log_writer() as log:
|
||||
while True:
|
||||
tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
tokens = [
|
||||
enc._special_tokens["<|startoftranscript|>"],
|
||||
enc._special_tokens["<|notimestamps|>"],
|
||||
]
|
||||
total = np.array([])
|
||||
out_counter.value = 0
|
||||
|
||||
|
@ -306,10 +537,12 @@ if __name__ == "__main__":
|
|||
is_listening_event.set()
|
||||
prev_text = None
|
||||
while True:
|
||||
for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
|
||||
for _ in range(RATE // CHUNK):
|
||||
total = np.concatenate([total, q.get()])
|
||||
txt = transcribe_waveform(model, enc, [total], truncate=True)
|
||||
print(txt, end="\r")
|
||||
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
|
||||
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()):
|
||||
continue
|
||||
if prev_text is not None and prev_text == txt:
|
||||
is_listening_event.clear()
|
||||
break
|
||||
|
@ -320,9 +553,15 @@ if __name__ == "__main__":
|
|||
# Generate with llama
|
||||
with Timing("llama generation: "):
|
||||
outputted, start_pos, response = llama_generate(
|
||||
llama, toks, outputted, txt, start_pos,
|
||||
user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
|
||||
max_tokens=args.llama_count
|
||||
llama,
|
||||
toks,
|
||||
outputted,
|
||||
txt,
|
||||
start_pos,
|
||||
user_delim=user_delim,
|
||||
resp_delim=resp_delim,
|
||||
temperature=args.llama_temperature,
|
||||
max_tokens=args.llama_count,
|
||||
)
|
||||
log.append(f"{resp_delim.capitalize()}: {response}")
|
||||
|
||||
|
@ -333,12 +572,21 @@ if __name__ == "__main__":
|
|||
total = np.array([], dtype=np.int16)
|
||||
for j in chunks(i.split(), args.max_sentence_length):
|
||||
audio_data = tts(
|
||||
" ".join(j), synth, hps, emotion_embedding,
|
||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
||||
args.vits_noise_scale_w, args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
||||
" ".join(j),
|
||||
synth,
|
||||
hps,
|
||||
emotion_embedding,
|
||||
args.vits_speaker_id,
|
||||
args.vits_model_to_use,
|
||||
args.vits_noise_scale,
|
||||
args.vits_noise_scale_w,
|
||||
args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length,
|
||||
text_mapper,
|
||||
model_has_multiple_speakers,
|
||||
)
|
||||
total = np.concatenate([total, audio_data])
|
||||
out_q.put(total.tobytes())
|
||||
while out_counter.value < len(sentences): continue
|
||||
while out_counter.value < len(sentences):
|
||||
continue
|
||||
log.append(f"Total: {time.perf_counter() - s}")
|
||||
|
|
|
@ -11,12 +11,14 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import getenv, fetch, Timing
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
# TODO: you should be able to put these in the jitted function
|
||||
bias = Tensor([0.485, 0.456, 0.406])
|
||||
scale = Tensor([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
@TinyJit
|
||||
def _infer(model, img):
|
||||
img = img.permute((2, 0, 1))
|
||||
|
@ -25,10 +27,13 @@ def _infer(model, img):
|
|||
img = img / scale.reshape((1, -1, 1, 1))
|
||||
return model.forward(img).realize()
|
||||
|
||||
|
||||
def infer(model, img):
|
||||
# preprocess image
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
img = img.resize(
|
||||
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
|
||||
)
|
||||
|
||||
img = np.array(img)
|
||||
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
||||
|
@ -52,18 +57,28 @@ def infer(model, img):
|
|||
"""
|
||||
return out, retimg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet(getenv("NUM", 0))
|
||||
model.load_from_pretrained()
|
||||
|
||||
# category labels
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
lbls = ast.literal_eval(
|
||||
fetch(
|
||||
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||
).read_text()
|
||||
)
|
||||
|
||||
# load image and preprocess
|
||||
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
||||
if url == 'webcam':
|
||||
url = (
|
||||
sys.argv[1]
|
||||
if len(sys.argv) >= 2
|
||||
else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
||||
)
|
||||
if url == "webcam":
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
while 1:
|
||||
|
@ -72,12 +87,17 @@ if __name__ == "__main__":
|
|||
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
||||
lt = time.monotonic_ns()
|
||||
out, retimg = infer(model, img)
|
||||
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
print(
|
||||
f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms",
|
||||
np.argmax(out),
|
||||
np.max(out),
|
||||
lbls[np.argmax(out)],
|
||||
)
|
||||
SCALE = 3
|
||||
simg = cv2.resize(retimg, (224 * SCALE, 224 * SCALE))
|
||||
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow('capture', retimg)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
cv2.imshow("capture", retimg)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
|
|
@ -3,6 +3,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import dtypes
|
||||
from tinygrad import Device
|
||||
|
||||
|
||||
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
|
||||
def bit_extract(x, s, e) -> Tensor:
|
||||
# extract the top bits we don't want
|
||||
|
@ -10,11 +11,16 @@ def bit_extract(x, s, e) -> Tensor:
|
|||
x = (x - top_bits) / (1 << e)
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
def u16_to_f16(x):
|
||||
sign = bit_extract(x, 15, 15).float()
|
||||
exponent = bit_extract(x, 14, 10).float()
|
||||
fraction = bit_extract(x, 9, 0).float()
|
||||
return sign.where(-1, 1) * exponent.where((exponent - 15).exp2() * (1 + fraction / 0x400), 6.103515625e-5 * (fraction / 0x400))
|
||||
return sign.where(-1, 1) * exponent.where(
|
||||
(exponent - 15).exp2() * (1 + fraction / 0x400),
|
||||
6.103515625e-5 * (fraction / 0x400),
|
||||
)
|
||||
|
||||
|
||||
def u32_to_f16(oo):
|
||||
oo1 = (oo / 0x10000).floor().contiguous()
|
||||
|
@ -24,6 +30,7 @@ def u32_to_f16(oo):
|
|||
f2 = u16_to_f16(oo2)
|
||||
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# random float16
|
||||
Tensor.manual_seed(2)
|
||||
|
|
222
examples/gpt2.py
222
examples/gpt2.py
|
@ -10,11 +10,20 @@ from tinygrad.shape.symbolic import Variable
|
|||
from tinygrad.jit import TinyJit
|
||||
import tiktoken
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
|
||||
from tinygrad.helpers import (
|
||||
GlobalCounters,
|
||||
Timing,
|
||||
DEBUG,
|
||||
getenv,
|
||||
fetch,
|
||||
colored,
|
||||
dtypes,
|
||||
)
|
||||
|
||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
||||
HALF = getenv("HALF")
|
||||
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads):
|
||||
self.c_attn = Linear(dim, 3 * dim, bias=True)
|
||||
|
@ -23,18 +32,27 @@ class Attention:
|
|||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
|
||||
def __call__(
|
||||
self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
if mask is not None:
|
||||
# no symbolic shape qkv when consuming prompts
|
||||
start_pos = start_pos.val
|
||||
|
||||
xqkv = self.c_attn(x)
|
||||
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
|
||||
xq, xk, xv = [
|
||||
xqkv.shrink((None, None, (i * self.dim, (i + 1) * self.dim))).reshape(
|
||||
xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
self.cache_k, self.cache_v = Tensor.zeros(
|
||||
bsz, MAX_CONTEXT, self.n_heads, self.head_dim
|
||||
), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
if HALF:
|
||||
self.cache_k = self.cache_k.half()
|
||||
self.cache_v = self.cache_v.half()
|
||||
|
@ -43,11 +61,28 @@ class Attention:
|
|||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_k.assign(
|
||||
keys.pad(
|
||||
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
self.cache_v.assign(
|
||||
values.pad(
|
||||
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
|
||||
xq, keys, values = (
|
||||
xq.transpose(1, 2),
|
||||
keys.transpose(1, 2),
|
||||
values.transpose(1, 2),
|
||||
)
|
||||
return self.c_proj(
|
||||
xq.scaled_dot_product_attention(keys, values, mask)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, seqlen, -1)
|
||||
)
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim):
|
||||
|
@ -57,6 +92,7 @@ class FeedForward:
|
|||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.c_proj(self.c_fc(x).gelu())
|
||||
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim, n_heads, norm_eps):
|
||||
self.attn = Attention(dim, n_heads)
|
||||
|
@ -66,7 +102,8 @@ class TransformerBlock:
|
|||
|
||||
def __call__(self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]):
|
||||
h = x + self.attn(self.ln_1(x), start_pos, mask)
|
||||
return (h + self.mlp(self.ln_2(h)))
|
||||
return h + self.mlp(self.ln_2(h))
|
||||
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
||||
|
@ -78,7 +115,8 @@ class Transformer:
|
|||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0):
|
||||
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||
if not hasattr(self, "allpos"):
|
||||
self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||
_bsz, seqlen = tokens.shape
|
||||
|
||||
# NOTE: cannot convert token indices into half due to precision
|
||||
|
@ -86,44 +124,73 @@ class Transformer:
|
|||
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos + seqlen))))
|
||||
h = tok_emb + pos_emb
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
mask = (
|
||||
Tensor.full((1, 1, seqlen, start_pos.val + seqlen), float("-inf"))
|
||||
.triu(start_pos.val + 1)
|
||||
.realize()
|
||||
if seqlen > 1
|
||||
else None
|
||||
)
|
||||
|
||||
if HALF:
|
||||
h = h.half()
|
||||
if mask is not None: mask = mask.half()
|
||||
if mask is not None:
|
||||
mask = mask.half()
|
||||
|
||||
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
|
||||
for hi in self.h:
|
||||
h = hi(h, start_pos=start_pos, mask=mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))
|
||||
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
||||
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
|
||||
def __call__(
|
||||
self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0
|
||||
) -> Tensor:
|
||||
return (
|
||||
self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward
|
||||
)(tokens, start_pos, temperature)
|
||||
|
||||
|
||||
VOCAB_SIZE = 50257
|
||||
MODEL_PARAMS = {
|
||||
'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params
|
||||
'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params
|
||||
'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params
|
||||
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
|
||||
"gpt2": dict(
|
||||
n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 124M params
|
||||
"gpt2-medium": dict(
|
||||
n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 350M params
|
||||
"gpt2-large": dict(
|
||||
n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 774M params
|
||||
"gpt2-xl": dict(
|
||||
n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 1558M params
|
||||
}
|
||||
|
||||
|
||||
class GPT2:
|
||||
@staticmethod
|
||||
def build(model_size="gpt2"):
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
model = Transformer(**MODEL_PARAMS[model_size])
|
||||
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
|
||||
weights = torch_load(
|
||||
fetch(f"https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin")
|
||||
)
|
||||
# special treatment for the Conv1D weights we need to transpose
|
||||
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
||||
transposed = [
|
||||
"attn.c_attn.weight",
|
||||
"attn.c_proj.weight",
|
||||
"mlp.c_fc.weight",
|
||||
"mlp.c_proj.weight",
|
||||
]
|
||||
for k in weights.keys():
|
||||
if any(k.endswith(w) for w in transposed):
|
||||
weights[k] = Tensor(weights[k].numpy().T)
|
||||
# lm head and wte are tied
|
||||
weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy())
|
||||
weights["lm_head.weight"] = Tensor(weights["wte.weight"].numpy())
|
||||
|
||||
load_state_dict(model, weights)
|
||||
return GPT2(model, tokenizer)
|
||||
|
@ -132,42 +199,98 @@ class GPT2:
|
|||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
||||
def greedy_until(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: int,
|
||||
temperature: float,
|
||||
timing: bool = False,
|
||||
batch_size: int = 1,
|
||||
):
|
||||
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
||||
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
||||
start_pos = 0
|
||||
for _ in trange(max_length, disable=(timing == True)):
|
||||
GlobalCounters.reset()
|
||||
if timing: print("")
|
||||
if timing:
|
||||
print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Timing("total ", enabled=timing):
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
||||
probs = self.model(Tensor([x[start_pos:] for x in toks]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
|
||||
with Timing(
|
||||
"ran model in ",
|
||||
on_exit=(
|
||||
lambda et: (
|
||||
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
|
||||
+ (
|
||||
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
)
|
||||
if DEBUG
|
||||
else None,
|
||||
enabled=timing,
|
||||
):
|
||||
probs = self.model(
|
||||
Tensor([x[start_pos:] for x in toks]),
|
||||
Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(
|
||||
start_pos
|
||||
),
|
||||
temperature,
|
||||
)
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().flatten().numpy().tolist()
|
||||
start_pos = len(toks[0])
|
||||
for i,t in enumerate(tok): toks[i].append(t)
|
||||
for i, t in enumerate(tok):
|
||||
toks[i].append(t)
|
||||
output = [self.tokenizer.decode(x) for x in toks]
|
||||
return output
|
||||
|
||||
|
||||
# **** main code ****
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--prompt', type=str, default="What is the answer to life, the universe, and everything?", help="Phrase to start with")
|
||||
parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
|
||||
parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
|
||||
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
||||
parser.add_argument('--seed', type=int, help="Set the random seed")
|
||||
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
|
||||
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run GPT2 in tinygrad",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="What is the answer to life, the universe, and everything?",
|
||||
help="Phrase to start with",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=100, help="Max number of tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.8, help="Temperature in the softmax"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="gpt2-medium",
|
||||
help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]",
|
||||
)
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||
parser.add_argument("--seed", type=int, help="Set the random seed")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Set the input batch size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Benchmark GPT with the given number of tokens",
|
||||
)
|
||||
parser.add_argument("--noshow", action="store_true", help="Don't show the output")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
|
@ -182,11 +305,22 @@ if __name__ == "__main__":
|
|||
l.assign(l.cast(dtypes.float16).realize())
|
||||
|
||||
if args.benchmark != -1:
|
||||
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
||||
gpt2.model(
|
||||
Tensor.rand(args.batch_size, args.benchmark),
|
||||
Variable("a", 0, MAX_CONTEXT).bind(0),
|
||||
).realize()
|
||||
else:
|
||||
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
|
||||
texts = gpt2.greedy_until(
|
||||
args.prompt,
|
||||
args.count,
|
||||
args.temperature,
|
||||
timing=args.timing,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
if not args.noshow:
|
||||
print('Generating text...')
|
||||
if len(texts) == 1: print(texts[0])
|
||||
print("Generating text...")
|
||||
if len(texts) == 1:
|
||||
print(texts[0])
|
||||
else:
|
||||
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
|
||||
for i, text in enumerate(texts):
|
||||
print(colored(f"Response {i}:", "green"), text)
|
||||
|
|
|
@ -28,7 +28,8 @@ if __name__ == "__main__":
|
|||
sched = [x for x in sched if x.ast.op not in LoadOps]
|
||||
|
||||
# focus on one kernel
|
||||
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
|
||||
if getenv("KERNEL", -1) >= 0:
|
||||
sched = sched[getenv("KERNEL", -1) : getenv("KERNEL", -1) + 1]
|
||||
|
||||
# work with the schedule
|
||||
total_tm = 0
|
||||
|
@ -52,20 +53,33 @@ if __name__ == "__main__":
|
|||
# try a beam search
|
||||
if getenv("BEAM"):
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
lin = beam_search(lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
lin = beam_search(
|
||||
lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1))
|
||||
)
|
||||
lins.append(lin)
|
||||
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
for lin in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
|
||||
gflops = (
|
||||
sym_infer(lin.info.flops, {k: k.min for k in vars_from_ast(lin.ast)})
|
||||
* 1e-9
|
||||
/ tm
|
||||
)
|
||||
choices.append((tm, gflops, lin.linearize()))
|
||||
|
||||
# print all kernels
|
||||
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
||||
if DEBUG >= 1:
|
||||
print(
|
||||
f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
|
||||
)
|
||||
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
|
||||
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
||||
print(
|
||||
f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
|
||||
)
|
||||
total_tm += tm
|
||||
running_gflops += gflops * tm
|
||||
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
|
||||
print(
|
||||
f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS"
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# setup for distributed
|
||||
from extra import dist
|
||||
from tinygrad.helpers import getenv, dtypes
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("DIST"):
|
||||
dist.preinit()
|
||||
|
@ -24,7 +25,7 @@ from tinygrad.shape.symbolic import Node
|
|||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv("EVAL_BS", 500), getenv("STEPS", 1000)
|
||||
|
||||
if getenv("HALF", 0):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
@ -33,16 +34,28 @@ else:
|
|||
Tensor.default_type = dtypes.float32
|
||||
np_dtype = np.float32
|
||||
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
def __init__(self, num_features):
|
||||
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
|
||||
super().__init__(
|
||||
num_features,
|
||||
track_running_stats=False,
|
||||
eps=1e-12,
|
||||
momentum=0.85,
|
||||
affine=True,
|
||||
)
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = True
|
||||
|
||||
|
||||
class ConvGroup:
|
||||
def __init__(self, channels_in, channels_out):
|
||||
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
channels_in, channels_out, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
channels_out, channels_out, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
|
||||
self.norm1 = BatchNorm(channels_out)
|
||||
self.norm2 = BatchNorm(channels_out)
|
||||
|
@ -63,6 +76,7 @@ class ConvGroup:
|
|||
|
||||
return x + residual
|
||||
|
||||
|
||||
class SpeedyResNet:
|
||||
def __init__(self, W):
|
||||
self.whitening = W
|
||||
|
@ -74,54 +88,58 @@ class SpeedyResNet:
|
|||
ConvGroup(256, 512),
|
||||
lambda x: x.max((2, 3)),
|
||||
nn.Linear(512, 10, bias=False),
|
||||
lambda x: x.mul(1./9)
|
||||
lambda x: x.mul(1.0 / 9),
|
||||
]
|
||||
|
||||
def __call__(self, x, training=True):
|
||||
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
|
||||
# TODO: remove the pad but instead let the kernel optimizer itself
|
||||
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
|
||||
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
|
||||
forward = (
|
||||
lambda x: x.conv2d(self.whitening).pad2d((1, 0, 0, 1)).sequential(self.net)
|
||||
)
|
||||
return (
|
||||
forward(x) if training else forward(x) * 0.5 + forward(x[..., ::-1]) * 0.5
|
||||
)
|
||||
|
||||
|
||||
def train_cifar():
|
||||
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
hyp: Dict[str, Any] = {
|
||||
'seed' : 209,
|
||||
'opt': {
|
||||
'bias_lr': 1.76 * bias_scaler/512,
|
||||
'non_bias_lr': 1.76 / 512,
|
||||
'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
|
||||
'non_bias_decay': 1.08 * 6.45e-4 * BS,
|
||||
'final_lr_ratio': 0.025,
|
||||
'initial_div_factor': 1e16,
|
||||
'label_smoothing': 0.20,
|
||||
'momentum': 0.85,
|
||||
'percent_start': 0.23,
|
||||
'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
||||
"seed": 209,
|
||||
"opt": {
|
||||
"bias_lr": 1.76 * bias_scaler / 512,
|
||||
"non_bias_lr": 1.76 / 512,
|
||||
"bias_decay": 1.08 * 6.45e-4 * BS / bias_scaler,
|
||||
"non_bias_decay": 1.08 * 6.45e-4 * BS,
|
||||
"final_lr_ratio": 0.025,
|
||||
"initial_div_factor": 1e16,
|
||||
"label_smoothing": 0.20,
|
||||
"momentum": 0.85,
|
||||
"percent_start": 0.23,
|
||||
"loss_scale_scaler": 1.0 / 128, # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
||||
},
|
||||
'net': {
|
||||
'kernel_size': 2, # kernel size for the whitening layer
|
||||
'cutmix_size': 3,
|
||||
'cutmix_steps': 499,
|
||||
'pad_amount': 2
|
||||
"net": {
|
||||
"kernel_size": 2, # kernel size for the whitening layer
|
||||
"cutmix_size": 3,
|
||||
"cutmix_steps": 499,
|
||||
"pad_amount": 2,
|
||||
},
|
||||
"ema": {
|
||||
"steps": 399,
|
||||
"decay_base": 0.95,
|
||||
"decay_pow": 1.6,
|
||||
"every_n_steps": 5,
|
||||
},
|
||||
'ema': {
|
||||
'steps': 399,
|
||||
'decay_base': .95,
|
||||
'decay_pow': 1.6,
|
||||
'every_n_steps': 5,
|
||||
}
|
||||
}
|
||||
|
||||
def set_seed(seed):
|
||||
Tensor.manual_seed(getenv('SEED', seed))
|
||||
random.seed(getenv('SEED', seed))
|
||||
Tensor.manual_seed(getenv("SEED", seed))
|
||||
random.seed(getenv("SEED", seed))
|
||||
|
||||
# ========== Model ==========
|
||||
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
|
||||
def whitening(X, kernel_size=hyp['net']['kernel_size']):
|
||||
def whitening(X, kernel_size=hyp["net"]["kernel_size"]):
|
||||
def _cov(X):
|
||||
X = X / np.sqrt(X.shape[0] - 1)
|
||||
return X.T @ X
|
||||
|
@ -130,12 +148,18 @@ def train_cifar():
|
|||
h, w = patch_size
|
||||
c = data.shape[1]
|
||||
axis: SupportsIndex = (2, 3) # type: ignore
|
||||
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
|
||||
return (
|
||||
np.lib.stride_tricks.sliding_window_view(
|
||||
data, window_shape=(h, w), axis=axis
|
||||
)
|
||||
.transpose((0, 3, 2, 1, 4, 5))
|
||||
.reshape((-1, c, h, w))
|
||||
)
|
||||
|
||||
def _eigens(patches):
|
||||
n, c, h, w = patches.shape
|
||||
Σ = _cov(patches.reshape(n, c * h * w))
|
||||
Λ, V = np.linalg.eigh(Σ, UPLO='U')
|
||||
Λ, V = np.linalg.eigh(Σ, UPLO="U")
|
||||
return np.flip(Λ, 0), np.flip(V.T.reshape(c * h * w, c, h, w), 0)
|
||||
|
||||
Λ, V = _eigens(_patches(X.numpy()))
|
||||
|
@ -144,12 +168,16 @@ def train_cifar():
|
|||
return Tensor(W.astype(np_dtype), requires_grad=False)
|
||||
|
||||
# ========== Loss ==========
|
||||
def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
|
||||
def cross_entropy(
|
||||
x: Tensor, y: Tensor, reduction: str = "mean", label_smoothing: float = 0.0
|
||||
) -> Tensor:
|
||||
divisor = y.shape[1]
|
||||
assert not isinstance(divisor, Node), "sint not supported as divisor"
|
||||
y = (1 - label_smoothing) * y + label_smoothing / divisor
|
||||
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||
if reduction == "none":
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction == "sum":
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
||||
|
||||
# ========== Preprocessing ==========
|
||||
|
@ -161,12 +189,20 @@ def train_cifar():
|
|||
p = padding[3]
|
||||
s = X.shape[3]
|
||||
|
||||
X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0)))
|
||||
X_lr = X[..., :, 1 : 1 + p[0]].flip(3).pad(
|
||||
((0, 0), (0, 0), (0, 0), (0, s + p[0]))
|
||||
) + X[..., :, -1 - p[1] : -1].flip(3).pad(
|
||||
((0, 0), (0, 0), (0, 0), (s + p[1], 0))
|
||||
)
|
||||
X = X.pad(((0, 0), (0, 0), (0, 0), p)) + X_lr
|
||||
|
||||
p = padding[2]
|
||||
s = X.shape[2]
|
||||
X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0)))
|
||||
X_lr = X[..., 1 : 1 + p[0], :].flip(2).pad(
|
||||
((0, 0), (0, 0), (0, s + p[0]), (0, 0))
|
||||
) + X[..., -1 - p[1] : -1, :].flip(2).pad(
|
||||
((0, 0), (0, 0), (s + p[1], 0), (0, 0))
|
||||
)
|
||||
X = X.pad(((0, 0), (0, 0), p, (0, 0))) + X_lr
|
||||
|
||||
return X
|
||||
|
@ -176,10 +212,18 @@ def train_cifar():
|
|||
is_even = int(mask_size % 2 == 0)
|
||||
center_max = shape[-2] - mask_size // 2 - is_even
|
||||
center_min = mask_size // 2 - is_even
|
||||
center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
||||
center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
||||
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
|
||||
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
|
||||
center_x = (
|
||||
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
|
||||
).floor()
|
||||
center_y = (
|
||||
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
|
||||
).floor()
|
||||
d_x = Tensor.arange(0, shape[-1]).reshape(
|
||||
(1, 1, 1, shape[-1])
|
||||
) - center_x.reshape((-1, 1, 1, 1))
|
||||
d_y = Tensor.arange(0, shape[-2]).reshape(
|
||||
(1, 1, shape[-2], 1)
|
||||
) - center_y.reshape((-1, 1, 1, 1))
|
||||
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
||||
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
||||
mask = d_y * d_x
|
||||
|
@ -200,7 +244,7 @@ def train_cifar():
|
|||
Y_patch = Tensor(Y.numpy()[order])
|
||||
X_cutmix = Tensor.where(mask, X_patch, X)
|
||||
mix_portion = float(mask_size**2) / (X.shape[-2] * X.shape[-1])
|
||||
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
|
||||
Y_cutmix = mix_portion * Y_patch + (1.0 - mix_portion) * Y
|
||||
return X_cutmix, Y_cutmix
|
||||
|
||||
# the operations that remain inside batch fetcher is the ones that involves random operations
|
||||
|
@ -213,11 +257,16 @@ def train_cifar():
|
|||
random.shuffle(order)
|
||||
if is_train:
|
||||
X = random_crop(X, crop_size=32)
|
||||
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
|
||||
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
|
||||
X = Tensor.where(
|
||||
Tensor.rand(X.shape[0], 1, 1, 1) < 0.5, X[..., ::-1], X
|
||||
) # flip LR
|
||||
if step >= hyp["net"]["cutmix_steps"]:
|
||||
X, Y = cutmix(X, Y, mask_size=hyp["net"]["cutmix_size"])
|
||||
X, Y = X.numpy(), Y.numpy()
|
||||
et = time.monotonic()
|
||||
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
|
||||
print(
|
||||
f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})"
|
||||
)
|
||||
for i in range(0, X.shape[0], BS):
|
||||
# pad the last batch
|
||||
batch_end = min(i + BS, Y.shape[0])
|
||||
|
@ -226,18 +275,24 @@ def train_cifar():
|
|||
step += 1
|
||||
yield x, y
|
||||
cnt += 1
|
||||
if not is_train: break
|
||||
if not is_train:
|
||||
break
|
||||
|
||||
transform = [
|
||||
lambda x: x / 255.0,
|
||||
lambda x: (x.reshape((-1,3,32,32)) - Tensor(cifar_mean).reshape((1,3,1,1)))/Tensor(cifar_std).reshape((1,3,1,1))
|
||||
lambda x: (
|
||||
x.reshape((-1, 3, 32, 32)) - Tensor(cifar_mean).reshape((1, 3, 1, 1))
|
||||
)
|
||||
/ Tensor(cifar_std).reshape((1, 3, 1, 1)),
|
||||
]
|
||||
|
||||
class modelEMA():
|
||||
class modelEMA:
|
||||
def __init__(self, w, net):
|
||||
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
|
||||
self.net_ema = SpeedyResNet(w)
|
||||
for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
|
||||
for net_ema_param, net_param in zip(
|
||||
get_state_dict(self.net_ema).values(), get_state_dict(net).values()
|
||||
):
|
||||
net_ema_param.requires_grad = False
|
||||
net_ema_param.assign(net_param.numpy())
|
||||
|
||||
|
@ -245,23 +300,37 @@ def train_cifar():
|
|||
def update(self, net, decay):
|
||||
# TODO with Tensor.no_grad()
|
||||
Tensor.no_grad = True
|
||||
for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
|
||||
for net_ema_param, (param_name, net_param) in zip(
|
||||
get_state_dict(self.net_ema).values(), get_state_dict(net).items()
|
||||
):
|
||||
# batchnorm currently is not being tracked
|
||||
if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
|
||||
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
|
||||
if not ("num_batches_tracked" in param_name) and not (
|
||||
"running" in param_name
|
||||
):
|
||||
net_ema_param.assign(
|
||||
net_ema_param.detach() * decay
|
||||
+ net_param.detach() * (1.0 - decay)
|
||||
).realize()
|
||||
Tensor.no_grad = False
|
||||
|
||||
set_seed(hyp['seed'])
|
||||
set_seed(hyp["seed"])
|
||||
|
||||
# this import needs to be done here because this is running in a subprocess
|
||||
from extra.dist import OOB
|
||||
|
||||
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
|
||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
# load data and label into GPU and convert to dtype accordingly
|
||||
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
|
||||
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
|
||||
X_train, X_test = (
|
||||
X_train.to(device=Device.DEFAULT).float(),
|
||||
X_test.to(device=Device.DEFAULT).float(),
|
||||
)
|
||||
Y_train, Y_test = (
|
||||
Y_train.to(device=Device.DEFAULT).float(),
|
||||
Y_test.to(device=Device.DEFAULT).float(),
|
||||
)
|
||||
# one-hot encode labels
|
||||
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
|
||||
# preprocess data
|
||||
|
@ -274,10 +343,15 @@ def train_cifar():
|
|||
model = SpeedyResNet(W)
|
||||
|
||||
# padding is not timed in the original repo since it can be done all at once
|
||||
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
|
||||
X_train = pad_reflect(X_train, size=hyp["net"]["pad_amount"])
|
||||
|
||||
# Convert data and labels to the default dtype
|
||||
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
|
||||
X_train, Y_train, X_test, Y_test = (
|
||||
X_train.cast(Tensor.default_type),
|
||||
Y_train.cast(Tensor.default_type),
|
||||
X_test.cast(Tensor.default_type),
|
||||
Y_test.cast(Tensor.default_type),
|
||||
)
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
|
@ -285,26 +359,60 @@ def train_cifar():
|
|||
params_non_bias = []
|
||||
for params in params_dict:
|
||||
if params_dict[params].requires_grad is not False:
|
||||
if 'bias' in params:
|
||||
if "bias" in params:
|
||||
params_bias.append(params_dict[params])
|
||||
else:
|
||||
params_non_bias.append(params_dict[params])
|
||||
|
||||
opt_bias = optim.SGD(params_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||
opt_non_bias = optim.SGD(params_non_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
|
||||
opt_bias = optim.SGD(
|
||||
params_bias,
|
||||
lr=0.01,
|
||||
momentum=hyp["opt"]["momentum"],
|
||||
nesterov=True,
|
||||
weight_decay=hyp["opt"]["bias_decay"],
|
||||
)
|
||||
opt_non_bias = optim.SGD(
|
||||
params_non_bias,
|
||||
lr=0.01,
|
||||
momentum=hyp["opt"]["momentum"],
|
||||
nesterov=True,
|
||||
weight_decay=hyp["opt"]["non_bias_decay"],
|
||||
)
|
||||
|
||||
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
|
||||
initial_div_factor = hyp['opt']['initial_div_factor']
|
||||
final_lr_ratio = hyp['opt']['final_lr_ratio']
|
||||
pct_start = hyp['opt']['percent_start']
|
||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
initial_div_factor = hyp["opt"]["initial_div_factor"]
|
||||
final_lr_ratio = hyp["opt"]["final_lr_ratio"]
|
||||
pct_start = hyp["opt"]["percent_start"]
|
||||
lr_sched_bias = OneCycleLR(
|
||||
opt_bias,
|
||||
max_lr=hyp["opt"]["bias_lr"],
|
||||
pct_start=pct_start,
|
||||
div_factor=initial_div_factor,
|
||||
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
|
||||
total_steps=STEPS,
|
||||
)
|
||||
lr_sched_non_bias = OneCycleLR(
|
||||
opt_non_bias,
|
||||
max_lr=hyp["opt"]["non_bias_lr"],
|
||||
pct_start=pct_start,
|
||||
div_factor=initial_div_factor,
|
||||
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
|
||||
total_steps=STEPS,
|
||||
)
|
||||
|
||||
loss_batchsize_scaler = 512 / BS
|
||||
|
||||
@TinyJit
|
||||
def train_step_jitted(model, optimizer, lr_scheduler, X, Y):
|
||||
out = model(X)
|
||||
loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
loss = (
|
||||
cross_entropy(
|
||||
out, Y, reduction="none", label_smoothing=hyp["opt"]["label_smoothing"]
|
||||
)
|
||||
.mul(hyp["opt"]["loss_scale_scaler"] * loss_batchsize_scaler)
|
||||
.sum()
|
||||
.div(hyp["opt"]["loss_scale_scaler"])
|
||||
)
|
||||
|
||||
if not getenv("DISABLE_BACKWARD"):
|
||||
# index 0 for bias and 1 for non-bias
|
||||
|
@ -316,11 +424,16 @@ def train_cifar():
|
|||
# sync gradients across ranks
|
||||
bucket, offset = [], 0
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None: bucket.append(v.grad.flatten())
|
||||
if v.grad is not None:
|
||||
bucket.append(v.grad.flatten())
|
||||
grads = collectives.allreduce(Tensor.cat(*bucket))
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None:
|
||||
v.grad.assign(grads[offset:offset+v.grad.numel()].reshape(*v.grad.shape))
|
||||
v.grad.assign(
|
||||
grads[offset : offset + v.grad.numel()].reshape(
|
||||
*v.grad.shape
|
||||
)
|
||||
)
|
||||
offset += v.grad.numel()
|
||||
|
||||
optimizer[0].step()
|
||||
|
@ -331,9 +444,10 @@ def train_cifar():
|
|||
|
||||
def eval_step(model, X, Y):
|
||||
out = model(X, training=False)
|
||||
loss = cross_entropy(out, Y, reduction='mean')
|
||||
loss = cross_entropy(out, Y, reduction="mean")
|
||||
correct = out.argmax(axis=1) == Y.argmax(axis=1)
|
||||
return correct.realize(), loss.realize()
|
||||
|
||||
eval_step_jitted = TinyJit(eval_step)
|
||||
eval_step_ema_jitted = TinyJit(eval_step)
|
||||
|
||||
|
@ -347,7 +461,7 @@ def train_cifar():
|
|||
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
||||
|
||||
model_ema: Optional[modelEMA] = None
|
||||
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
|
||||
projected_ema_decay_val = hyp["ema"]["decay_base"] ** hyp["ema"]["every_n_steps"]
|
||||
i = 0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||
with Tensor.train():
|
||||
|
@ -363,24 +477,37 @@ def train_cifar():
|
|||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
||||
Xt, Yt = (
|
||||
Xt.chunk(min(world_size, 5), 0)[min(rank, 4)],
|
||||
Yt.chunk(min(world_size, 5), 0)[min(rank, 4)],
|
||||
)
|
||||
|
||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||
losses.append(loss.numpy().tolist())
|
||||
corrects.extend(correct.numpy().tolist())
|
||||
if model_ema:
|
||||
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
|
||||
correct_ema, loss_ema = eval_step_ema_jitted(
|
||||
model_ema.net_ema, Xt, Yt
|
||||
)
|
||||
losses_ema.append(loss_ema.numpy().tolist())
|
||||
corrects_ema.extend(correct_ema.numpy().tolist())
|
||||
|
||||
# collect accuracy across ranks
|
||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
||||
if model_ema:
|
||||
correct_sum_ema, correct_len_ema = sum(corrects_ema), len(
|
||||
corrects_ema
|
||||
)
|
||||
if getenv("DIST"):
|
||||
if rank == 0:
|
||||
for j in range(1, min(world_size, 5)):
|
||||
if model_ema:
|
||||
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
||||
(
|
||||
recv_sum,
|
||||
recv_len,
|
||||
recv_sum_ema,
|
||||
recv_len_ema,
|
||||
) = OOB.recv(j)
|
||||
else:
|
||||
recv_sum, recv_len = OOB.recv(j)
|
||||
correct_sum += recv_sum
|
||||
|
@ -390,55 +517,95 @@ def train_cifar():
|
|||
correct_len_ema += recv_len_ema
|
||||
elif rank < min(world_size, 5):
|
||||
if model_ema:
|
||||
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
||||
OOB.send(
|
||||
(
|
||||
correct_sum,
|
||||
correct_len,
|
||||
correct_sum_ema,
|
||||
correct_len_ema,
|
||||
),
|
||||
0,
|
||||
)
|
||||
else:
|
||||
OOB.send((correct_sum, correct_len), 0)
|
||||
|
||||
# only rank 0 prints
|
||||
if rank == 0:
|
||||
acc = correct_sum / correct_len * 100.0
|
||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
if model_ema:
|
||||
acc_ema = correct_sum_ema / correct_len_ema * 100.0
|
||||
print(
|
||||
f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)"
|
||||
)
|
||||
if model_ema:
|
||||
print(
|
||||
f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}"
|
||||
)
|
||||
|
||||
if STEPS == 0 or i==STEPS: break
|
||||
if STEPS == 0 or i == STEPS:
|
||||
break
|
||||
X, Y = next(batcher)
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
GlobalCounters.reset()
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
loss = train_step_jitted(
|
||||
model,
|
||||
[opt_bias, opt_non_bias],
|
||||
[lr_sched_bias, lr_sched_non_bias],
|
||||
X,
|
||||
Y,
|
||||
)
|
||||
et = time.monotonic()
|
||||
loss_cpu = loss.numpy()
|
||||
# EMA for network weights
|
||||
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
||||
if i > hyp["ema"]["steps"] and (i + 1) % hyp["ema"]["every_n_steps"] == 0:
|
||||
if model_ema is None:
|
||||
model_ema = modelEMA(W, model)
|
||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||
model_ema.update(
|
||||
model,
|
||||
Tensor(
|
||||
[
|
||||
projected_ema_decay_val
|
||||
* (i / STEPS) ** hyp["ema"]["decay_pow"]
|
||||
]
|
||||
),
|
||||
)
|
||||
cl = time.monotonic()
|
||||
if not getenv("DIST"):
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
print(
|
||||
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||
)
|
||||
else:
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
print(
|
||||
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||
)
|
||||
st = cl
|
||||
i += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not getenv("DIST"):
|
||||
train_cifar()
|
||||
else: # distributed
|
||||
if getenv("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIP
|
||||
|
||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||
else:
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
|
||||
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
||||
world_size = len(devices)
|
||||
|
||||
# ensure that the batch size is divisible by the number of devices
|
||||
assert BS % world_size == 0, f"batch size {BS} is not divisible by world size {world_size}"
|
||||
assert (
|
||||
BS % world_size == 0
|
||||
), f"batch size {BS} is not divisible by world size {world_size}"
|
||||
|
||||
# ensure that the evaluation batch size is divisible by the number of devices
|
||||
assert EVAL_BS % min(world_size, 5) == 0, f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
|
||||
assert (
|
||||
EVAL_BS % min(world_size, 5) == 0
|
||||
), f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
|
||||
|
||||
# init out-of-band communication
|
||||
dist.init_oob(world_size)
|
||||
|
@ -447,4 +614,5 @@ if __name__ == "__main__":
|
|||
processes = []
|
||||
for rank, device in enumerate(devices):
|
||||
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
|
||||
for p in processes: p.join()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
from pathlib import Path
|
||||
import sys, argparse, json
|
||||
import numpy as np
|
||||
|
||||
np.set_printoptions(linewidth=200)
|
||||
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes
|
||||
from tinygrad import Device
|
||||
|
@ -24,84 +25,225 @@ MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
|
|||
MODEL_PARAMS = {
|
||||
"1": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 11008},
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_heads": 32,
|
||||
"n_layers": 32,
|
||||
"norm_eps": 1e-06,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 13824},
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_heads": 40,
|
||||
"n_layers": 40,
|
||||
"norm_eps": 1e-06,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"30B": {
|
||||
"args": {"dim": 6656, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 17920},
|
||||
"args": {
|
||||
"dim": 6656,
|
||||
"n_heads": 52,
|
||||
"n_layers": 60,
|
||||
"norm_eps": 1e-06,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 17920,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
"65B": {
|
||||
"args": {"dim": 8192, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_heads": 64,
|
||||
"n_layers": 80,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
"2": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 11008},
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_heads": 32,
|
||||
"n_layers": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 13824},
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_heads": 40,
|
||||
"n_layers": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"70B": {
|
||||
"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 28672},
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"n_layers": 80,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 28672,
|
||||
},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
"code": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Python": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 11008},
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Instruct": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Python": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 13824},
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Instruct": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"34B": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_layers": 48,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Python": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_layers": 48,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Instruct": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_layers": 48,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
},
|
||||
"tiny": {
|
||||
"1B": {
|
||||
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632},
|
||||
"args": {
|
||||
"dim": 2048,
|
||||
"n_layers": 22,
|
||||
"n_heads": 32,
|
||||
"n_kv_heads": 4,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 5632,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"1B-Chat": {
|
||||
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632},
|
||||
"args": {
|
||||
"dim": 2048,
|
||||
"n_layers": 22,
|
||||
"n_heads": 32,
|
||||
"n_kv_heads": 4,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32003,
|
||||
"hidden_dim": 5632,
|
||||
},
|
||||
"files": 1,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -111,21 +253,37 @@ def concat_weights(models):
|
|||
disk_tensors = [model[name] for model in models]
|
||||
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
||||
return disk_tensors[0].to(device=Device.DEFAULT)
|
||||
axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
|
||||
axis = (
|
||||
1
|
||||
if name.startswith("tok_embeddings.")
|
||||
or name.endswith(".attention.wo.weight")
|
||||
or name.endswith(".feed_forward.w2.weight")
|
||||
else 0
|
||||
)
|
||||
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
|
||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||
return {name: convert(name) for name in {name: None for model in models for name in model}}
|
||||
|
||||
return {
|
||||
name: convert(name)
|
||||
for name in {name: None for model in models for name in model}
|
||||
}
|
||||
|
||||
|
||||
def load(fn: str):
|
||||
if fn.endswith('.index.json'):
|
||||
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
||||
parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
|
||||
if fn.endswith(".index.json"):
|
||||
with open(fn) as fp:
|
||||
weight_map = json.load(fp)["weight_map"]
|
||||
parts = {
|
||||
n: load(str(Path(fn).parent / Path(n).name))
|
||||
for n in set(weight_map.values())
|
||||
}
|
||||
return {k: parts[n][k] for k, n in weight_map.items()}
|
||||
elif fn.endswith(".safetensors"):
|
||||
return safe_load(fn)
|
||||
else:
|
||||
return torch_load(fn)
|
||||
|
||||
|
||||
class AbsmaxQuantizedLinear:
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
assert bias == False
|
||||
|
@ -139,34 +297,63 @@ class AbsmaxQuantizedLinear:
|
|||
def quantize(tensors):
|
||||
new_tensors = {}
|
||||
for name, v in tensors.items():
|
||||
if "feed_forward" in name or ("attention.w") in name or name == "output.weight":
|
||||
if (
|
||||
"feed_forward" in name
|
||||
or ("attention.w") in name
|
||||
or name == "output.weight"
|
||||
):
|
||||
scale = v.abs().max(axis=1) / 127.0
|
||||
int8_weight = (v.T / scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight
|
||||
new_tensors[name.replace('weight', 'scale')] = scale
|
||||
new_tensors[name.replace("weight", "scale")] = scale
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
|
||||
|
||||
class LLaMa:
|
||||
@staticmethod
|
||||
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
|
||||
def build(
|
||||
model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False
|
||||
):
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
||||
assert (
|
||||
sp_model.vocab_size() == params["args"]["vocab_size"]
|
||||
), f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
||||
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
||||
model = (
|
||||
Transformer(
|
||||
**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT
|
||||
)
|
||||
if quantize
|
||||
else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
||||
)
|
||||
|
||||
if model_path.is_dir():
|
||||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
weights = concat_weights(
|
||||
[
|
||||
load(filename)
|
||||
for filename in [
|
||||
f"{model_path}/consolidated.{i:02d}.pth"
|
||||
for i in range(params["files"])
|
||||
]
|
||||
]
|
||||
)
|
||||
else:
|
||||
weights = load(str(model_path))
|
||||
if "model.embed_tokens.weight" in weights:
|
||||
weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
|
||||
weights = convert_from_huggingface(
|
||||
weights,
|
||||
model,
|
||||
params["args"]["n_heads"],
|
||||
params["args"].get("n_kv_heads", params["args"]["n_heads"]),
|
||||
)
|
||||
|
||||
if quantize:
|
||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||
for _,v in weights.items(): v.realize()
|
||||
for _, v in weights.items():
|
||||
v.realize()
|
||||
load_state_dict(model, weights, strict=False)
|
||||
|
||||
return LLaMa(model, sp_model)
|
||||
|
@ -179,18 +366,23 @@ class LLaMa:
|
|||
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
|
||||
start_pos = 0
|
||||
for i in range(max_length):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).realize()
|
||||
probs = llama.model(
|
||||
Tensor([toks[start_pos:]]), start_pos, temperature
|
||||
).realize()
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
|
||||
if tok == self.tokenizer.eos_id(): break
|
||||
if tok == self.tokenizer.eos_id():
|
||||
break
|
||||
output = self.tokenizer.decode(toks)
|
||||
for s in until:
|
||||
if output.endswith(s): return output[0:-len(s)]
|
||||
if output.endswith(s):
|
||||
return output[0 : -len(s)]
|
||||
return output
|
||||
|
||||
|
||||
# **** main code ****
|
||||
"""
|
||||
test:
|
||||
|
@ -256,21 +448,58 @@ if __name__ == "__main__":
|
|||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
|
||||
parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate")
|
||||
parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
|
||||
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run LLaMA in tinygrad",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Phrase to start with. Without this, it goes into chatbot mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=1000, help="Max number of tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--personality",
|
||||
type=str,
|
||||
default="Stacy",
|
||||
help="Personality, can be Stacy, George, Gary, or Lexie",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.7, help="Temperature in the softmax"
|
||||
)
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||
parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof")
|
||||
parser.add_argument("--gen", default="1", help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""")
|
||||
parser.add_argument("--size", type=str, default=None, help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""")
|
||||
parser.add_argument("--quantize", action="store_true", help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
parser.add_argument(
|
||||
"--profile", action="store_true", help="Output profile data to out.prof"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen",
|
||||
default="1",
|
||||
help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default=None,
|
||||
help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize", action="store_true", help="Quantize the weights to int8 in memory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.gen not in MODEL_PARAMS: raise ValueError("Invalid model generation")
|
||||
if args.size is None: args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
|
||||
if args.gen not in MODEL_PARAMS:
|
||||
raise ValueError("Invalid model generation")
|
||||
if args.size is None:
|
||||
args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
|
||||
chatbot = args.prompt == None
|
||||
|
||||
# *** prompt engineers work here ****
|
||||
|
@ -294,9 +523,13 @@ After you are done speaking, output [EOS]. You are not the User.
|
|||
user_delim = "\nUser: "
|
||||
resp_delim = "Stacy: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
elif args.personality.lower() == "george":
|
||||
print("WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter.")
|
||||
print(
|
||||
"WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter."
|
||||
)
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User
|
||||
You are an AI version of George Hotz. You act as much as you can like George.
|
||||
You are one of the greatest computer experts in the world.
|
||||
|
@ -312,13 +545,15 @@ After you are done speaking, output [EOS]. You are not the User.
|
|||
"What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm",
|
||||
"What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer",
|
||||
"How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version",
|
||||
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk"
|
||||
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk",
|
||||
}
|
||||
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "George: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
elif args.personality.lower() == "gary":
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User
|
||||
You are Gary!
|
||||
|
@ -331,13 +566,15 @@ After you are done speaking, output [EOS]. You are not the User.
|
|||
"""
|
||||
examples = {
|
||||
"What is your name?": "I am Gary. I used to sell cars.",
|
||||
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla"
|
||||
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla",
|
||||
}
|
||||
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "Gary: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
elif args.personality.lower() == "lexie":
|
||||
pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad
|
||||
You are Lexie!
|
||||
|
@ -352,21 +589,34 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
examples = {
|
||||
"hi lexie": "hi chad, glad we finally met up!",
|
||||
"you look better than your pictures": "thanks! are you subscribed to my onlyfans?",
|
||||
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress"
|
||||
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress",
|
||||
}
|
||||
|
||||
user_delim = "\nChad: "
|
||||
resp_delim = "Lexie: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
|
||||
# *** prompt engineers stop here ****
|
||||
|
||||
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen]
|
||||
MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
||||
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
|
||||
MODEL_PATH = (
|
||||
args.model
|
||||
or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
||||
)
|
||||
TOKENIZER_PATH = (
|
||||
MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent
|
||||
) / "tokenizer.model"
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
|
||||
llama = LLaMa.build(
|
||||
MODEL_PATH,
|
||||
TOKENIZER_PATH,
|
||||
model_gen=args.gen,
|
||||
model_size=args.size,
|
||||
quantize=args.quantize,
|
||||
)
|
||||
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
|
||||
|
||||
if chatbot:
|
||||
|
@ -375,7 +625,9 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
|
||||
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
||||
with Timing():
|
||||
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: outputs are not used
|
||||
llama.model(
|
||||
Tensor([toks]), 0, args.temperature
|
||||
).realize() # NOTE: outputs are not used
|
||||
start_pos = len(toks)
|
||||
else:
|
||||
# non chat bot mode
|
||||
|
@ -403,14 +655,37 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
for i in range(args.count):
|
||||
GlobalCounters.reset()
|
||||
|
||||
if args.timing or args.profile: print("")
|
||||
if args.timing or args.profile:
|
||||
print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Profiling(enabled=args.profile):
|
||||
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
||||
with Timing(
|
||||
"total ",
|
||||
enabled=args.timing,
|
||||
on_exit=lambda x: f", {1e9/x:.2f} tok/sec",
|
||||
):
|
||||
with Timing(
|
||||
"ran model in ",
|
||||
on_exit=(
|
||||
lambda et: (
|
||||
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
|
||||
+ (
|
||||
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
)
|
||||
if DEBUG
|
||||
else None,
|
||||
enabled=args.timing,
|
||||
):
|
||||
probs = llama.model(
|
||||
Tensor([toks[start_pos:]]), start_pos, args.temperature
|
||||
).realize()
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().item()
|
||||
|
||||
|
@ -427,5 +702,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
outputted = cur
|
||||
|
||||
# stop after you have your answer
|
||||
if chatbot and outputted.endswith(end_delim): break
|
||||
if not chatbot: break
|
||||
if chatbot and outputted.endswith(end_delim):
|
||||
break
|
||||
if not chatbot:
|
||||
break
|
||||
|
|
|
@ -63,21 +63,23 @@ class Normalize:
|
|||
image = Ft.normalize(image, mean=self.mean, std=self.std)
|
||||
return image
|
||||
|
||||
|
||||
transforms = lambda size_scale: T.Compose(
|
||||
[
|
||||
Resize(int(800 * size_scale), int(1333 * size_scale)),
|
||||
T.ToTensor(),
|
||||
Normalize(
|
||||
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
|
||||
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_bgr255=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def expand_boxes(boxes, scale):
|
||||
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
|
||||
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
|
||||
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
|
||||
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
|
||||
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
|
||||
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
|
||||
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
|
||||
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
|
||||
|
||||
w_half *= scale
|
||||
h_half *= scale
|
||||
|
@ -118,7 +120,7 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
|
|||
mask = mask.expand((1, 1, -1, -1))
|
||||
|
||||
mask = mask.to(torch.float32)
|
||||
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
|
||||
mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
|
||||
mask = mask[0][0]
|
||||
|
||||
if thresh >= 0:
|
||||
|
@ -169,11 +171,13 @@ class Masker:
|
|||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
|
||||
def select_top_predictions(predictions, confidence_threshold=0.9):
|
||||
scores = predictions.get_field("scores").numpy()
|
||||
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
|
||||
return predictions[keep]
|
||||
|
||||
|
||||
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
|
||||
image = transforms(size_scale)(original_image).numpy()
|
||||
image = Tensor(image, requires_grad=False)
|
||||
|
@ -189,6 +193,7 @@ def compute_prediction(original_image, model, confidence_threshold, size_scale=1
|
|||
prediction.add_field("mask", masks)
|
||||
return prediction
|
||||
|
||||
|
||||
def compute_prediction_batched(batch, model, size_scale=1.0):
|
||||
imgs = []
|
||||
for img in batch:
|
||||
|
@ -198,21 +203,25 @@ def compute_prediction_batched(batch, model, size_scale=1.0):
|
|||
del image
|
||||
return predictions
|
||||
|
||||
|
||||
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
|
||||
|
||||
|
||||
def findContours(*args, **kwargs):
|
||||
if cv2.__version__.startswith('4'):
|
||||
if cv2.__version__.startswith("4"):
|
||||
contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
elif cv2.__version__.startswith('3'):
|
||||
elif cv2.__version__.startswith("3"):
|
||||
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
return contours, hierarchy
|
||||
|
||||
|
||||
def compute_colors_for_labels(labels):
|
||||
l = labels[:, None]
|
||||
colors = l * palette
|
||||
colors = (colors % 255).astype("uint8")
|
||||
return colors
|
||||
|
||||
|
||||
def overlay_mask(image, predictions):
|
||||
image = np.asarray(image)
|
||||
masks = predictions.get_field("mask").numpy()
|
||||
|
@ -231,17 +240,92 @@ def overlay_mask(image, predictions):
|
|||
|
||||
return composite
|
||||
|
||||
|
||||
CATEGORIES = [
|
||||
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
|
||||
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
|
||||
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
|
||||
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
|
||||
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
|
||||
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
|
||||
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
||||
"__background",
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorcycle",
|
||||
"airplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"couch",
|
||||
"potted plant",
|
||||
"bed",
|
||||
"dining table",
|
||||
"toilet",
|
||||
"tv",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
]
|
||||
|
||||
|
||||
def overlay_boxes(image, predictions):
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
boxes = predictions.bbox
|
||||
|
@ -258,6 +342,7 @@ def overlay_boxes(image, predictions):
|
|||
|
||||
return image
|
||||
|
||||
|
||||
def overlay_class_names(image, predictions):
|
||||
scores = predictions.get_field("scores").numpy().tolist()
|
||||
labels = predictions.get_field("labels").numpy().tolist()
|
||||
|
@ -269,26 +354,35 @@ def overlay_class_names(image, predictions):
|
|||
x, y = box[:2]
|
||||
s = template.format(label, score)
|
||||
x, y = int(x), int(y)
|
||||
cv2.putText(
|
||||
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
|
||||
)
|
||||
cv2.putText(image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--image', type=str, help="Path of the image to run")
|
||||
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
|
||||
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
|
||||
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run MaskRCNN",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--image", type=str, help="Path of the image to run")
|
||||
parser.add_argument(
|
||||
"--threshold", type=float, default=0.7, help="Detector threshold"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size_scale", type=float, default=1.0, help="Image resize multiplier"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out", type=str, default="/tmp/rendered.png", help="Output filename"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||
model_tiny = MaskRCNN(resnet)
|
||||
model_tiny.load_from_pretrained()
|
||||
img = Image.open(args.image)
|
||||
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
|
||||
top_result_tiny = compute_prediction(
|
||||
img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale
|
||||
)
|
||||
bbox_image = overlay_boxes(img, top_result_tiny)
|
||||
mask_image = overlay_mask(bbox_image, top_result_tiny)
|
||||
final_image = overlay_class_names(mask_image, top_result_tiny)
|
||||
|
|
|
@ -3,6 +3,7 @@ import unicodedata
|
|||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
|
||||
def gaussian_kernel(n, std):
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
|
@ -12,14 +13,18 @@ def gaussian_kernel(n, std):
|
|||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
|
||||
|
||||
def prepare_arrays(image, roi_shape=(128, 128, 128)):
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
image_shape = list(image.shape[2:])
|
||||
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
|
||||
norm_map = np.zeros_like(result)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(
|
||||
norm_map.dtype
|
||||
)
|
||||
return result, norm_map, norm_patch
|
||||
|
||||
|
||||
def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
assert 0 < overlap_factor < 1
|
||||
|
@ -31,25 +36,35 @@ def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
|
|||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
yield i, j, k
|
||||
|
||||
|
||||
def _get_best_indices(logits, n_best_size):
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
return list(map(lambda x: x[0], index_and_score))[:n_best_size]
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
|
||||
if (
|
||||
(cp := ord(char)) in range(33, 48)
|
||||
or cp in range(58, 65)
|
||||
or cp in range(91, 97)
|
||||
or cp in range(123, 127)
|
||||
):
|
||||
return True
|
||||
return unicodedata.category(char).startswith("P")
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
return unicodedata.category(char) == "Zs"
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
return unicodedata.category(char).startswith("C")
|
||||
|
||||
|
||||
def _run_split_on_punc(text):
|
||||
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
||||
return [text]
|
||||
|
@ -66,6 +81,7 @@ def _run_split_on_punc(text):
|
|||
output[-1].append(char)
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
|
||||
def _run_strip_accents(text):
|
||||
output = []
|
||||
for char in unicodedata.normalize("NFD", text):
|
||||
|
@ -73,13 +89,15 @@ def _run_strip_accents(text):
|
|||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def _clean_text(text):
|
||||
output = []
|
||||
for char in text:
|
||||
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
|
||||
if not ((cp := ord(char)) == 0 or cp == 0xFFFD or _is_control(char)):
|
||||
output.append(" " if _is_whitespace(char) else char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def _get_final_text(pred_text, orig_text):
|
||||
def _strip_spaces(text):
|
||||
ns_text = ""
|
||||
|
@ -128,32 +146,46 @@ def _get_final_text(pred_text, orig_text):
|
|||
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
def get_bert_qa_prediction(features, example, start_end_logits):
|
||||
prelim_predictions = []
|
||||
for i, feature in enumerate(features):
|
||||
for start_index in _get_best_indices(start_end_logits[i][0], 20):
|
||||
for end_index in _get_best_indices(start_end_logits[i][1], 20):
|
||||
if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]):
|
||||
if start_index >= len(feature["tokens"]) or end_index >= len(
|
||||
feature["tokens"]
|
||||
):
|
||||
continue
|
||||
if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]:
|
||||
if (
|
||||
start_index not in feature["token_to_orig_map"]
|
||||
or end_index not in feature["token_to_orig_map"]
|
||||
):
|
||||
continue
|
||||
if not feature["token_is_max_context"].get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index or end_index - start_index + 1 > 30:
|
||||
continue
|
||||
|
||||
prelim_predictions.append({
|
||||
prelim_predictions.append(
|
||||
{
|
||||
"feature_index": i,
|
||||
"start_index": start_index,
|
||||
"end_index": end_index,
|
||||
"start_logit": start_end_logits[i][0, start_index],
|
||||
"end_logit": start_end_logits[i][1, end_index]
|
||||
})
|
||||
predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True)
|
||||
"end_logit": start_end_logits[i][1, end_index],
|
||||
}
|
||||
)
|
||||
predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x["start_logit"] + x["end_logit"]),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
if len(predictions) > 0:
|
||||
feature = features[predictions[0]["feature_index"]]
|
||||
tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)]
|
||||
tok_tokens = feature["tokens"][
|
||||
predictions[0]["start_index"] : (predictions[0]["end_index"] + 1)
|
||||
]
|
||||
orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
|
||||
orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
|
||||
orig_tokens = example["context"][orig_doc_start : (orig_doc_end + 1)]
|
||||
|
|
|
@ -3,6 +3,7 @@ import string
|
|||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
|
||||
def levenshtein(a, b):
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
|
@ -20,6 +21,7 @@ def levenshtein(a, b):
|
|||
|
||||
return current[n]
|
||||
|
||||
|
||||
def word_error_rate(x, y):
|
||||
scores = words = 0
|
||||
for h, r in zip(x, y):
|
||||
|
@ -29,12 +31,14 @@ def word_error_rate(x, y):
|
|||
scores += levenshtein(h_list, r_list)
|
||||
return float(scores) / words, float(scores), words
|
||||
|
||||
|
||||
def one_hot(arr, num_classes=3):
|
||||
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
|
||||
arr = res.reshape(list(arr.shape) + [num_classes])
|
||||
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
|
||||
return arr
|
||||
|
||||
|
||||
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
|
||||
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
||||
prediction = prediction.argmax(axis=channel_axis)
|
||||
|
@ -42,14 +46,18 @@ def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr
|
|||
intersection = np.sum(prediction * target, axis=reduce_axis)
|
||||
target_sum = np.sum(target, axis=reduce_axis)
|
||||
prediction_sum = np.sum(prediction, axis=reduce_axis)
|
||||
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
|
||||
result = (2.0 * intersection + smooth_nr) / (
|
||||
target_sum + prediction_sum + smooth_dr
|
||||
)
|
||||
return result[0]
|
||||
|
||||
|
||||
def normalize_string(s):
|
||||
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
||||
s = re.sub(r'\b(a|an|the)\b', ' ', s)
|
||||
s = re.sub(r"\b(a|an|the)\b", " ", s)
|
||||
return " ".join(s.split())
|
||||
|
||||
|
||||
def f1_score(x, y):
|
||||
xt = normalize_string(x).split()
|
||||
yt = normalize_string(y).split()
|
||||
|
|
|
@ -6,15 +6,18 @@ from tinygrad.jit import TinyJit
|
|||
from tinygrad.helpers import getenv, dtypes, GlobalCounters
|
||||
from examples.mlperf import helpers
|
||||
|
||||
|
||||
def eval_resnet():
|
||||
# Resnet50-v1.5
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.models.resnet import ResNet50
|
||||
|
||||
mdl = ResNet50()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
|
||||
def input_fixup(x):
|
||||
x = x.permute([0, 3, 1, 2]).cast(dtypes.float32) / 255.0
|
||||
x -= input_mean
|
||||
|
@ -48,14 +51,18 @@ def eval_resnet():
|
|||
et = time.perf_counter()
|
||||
n += (t == y).sum()
|
||||
d += len(t)
|
||||
print(f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS")
|
||||
print(
|
||||
f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS"
|
||||
)
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_unet3d():
|
||||
# UNet3D
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, sliding_window_inference
|
||||
from examples.mlperf.metrics import get_dice_score
|
||||
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
s = 0
|
||||
|
@ -69,15 +76,18 @@ def eval_unet3d():
|
|||
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_retinanet():
|
||||
# RetinaNet with ResNeXt50_32X4D
|
||||
from extra.models.resnet import ResNeXt50_32X4D
|
||||
from extra.models.retinanet import RetinaNet
|
||||
|
||||
mdl = RetinaNet(ResNeXt50_32X4D())
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
|
||||
def input_fixup(x):
|
||||
x = x.permute([0, 3, 1, 2]) / 255.0
|
||||
x -= input_mean
|
||||
|
@ -88,11 +98,18 @@ def eval_retinanet():
|
|||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
coco = COCO(openimages())
|
||||
coco_eval = COCOeval(coco, iouType="bbox")
|
||||
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
|
||||
coco_evalimgs, evaluated_imgs, ncats, narea = (
|
||||
[],
|
||||
[],
|
||||
len(coco_eval.params.catIds),
|
||||
len(coco_eval.params.areaRng),
|
||||
)
|
||||
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
|
||||
|
||||
n, bs = 0, 8
|
||||
|
@ -106,19 +123,35 @@ def eval_retinanet():
|
|||
mdlrun.jit_cache = None
|
||||
outs = mdl(input_fixup(dat)).numpy()
|
||||
et = time.perf_counter()
|
||||
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
|
||||
predictions = mdl.postprocess_detections(
|
||||
outs,
|
||||
input_size=dat.shape[1:3],
|
||||
orig_image_sizes=[t["image_size"] for t in targets],
|
||||
)
|
||||
ext = time.perf_counter()
|
||||
n += len(targets)
|
||||
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
|
||||
print(
|
||||
f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing"
|
||||
)
|
||||
img_ids = [t["image_id"] for t in targets]
|
||||
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score}
|
||||
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
|
||||
coco_results = [
|
||||
{
|
||||
"image_id": targets[i]["image_id"],
|
||||
"category_id": label,
|
||||
"bbox": box,
|
||||
"score": score,
|
||||
}
|
||||
for i, prediction in enumerate(predictions)
|
||||
for box, score, label in zip(*prediction.values())
|
||||
]
|
||||
with redirect_stdout(None):
|
||||
coco_eval.cocoDt = coco.loadRes(coco_results)
|
||||
coco_eval.params.imgIds = img_ids
|
||||
coco_eval.evaluate()
|
||||
evaluated_imgs.extend(img_ids)
|
||||
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
coco_evalimgs.append(
|
||||
np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids))
|
||||
)
|
||||
st = time.perf_counter()
|
||||
|
||||
coco_eval.params.imgIds = evaluated_imgs
|
||||
|
@ -127,16 +160,47 @@ def eval_retinanet():
|
|||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
|
||||
|
||||
def eval_rnnt():
|
||||
# RNN-T
|
||||
from extra.models.rnnt import RNNT
|
||||
|
||||
mdl = RNNT()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
from extra.datasets.librispeech import iterate
|
||||
from examples.mlperf.metrics import word_error_rate
|
||||
|
||||
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
|
||||
LABELS = [
|
||||
" ",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
"g",
|
||||
"h",
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"o",
|
||||
"p",
|
||||
"q",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"u",
|
||||
"v",
|
||||
"w",
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"'",
|
||||
]
|
||||
|
||||
c = 0
|
||||
scores = 0
|
||||
|
@ -149,16 +213,20 @@ def eval_rnnt():
|
|||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
for n, t in enumerate(tt):
|
||||
tnp = np.array(t)
|
||||
_, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
|
||||
_, scores_, words_ = word_error_rate(
|
||||
["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]]
|
||||
)
|
||||
scores += scores_
|
||||
words += words_
|
||||
c += len(tt)
|
||||
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_bert():
|
||||
# Bert-QA
|
||||
from extra.models.bert import BertForQuestionAnswering
|
||||
|
||||
mdl = BertForQuestionAnswering()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
|
@ -180,9 +248,17 @@ def eval_bert():
|
|||
mt = time.perf_counter()
|
||||
outs = []
|
||||
for x in X:
|
||||
outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy())
|
||||
outs.append(
|
||||
run(
|
||||
Tensor(x["input_ids"]),
|
||||
Tensor(x["input_mask"]),
|
||||
Tensor(x["segment_ids"]),
|
||||
).numpy()
|
||||
)
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features")
|
||||
print(
|
||||
f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features"
|
||||
)
|
||||
|
||||
pred = get_bert_qa_prediction(X, Y, outs)
|
||||
print(f"pred: {pred}\nans: {Y['answers']}")
|
||||
|
@ -192,17 +268,27 @@ def eval_bert():
|
|||
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_mrcnn():
|
||||
from tqdm import tqdm
|
||||
from extra.models.mask_rcnn import MaskRCNN
|
||||
from extra.models.resnet import ResNet
|
||||
from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
||||
from extra.datasets.coco import (
|
||||
BASEDIR,
|
||||
images,
|
||||
convert_prediction_to_coco_bbox,
|
||||
convert_prediction_to_coco_mask,
|
||||
accumulate_predictions_for_coco,
|
||||
evaluate_predictions_on_coco,
|
||||
iterate,
|
||||
)
|
||||
from examples.mask_rcnn import compute_prediction_batched, Image
|
||||
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
bbox_output = '/tmp/results_bbox.json'
|
||||
mask_output = '/tmp/results_mask.json'
|
||||
bbox_output = "/tmp/results_bbox.json"
|
||||
mask_output = "/tmp/results_mask.json"
|
||||
|
||||
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
||||
accumulate_predictions_for_coco([], mask_output, rm=True)
|
||||
|
@ -213,12 +299,12 @@ def eval_mrcnn():
|
|||
for batch in tqdm(iterate(images, bs=bs), total=len(images) // bs):
|
||||
batch_imgs = []
|
||||
for image_row in batch:
|
||||
image_name = image_row['file_name']
|
||||
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
|
||||
image_name = image_row["file_name"]
|
||||
img = Image.open(BASEDIR / f"val2017/{image_name}").convert("RGB")
|
||||
batch_imgs.append(img)
|
||||
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
||||
for image_row, result in zip(batch, batch_result):
|
||||
image_name = image_row['file_name']
|
||||
image_name = image_row["file_name"]
|
||||
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
||||
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
||||
accumulate_predictions_for_coco(box_pred, bbox_output)
|
||||
|
@ -226,8 +312,9 @@ def eval_mrcnn():
|
|||
del batch_imgs
|
||||
del batch_result
|
||||
|
||||
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
|
||||
evaluate_predictions_on_coco(mask_output, iou_type='segm')
|
||||
evaluate_predictions_on_coco(bbox_output, iou_type="bbox")
|
||||
evaluate_predictions_on_coco(mask_output, iou_type="segm")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# inference only
|
||||
|
|
|
@ -3,46 +3,60 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import GlobalCounters, getenv
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_model(model, *inputs):
|
||||
GlobalCounters.reset()
|
||||
out = model(*inputs)
|
||||
if isinstance(out, Tensor): out = out.numpy()
|
||||
if isinstance(out, Tensor):
|
||||
out = out.numpy()
|
||||
# TODO: return event future to still get the time_sum_s without DEBUG=2
|
||||
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
|
||||
print(
|
||||
f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms"
|
||||
)
|
||||
|
||||
|
||||
def spec_resnet():
|
||||
# Resnet50-v1.5
|
||||
from extra.models.resnet import ResNet50
|
||||
|
||||
mdl = ResNet50()
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
test_model(mdl, img)
|
||||
|
||||
|
||||
def spec_retinanet():
|
||||
# Retinanet with ResNet backbone
|
||||
from extra.models.resnet import ResNet50
|
||||
from extra.models.retinanet import RetinaNet
|
||||
|
||||
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
test_model(mdl, img)
|
||||
|
||||
|
||||
def spec_unet3d():
|
||||
# 3D UNET
|
||||
from extra.models.unet3d import UNet3D
|
||||
|
||||
mdl = UNet3D()
|
||||
# mdl.load_from_pretrained()
|
||||
img = Tensor.randn(1, 1, 128, 128, 128)
|
||||
test_model(mdl, img)
|
||||
|
||||
|
||||
def spec_rnnt():
|
||||
from extra.models.rnnt import RNNT
|
||||
|
||||
mdl = RNNT()
|
||||
# mdl.load_from_pretrained()
|
||||
x = Tensor.randn(220, 1, 240)
|
||||
y = Tensor.randn(1, 220)
|
||||
test_model(mdl, x, y)
|
||||
|
||||
|
||||
def spec_bert():
|
||||
from extra.models.bert import BertForQuestionAnswering
|
||||
|
||||
mdl = BertForQuestionAnswering()
|
||||
# mdl.load_from_pretrained()
|
||||
x = Tensor.randn(1, 384)
|
||||
|
@ -50,13 +64,16 @@ def spec_bert():
|
|||
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
||||
test_model(mdl, x, am, tt)
|
||||
|
||||
|
||||
def spec_mrcnn():
|
||||
from extra.models.mask_rcnn import MaskRCNN, ResNet
|
||||
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
# mdl.load_from_pretrained()
|
||||
x = Tensor.randn(3, 224, 224)
|
||||
test_model(mdl, [x])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# inference only for now
|
||||
Tensor.training = False
|
||||
|
@ -67,4 +84,3 @@ if __name__ == "__main__":
|
|||
if nm in globals():
|
||||
print(f"testing {m}")
|
||||
globals()[nm]()
|
||||
|
||||
|
|
|
@ -1,36 +1,43 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
||||
def train_resnet():
|
||||
# TODO: Resnet50-v1.5
|
||||
pass
|
||||
|
||||
|
||||
def train_retinanet():
|
||||
# TODO: Retinanet
|
||||
pass
|
||||
|
||||
|
||||
def train_unet3d():
|
||||
# TODO: Unet3d
|
||||
pass
|
||||
|
||||
|
||||
def train_rnnt():
|
||||
# TODO: RNN-T
|
||||
pass
|
||||
|
||||
|
||||
def train_bert():
|
||||
# TODO: BERT
|
||||
pass
|
||||
|
||||
|
||||
def train_maskrcnn():
|
||||
# TODO: Mask RCNN
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Tensor.train():
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(
|
||||
","
|
||||
):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
globals()[nm]()
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from tinygrad.helpers import getenv
|
|||
from tinygrad.nn import optim
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
|
||||
class LinearGen:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.scaled_uniform(128, 256)
|
||||
|
@ -23,6 +24,7 @@ class LinearGen:
|
|||
x = x.dot(self.l4).tanh()
|
||||
return x
|
||||
|
||||
|
||||
class LinearDisc:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.scaled_uniform(784, 1024)
|
||||
|
@ -38,16 +40,21 @@ class LinearDisc:
|
|||
x = x.dot(self.l4).log_softmax()
|
||||
return x
|
||||
|
||||
|
||||
def make_batch(images):
|
||||
sample = np.random.randint(0, len(images), size=(batch_size))
|
||||
image_b = images[sample].reshape(-1, 28 * 28).astype(np.float32) / 127.5 - 1.0
|
||||
return Tensor(image_b)
|
||||
|
||||
|
||||
def make_labels(bs, col, val=-2.0):
|
||||
y = np.zeros((bs, 2), np.float32)
|
||||
y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
|
||||
y[
|
||||
range(bs), [col] * bs
|
||||
] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
|
||||
return Tensor(y)
|
||||
|
||||
|
||||
def train_discriminator(optimizer, data_real, data_fake):
|
||||
real_labels = make_labels(batch_size, 1)
|
||||
fake_labels = make_labels(batch_size, 0)
|
||||
|
@ -61,6 +68,7 @@ def train_discriminator(optimizer, data_real, data_fake):
|
|||
optimizer.step()
|
||||
return (loss_real + loss_fake).numpy()
|
||||
|
||||
|
||||
def train_generator(optimizer, data_fake):
|
||||
real_labels = make_labels(batch_size, 1)
|
||||
optimizer.zero_grad()
|
||||
|
@ -70,6 +78,7 @@ def train_generator(optimizer, data_fake):
|
|||
optimizer.step()
|
||||
return loss.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# data for training and validation
|
||||
images_real = np.vstack(fetch_mnist()[::2])
|
||||
|
@ -85,7 +94,9 @@ if __name__ == "__main__":
|
|||
output_dir = Path(".").resolve() / "outputs"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
# optimizers
|
||||
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_g = optim.Adam(
|
||||
get_parameters(generator), lr=0.0002, b1=0.5
|
||||
) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
||||
# training loop
|
||||
for epoch in (t := trange(epochs)):
|
||||
|
@ -102,6 +113,11 @@ if __name__ == "__main__":
|
|||
if (epoch + 1) % sample_interval == 0:
|
||||
fake_images = generator.forward(ds_noise).detach().numpy()
|
||||
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
|
||||
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
|
||||
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
|
||||
save_image(
|
||||
make_grid(torch.tensor(fake_images)),
|
||||
output_dir / f"image_{epoch+1}.jpg",
|
||||
)
|
||||
t.set_description(
|
||||
f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}"
|
||||
)
|
||||
print("Training Completed!")
|
||||
|
|
|
@ -9,10 +9,12 @@ from tinygrad.helpers import getenv
|
|||
from extra.datasets import fetch_mnist
|
||||
from extra.augment import augment_img
|
||||
from extra.training import train, evaluate
|
||||
|
||||
GPU = getenv("GPU")
|
||||
QUICK = getenv("QUICK")
|
||||
DEBUG = getenv("DEBUG")
|
||||
|
||||
|
||||
class SqueezeExciteBlock2D:
|
||||
def __init__(self, filters):
|
||||
self.filters = filters
|
||||
|
@ -22,7 +24,9 @@ class SqueezeExciteBlock2D:
|
|||
self.bias2 = Tensor.scaled_uniform(1, self.filters)
|
||||
|
||||
def __call__(self, input):
|
||||
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
|
||||
se = input.avg_pool2d(
|
||||
kernel_size=(input.shape[2], input.shape[3])
|
||||
) # GlobalAveragePool2D
|
||||
se = se.reshape(shape=(-1, self.filters))
|
||||
se = se.dot(self.weight1) + self.bias1
|
||||
se = se.relu()
|
||||
|
@ -31,12 +35,16 @@ class SqueezeExciteBlock2D:
|
|||
se = input.mul(se)
|
||||
return se
|
||||
|
||||
|
||||
class ConvBlock:
|
||||
def __init__(self, h, w, inp, filters=128, conv=3):
|
||||
self.h, self.w = h, w
|
||||
self.inp = inp
|
||||
# init weights
|
||||
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
|
||||
self.cweights = [
|
||||
Tensor.scaled_uniform(filters, inp if i == 0 else filters, conv, conv)
|
||||
for i in range(3)
|
||||
]
|
||||
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
|
||||
# init layers
|
||||
self._bn = BatchNorm2d(128)
|
||||
|
@ -50,9 +58,14 @@ class ConvBlock:
|
|||
x = self._seb(x)
|
||||
return x
|
||||
|
||||
|
||||
class BigConvNet:
|
||||
def __init__(self):
|
||||
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
|
||||
self.conv = [
|
||||
ConvBlock(28, 28, 1),
|
||||
ConvBlock(28, 28, 128),
|
||||
ConvBlock(14, 14, 128),
|
||||
]
|
||||
self.weight1 = Tensor.scaled_uniform(128, 10)
|
||||
self.weight2 = Tensor.scaled_uniform(128, 10)
|
||||
|
||||
|
@ -63,19 +76,19 @@ class BigConvNet:
|
|||
for par in pars:
|
||||
print(par.shape)
|
||||
no_pars += np.prod(par.shape)
|
||||
print('no of parameters', no_pars)
|
||||
print("no of parameters", no_pars)
|
||||
return pars
|
||||
else:
|
||||
return get_parameters(self)
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename+'.npy', 'wb') as f:
|
||||
with open(filename + ".npy", "wb") as f:
|
||||
for par in get_parameters(self):
|
||||
# if par.requires_grad:
|
||||
np.save(f, par.numpy())
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename+'.npy', 'rb') as f:
|
||||
with open(filename + ".npy", "rb") as f:
|
||||
for par in get_parameters(self):
|
||||
# if par.requires_grad:
|
||||
try:
|
||||
|
@ -83,7 +96,7 @@ class BigConvNet:
|
|||
if GPU:
|
||||
par.gpu()
|
||||
except:
|
||||
print('Could not load parameter')
|
||||
print("Could not load parameter")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv[0](x)
|
||||
|
@ -102,7 +115,10 @@ if __name__ == "__main__":
|
|||
BS = 32
|
||||
|
||||
lmbd = 0.00025
|
||||
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
||||
lossfn = (
|
||||
lambda out, y: out.sparse_categorical_crossentropy(y)
|
||||
+ lmbd * (model.weight1.abs() + model.weight2.abs()).sum()
|
||||
)
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
|
@ -133,4 +149,4 @@ if __name__ == "__main__":
|
|||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')
|
||||
model.save(f"examples/checkpoint{accuracy * 1e6:.0f}")
|
||||
|
|
|
@ -6,14 +6,14 @@ from tinygrad.nn.state import get_parameters
|
|||
|
||||
if __name__ == "__main__":
|
||||
with Tensor.train():
|
||||
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]): t.realize()
|
||||
for t in get_parameters([x, conv, bn]):
|
||||
t.realize()
|
||||
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -7,9 +7,17 @@ import soundfile
|
|||
import numpy as np
|
||||
import parselmouth
|
||||
|
||||
|
||||
class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
|
||||
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm"
|
||||
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = (
|
||||
hop_length,
|
||||
f0_min,
|
||||
f0_max,
|
||||
sampling_rate,
|
||||
"pm",
|
||||
)
|
||||
|
||||
def interpolate_f0(self, f0):
|
||||
vuv_vector = np.zeros_like(f0, dtype=np.float32)
|
||||
vuv_vector[f0 > 0.0] = 1.0
|
||||
|
@ -19,79 +27,142 @@ class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
|
|||
nzindex = nzindex.astype(np.float32)
|
||||
time_org = self.hop_length / self.sampling_rate * nzindex
|
||||
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
|
||||
if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
|
||||
if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
|
||||
if data.shape[0] <= 0:
|
||||
return np.zeros(f0.shape[0], dtype=np.float32), vuv_vector
|
||||
if data.shape[0] == 1:
|
||||
return np.ones(f0.shape[0], dtype=np.float32) * f0[0], vuv_vector
|
||||
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
|
||||
return f0, vuv_vector
|
||||
|
||||
def compute_f0(self, wav, p_len=None):
|
||||
x = wav
|
||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
if p_len is None:
|
||||
p_len = x.shape[0] // self.hop_length
|
||||
else:
|
||||
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = parselmouth.Sound(x, self.sampling_rate) \
|
||||
.to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \
|
||||
.selected_array['frequency']
|
||||
f0 = (
|
||||
parselmouth.Sound(x, self.sampling_rate)
|
||||
.to_pitch_ac(
|
||||
time_step=time_step / 1000,
|
||||
voicing_threshold=0.6,
|
||||
pitch_floor=self.f0_min,
|
||||
pitch_ceiling=self.f0_max,
|
||||
)
|
||||
.selected_array["frequency"]
|
||||
)
|
||||
pad_size = (p_len - len(f0) + 1) // 2
|
||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
||||
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
||||
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
||||
f0, uv = self.interpolate_f0(f0)
|
||||
return f0
|
||||
|
||||
def compute_f0_uv(self, wav, p_len=None):
|
||||
x = wav
|
||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
if p_len is None:
|
||||
p_len = x.shape[0] // self.hop_length
|
||||
else:
|
||||
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
|
||||
time_step=time_step / 1000, voicing_threshold=0.6,
|
||||
pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
|
||||
f0 = (
|
||||
parselmouth.Sound(x, self.sampling_rate)
|
||||
.to_pitch_ac(
|
||||
time_step=time_step / 1000,
|
||||
voicing_threshold=0.6,
|
||||
pitch_floor=self.f0_min,
|
||||
pitch_ceiling=self.f0_max,
|
||||
)
|
||||
.selected_array["frequency"]
|
||||
)
|
||||
pad_size = (p_len - len(f0) + 1) // 2
|
||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
||||
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
||||
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
||||
f0, uv = self.interpolate_f0(f0)
|
||||
return f0, uv
|
||||
|
||||
|
||||
class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||
def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000):
|
||||
def __init__(
|
||||
self,
|
||||
sr: int,
|
||||
threshold: float = -40.0,
|
||||
min_length: int = 5000,
|
||||
min_interval: int = 300,
|
||||
hop_size: int = 20,
|
||||
max_sil_kept: int = 5000,
|
||||
):
|
||||
if not min_length >= min_interval >= hop_size:
|
||||
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
|
||||
)
|
||||
if not max_sil_kept >= hop_size:
|
||||
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: max_sil_kept >= hop_size"
|
||||
)
|
||||
min_interval = sr * min_interval / 1000
|
||||
self.threshold = 10 ** (threshold / 20.)
|
||||
self.threshold = 10 ** (threshold / 20.0)
|
||||
self.hop_size = round(sr * hop_size / 1000)
|
||||
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||
self.min_interval = round(min_interval / self.hop_size)
|
||||
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||
|
||||
def _apply_slice(self, waveform, begin, end):
|
||||
if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
|
||||
else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
|
||||
if len(waveform.shape) > 1:
|
||||
return waveform[
|
||||
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
|
||||
]
|
||||
else:
|
||||
return waveform[
|
||||
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
|
||||
]
|
||||
|
||||
def slice(self, waveform):
|
||||
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
|
||||
if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
||||
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
if samples.shape[0] <= self.min_length:
|
||||
return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
||||
rms_list = librosa.feature.rms(
|
||||
y=samples, frame_length=self.win_size, hop_length=self.hop_size
|
||||
).squeeze(0)
|
||||
sil_tags, silence_start, clip_start = [], None, 0
|
||||
for i, rms in enumerate(rms_list):
|
||||
if rms < self.threshold: # Keep looping while frame is silent.
|
||||
if silence_start is None: # Record start of silent frames.
|
||||
silence_start = i
|
||||
continue
|
||||
if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded.
|
||||
if silence_start is None:
|
||||
continue # Keep looping while frame is not silent and silence start has not been recorded.
|
||||
# Clear recorded silence start if interval is not enough or clip is too short
|
||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
||||
need_slice_middle = (
|
||||
i - silence_start >= self.min_interval
|
||||
and i - clip_start >= self.min_length
|
||||
)
|
||||
if not is_leading_silence and not need_slice_middle:
|
||||
silence_start = None
|
||||
continue
|
||||
if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed.
|
||||
if (
|
||||
i - silence_start <= self.max_sil_kept
|
||||
): # Need slicing. Record the range of silent frames to be removed.
|
||||
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
||||
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
||||
clip_start = pos
|
||||
elif i - silence_start <= self.max_sil_kept * 2:
|
||||
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
|
||||
pos = rms_list[
|
||||
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
pos += i - self.max_sil_kept
|
||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
clip_start = pos_r
|
||||
|
@ -99,41 +170,105 @@ class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
|
|||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||
clip_start = max(pos_r, pos)
|
||||
else:
|
||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
|
||||
clip_start = pos_r
|
||||
silence_start = None
|
||||
total_frames = rms_list.shape[0]
|
||||
if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence.
|
||||
if (
|
||||
silence_start is not None
|
||||
and total_frames - silence_start >= self.min_interval
|
||||
): # Deal with trailing silence.
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices.
|
||||
if len(sil_tags) == 0:
|
||||
return {
|
||||
"0": {"slice": False, "split_time": f"0,{len(waveform)}"}
|
||||
} # Apply and return slices.
|
||||
chunks = []
|
||||
if sil_tags[0][0]:
|
||||
chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
|
||||
chunks.append(
|
||||
{
|
||||
"slice": False,
|
||||
"split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}",
|
||||
}
|
||||
)
|
||||
for i in range(0, len(sil_tags)):
|
||||
if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
|
||||
chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
|
||||
if i:
|
||||
chunks.append(
|
||||
{
|
||||
"slice": False,
|
||||
"split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}",
|
||||
}
|
||||
)
|
||||
chunks.append(
|
||||
{
|
||||
"slice": True,
|
||||
"split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}",
|
||||
}
|
||||
)
|
||||
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
||||
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
|
||||
chunks.append(
|
||||
{
|
||||
"slice": False,
|
||||
"split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}",
|
||||
}
|
||||
)
|
||||
chunk_dict = {}
|
||||
for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i]
|
||||
for i in range(len(chunks)):
|
||||
chunk_dict[str(i)] = chunks[i]
|
||||
return chunk_dict
|
||||
|
||||
|
||||
# sinc_interp_hann audio resampling
|
||||
class Resample:
|
||||
def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None):
|
||||
self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta
|
||||
def __init__(
|
||||
self,
|
||||
orig_freq: int = 16000,
|
||||
new_freq: int = 16000,
|
||||
lowpass_filter_width: int = 6,
|
||||
rolloff: float = 0.99,
|
||||
beta: Optional[float] = None,
|
||||
dtype: Optional[dtypes] = None,
|
||||
):
|
||||
(
|
||||
self.orig_freq,
|
||||
self.new_freq,
|
||||
self.lowpass_filter_width,
|
||||
self.rolloff,
|
||||
self.beta,
|
||||
) = (orig_freq, new_freq, lowpass_filter_width, rolloff, beta)
|
||||
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
|
||||
self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None)
|
||||
self.kernel, self.width = (
|
||||
self._get_sinc_resample_kernel(dtype)
|
||||
if self.orig_freq != self.new_freq
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
def __call__(self, waveform: Tensor) -> Tensor:
|
||||
if self.orig_freq == self.new_freq: return waveform
|
||||
if self.orig_freq == self.new_freq:
|
||||
return waveform
|
||||
return self._apply_sinc_resample_kernel(waveform)
|
||||
|
||||
def _apply_sinc_resample_kernel(self, waveform: Tensor):
|
||||
if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.")
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
||||
if not waveform.is_floating_point():
|
||||
raise TypeError(
|
||||
f"Waveform tensor expected to be of type float, but received {waveform.dtype}."
|
||||
)
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
|
||||
int(self.new_freq) // self.gcd
|
||||
)
|
||||
shape = waveform.shape
|
||||
waveform = waveform.reshape(-1, shape[-1]) # pack batch
|
||||
num_wavs, length = waveform.shape
|
||||
|
@ -144,34 +279,58 @@ class Resample:
|
|||
resampled = resampled[..., :target_length]
|
||||
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
|
||||
return resampled
|
||||
|
||||
def _get_sinc_resample_kernel(self, dtype=None):
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
||||
if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.")
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
|
||||
int(self.new_freq) // self.gcd
|
||||
)
|
||||
if self.lowpass_filter_width <= 0:
|
||||
raise ValueError("Low pass filter width should be positive.")
|
||||
base_freq = min(orig_freq, new_freq)
|
||||
base_freq *= self.rolloff
|
||||
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
|
||||
idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq
|
||||
idx = (
|
||||
Tensor.arange(
|
||||
-width,
|
||||
width + orig_freq,
|
||||
dtype=(dtype if dtype is not None else dtypes.float32),
|
||||
)[None, None]
|
||||
/ orig_freq
|
||||
)
|
||||
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
|
||||
t *= base_freq
|
||||
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
|
||||
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
|
||||
t *= math.pi
|
||||
scale = base_freq / orig_freq
|
||||
kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t)
|
||||
kernels = Tensor.where(
|
||||
t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t
|
||||
)
|
||||
kernels *= window * scale
|
||||
if dtype is None: kernels = kernels.cast(dtype=dtypes.float32)
|
||||
if dtype is None:
|
||||
kernels = kernels.cast(dtype=dtypes.float32)
|
||||
return kernels, width
|
||||
|
||||
def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None):
|
||||
|
||||
def sinc_interp_resample(
|
||||
x: Tensor,
|
||||
orig_freq: int = 16000,
|
||||
new_freq: int = 1600,
|
||||
lowpass_filter_width: int = 6,
|
||||
rolloff: float = 0.99,
|
||||
beta: Optional[float] = None,
|
||||
):
|
||||
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
|
||||
return resamp(x)
|
||||
|
||||
|
||||
def cut(audio_path, db_thresh=-30, min_len=5000):
|
||||
audio, sr = librosa.load(audio_path, sr=None)
|
||||
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
|
||||
chunks = slicer.slice(audio)
|
||||
return chunks
|
||||
|
||||
|
||||
def chunks2audio(audio_path, chunks):
|
||||
chunks = dict(chunks)
|
||||
audio, sr = load_audiofile(audio_path)
|
||||
|
@ -185,19 +344,30 @@ def chunks2audio(audio_path, chunks):
|
|||
result.append((v["slice"], audio[int(tag[0]) : int(tag[1])]))
|
||||
return result, sr
|
||||
|
||||
def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True):
|
||||
|
||||
def load_audiofile(
|
||||
filepath: str,
|
||||
frame_offset: int = 0,
|
||||
num_frames: int = -1,
|
||||
channels_first: bool = True,
|
||||
):
|
||||
with soundfile.SoundFile(filepath, "r") as file_:
|
||||
frames = file_._prepare_read(frame_offset, None, num_frames)
|
||||
waveform = file_.read(frames, "float32", always_2d=True)
|
||||
sample_rate = file_.samplerate
|
||||
waveform = Tensor(waveform)
|
||||
if channels_first: waveform = waveform.transpose(0, 1)
|
||||
if channels_first:
|
||||
waveform = waveform.transpose(0, 1)
|
||||
return waveform, sample_rate
|
||||
|
||||
def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]:
|
||||
|
||||
def get_unit_f0(
|
||||
wav: Tensor, tran, hop_length, target_sample, f0_filter=False
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
|
||||
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
|
||||
if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected")
|
||||
if f0_filter and sum(f0) == 0:
|
||||
raise RuntimeError("No voice detected")
|
||||
f0 = Tensor(f0.astype(np.float32)).float()
|
||||
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
|
||||
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
|
||||
|
|
|
@ -14,6 +14,7 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
|||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
self.norm = GroupNorm(32, in_channels)
|
||||
|
@ -30,22 +31,32 @@ class AttnBlock:
|
|||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = [x.reshape(b, c, h * w).transpose(1, 2) for x in (q, k, v)]
|
||||
h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
|
||||
h_ = (
|
||||
Tensor.scaled_dot_product_attention(q, k, v)
|
||||
.transpose(1, 2)
|
||||
.reshape(b, c, h, w)
|
||||
)
|
||||
return x + self.proj_out(h_)
|
||||
|
||||
|
||||
class ResnetBlock:
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
self.norm1 = GroupNorm(32, in_channels)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
||||
self.norm2 = GroupNorm(32, out_channels)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
|
||||
self.nin_shortcut = (
|
||||
Conv2d(in_channels, out_channels, 1)
|
||||
if in_channels != out_channels
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
h = self.conv1(self.norm1(x).swish())
|
||||
h = self.conv2(self.norm2(h).swish())
|
||||
return self.nin_shortcut(x) + h
|
||||
|
||||
|
||||
class Mid:
|
||||
def __init__(self, block_in):
|
||||
self.block_1 = ResnetBlock(block_in, block_in)
|
||||
|
@ -55,6 +66,7 @@ class Mid:
|
|||
def __call__(self, x):
|
||||
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
||||
|
||||
|
||||
class Decoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
||||
|
@ -63,11 +75,17 @@ class Decoder:
|
|||
|
||||
arr = []
|
||||
for i, s in enumerate(sz):
|
||||
arr.append({"block":
|
||||
[ResnetBlock(s[1], s[0]),
|
||||
arr.append(
|
||||
{
|
||||
"block": [
|
||||
ResnetBlock(s[1], s[0]),
|
||||
ResnetBlock(s[0], s[0]),
|
||||
ResnetBlock(s[0], s[0])]})
|
||||
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||||
ResnetBlock(s[0], s[0]),
|
||||
]
|
||||
}
|
||||
)
|
||||
if i != 0:
|
||||
arr[-1]["upsample"] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||||
self.up = arr
|
||||
|
||||
self.norm_out = GroupNorm(32, 128)
|
||||
|
@ -79,16 +97,22 @@ class Decoder:
|
|||
|
||||
for l in self.up[::-1]:
|
||||
print("decode", x.shape)
|
||||
for b in l['block']: x = b(x)
|
||||
if 'upsample' in l:
|
||||
for b in l["block"]:
|
||||
x = b(x)
|
||||
if "upsample" in l:
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
|
||||
bs, c, py, px = x.shape
|
||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
x = l['upsample']['conv'](x)
|
||||
x = (
|
||||
x.reshape(bs, c, py, 1, px, 1)
|
||||
.expand(bs, c, py, 2, px, 2)
|
||||
.reshape(bs, c, py * 2, px * 2)
|
||||
)
|
||||
x = l["upsample"]["conv"](x)
|
||||
x.realize()
|
||||
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self):
|
||||
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
||||
|
@ -96,10 +120,11 @@ class Encoder:
|
|||
|
||||
arr = []
|
||||
for i, s in enumerate(sz):
|
||||
arr.append({"block":
|
||||
[ResnetBlock(s[0], s[1]),
|
||||
ResnetBlock(s[1], s[1])]})
|
||||
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
|
||||
arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
|
||||
if i != 3:
|
||||
arr[-1]["downsample"] = {
|
||||
"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))
|
||||
}
|
||||
self.down = arr
|
||||
|
||||
self.mid = Mid(512)
|
||||
|
@ -111,12 +136,15 @@ class Encoder:
|
|||
|
||||
for l in self.down:
|
||||
print("encode", x.shape)
|
||||
for b in l['block']: x = b(x)
|
||||
if 'downsample' in l: x = l['downsample']['conv'](x)
|
||||
for b in l["block"]:
|
||||
x = b(x)
|
||||
if "downsample" in l:
|
||||
x = l["downsample"]["conv"](x)
|
||||
|
||||
x = self.mid(x)
|
||||
return self.conv_out(self.norm_out(x).swish())
|
||||
|
||||
|
||||
class AutoencoderKL:
|
||||
def __init__(self):
|
||||
self.encoder = Encoder()
|
||||
|
@ -132,25 +160,27 @@ class AutoencoderKL:
|
|||
latent = self.post_quant_conv(latent)
|
||||
return self.decoder(latent)
|
||||
|
||||
|
||||
# not to be confused with ResnetBlock
|
||||
class ResBlock:
|
||||
def __init__(self, channels, emb_channels, out_channels):
|
||||
self.in_layers = [
|
||||
GroupNorm(32, channels),
|
||||
Tensor.silu,
|
||||
Conv2d(channels, out_channels, 3, padding=1)
|
||||
]
|
||||
self.emb_layers = [
|
||||
Tensor.silu,
|
||||
Linear(emb_channels, out_channels)
|
||||
Conv2d(channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.emb_layers = [Tensor.silu, Linear(emb_channels, out_channels)]
|
||||
self.out_layers = [
|
||||
GroupNorm(32, out_channels),
|
||||
Tensor.silu,
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Conv2d(out_channels, out_channels, 3, padding=1)
|
||||
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||
]
|
||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
|
||||
self.skip_connection = (
|
||||
Conv2d(channels, out_channels, 1)
|
||||
if channels != out_channels
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
def __call__(self, x, emb):
|
||||
h = x.sequential(self.in_layers)
|
||||
|
@ -160,6 +190,7 @@ class ResBlock:
|
|||
ret = self.skip_connection(x) + h
|
||||
return ret
|
||||
|
||||
|
||||
class CrossAttention:
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
self.to_q = Linear(query_dim, n_heads * d_head, bias=False)
|
||||
|
@ -172,11 +203,15 @@ class CrossAttention:
|
|||
def __call__(self, x, context=None):
|
||||
context = x if context is None else context
|
||||
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
||||
q, k, v = [
|
||||
y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1, 2)
|
||||
for y in (q, k, v)
|
||||
]
|
||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2)
|
||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
return h_.sequential(self.to_out)
|
||||
|
||||
|
||||
class GEGLU:
|
||||
def __init__(self, dim_in, dim_out):
|
||||
self.proj = Linear(dim_in, dim_out * 2)
|
||||
|
@ -186,17 +221,19 @@ class GEGLU:
|
|||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * gate.gelu()
|
||||
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, mult=4):
|
||||
self.net = [
|
||||
GEGLU(dim, dim * mult),
|
||||
lambda x: x, # needed for weights loading code to work
|
||||
Linear(dim*mult, dim)
|
||||
Linear(dim * mult, dim),
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
return x.sequential(self.net)
|
||||
|
||||
|
||||
class BasicTransformerBlock:
|
||||
def __init__(self, dim, context_dim, n_heads, d_head):
|
||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||
|
@ -212,12 +249,15 @@ class BasicTransformerBlock:
|
|||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels, context_dim, n_heads, d_head):
|
||||
self.norm = GroupNorm(32, channels)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
|
||||
self.transformer_blocks = [
|
||||
BasicTransformerBlock(channels, context_dim, n_heads, d_head)
|
||||
]
|
||||
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
|
||||
|
||||
def __call__(self, x, context=None):
|
||||
|
@ -232,6 +272,7 @@ class SpatialTransformer:
|
|||
ret = self.proj_out(x) + x_in
|
||||
return ret
|
||||
|
||||
|
||||
class Downsample:
|
||||
def __init__(self, channels):
|
||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||
|
@ -239,21 +280,28 @@ class Downsample:
|
|||
def __call__(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Upsample:
|
||||
def __init__(self, channels):
|
||||
self.conv = Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
bs, c, py, px = x.shape
|
||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
||||
x = (
|
||||
x.reshape(bs, c, py, 1, px, 1)
|
||||
.expand(bs, c, py, 2, px, 2)
|
||||
.reshape(bs, c, py * 2, px * 2)
|
||||
)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
args = timesteps * freqs
|
||||
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)
|
||||
|
||||
|
||||
class UNetModel:
|
||||
def __init__(self):
|
||||
self.time_embed = [
|
||||
|
@ -273,12 +321,12 @@ class UNetModel:
|
|||
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[Downsample(1280)],
|
||||
[ResBlock(1280, 1280, 1280)],
|
||||
[ResBlock(1280, 1280, 1280)]
|
||||
[ResBlock(1280, 1280, 1280)],
|
||||
]
|
||||
self.middle_block = [
|
||||
ResBlock(1280, 1280, 1280),
|
||||
SpatialTransformer(1280, 768, 8, 160),
|
||||
ResBlock(1280, 1280, 1280)
|
||||
ResBlock(1280, 1280, 1280),
|
||||
]
|
||||
self.output_blocks = [
|
||||
[ResBlock(2560, 1280, 1280)],
|
||||
|
@ -286,10 +334,18 @@ class UNetModel:
|
|||
[ResBlock(2560, 1280, 1280), Upsample(1280)],
|
||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||
[ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)],
|
||||
[
|
||||
ResBlock(1920, 1280, 1280),
|
||||
SpatialTransformer(1280, 768, 8, 160),
|
||||
Upsample(1280),
|
||||
],
|
||||
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
|
||||
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||
[ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)],
|
||||
[
|
||||
ResBlock(960, 1280, 640),
|
||||
SpatialTransformer(640, 768, 8, 80),
|
||||
Upsample(640),
|
||||
],
|
||||
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||
|
@ -297,7 +353,7 @@ class UNetModel:
|
|||
self.out = [
|
||||
GroupNorm(32, 320),
|
||||
Tensor.silu,
|
||||
Conv2d(320, 4, kernel_size=3, padding=1)
|
||||
Conv2d(320, 4, kernel_size=3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x, timesteps=None, context=None):
|
||||
|
@ -306,9 +362,12 @@ class UNetModel:
|
|||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
def run(x, bb):
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
elif isinstance(bb, SpatialTransformer): x = bb(x, context)
|
||||
else: x = bb(x)
|
||||
if isinstance(bb, ResBlock):
|
||||
x = bb(x, emb)
|
||||
elif isinstance(bb, SpatialTransformer):
|
||||
x = bb(x, context)
|
||||
else:
|
||||
x = bb(x)
|
||||
return x
|
||||
|
||||
saved_inputs = []
|
||||
|
@ -326,6 +385,7 @@ class UNetModel:
|
|||
x = run(x, bb)
|
||||
return x.sequential(self.out)
|
||||
|
||||
|
||||
class CLIPMLP:
|
||||
def __init__(self):
|
||||
self.fc1 = Linear(768, 3072)
|
||||
|
@ -337,6 +397,7 @@ class CLIPMLP:
|
|||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPAttention:
|
||||
def __init__(self):
|
||||
self.embed_dim = 768
|
||||
|
@ -349,10 +410,22 @@ class CLIPAttention:
|
|||
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
||||
q, k, v = (
|
||||
self.q_proj(hidden_states),
|
||||
self.k_proj(hidden_states),
|
||||
self.v_proj(hidden_states),
|
||||
)
|
||||
q, k, v = [
|
||||
x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
for x in (q, k, v)
|
||||
]
|
||||
attn_output = Tensor.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=causal_attention_mask
|
||||
)
|
||||
return self.out_proj(
|
||||
attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
|
||||
)
|
||||
|
||||
|
||||
class CLIPEncoderLayer:
|
||||
def __init__(self):
|
||||
|
@ -374,6 +447,7 @@ class CLIPEncoderLayer:
|
|||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoder:
|
||||
def __init__(self):
|
||||
self.layers = [CLIPEncoderLayer() for i in range(12)]
|
||||
|
@ -383,6 +457,7 @@ class CLIPEncoder:
|
|||
hidden_states = l(hidden_states, causal_attention_mask)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPTextEmbeddings:
|
||||
def __init__(self):
|
||||
self.token_embedding = Embedding(49408, 768)
|
||||
|
@ -391,6 +466,7 @@ class CLIPTextEmbeddings:
|
|||
def __call__(self, input_ids, position_ids):
|
||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||
|
||||
|
||||
class CLIPTextTransformer:
|
||||
def __init__(self):
|
||||
self.embeddings = CLIPTextEmbeddings()
|
||||
|
@ -402,9 +478,15 @@ class CLIPTextTransformer:
|
|||
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1))
|
||||
return self.final_layer_norm(x)
|
||||
|
||||
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
@lru_cache()
|
||||
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
||||
def default_bpe():
|
||||
return fetch(
|
||||
"https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz",
|
||||
"bpe_simple_vocab_16e6.txt.gz",
|
||||
)
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
@ -417,11 +499,13 @@ def get_pairs(word):
|
|||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
|
@ -432,7 +516,11 @@ def bytes_to_unicode():
|
|||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
|
@ -443,33 +531,40 @@ def bytes_to_unicode():
|
|||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
class ClipTokenizer:
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
||||
merges = merges[1 : 49152 - 256 - 2 + 1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
vocab = vocab + [v + "</w>" for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
vocab.append("".join(merge))
|
||||
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
||||
self.cache = {
|
||||
"<|startoftext|>": "<|startoftext|>",
|
||||
"<|endoftext|>": "<|endoftext|>",
|
||||
}
|
||||
self.pat = re.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
return token + "</w>"
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
|
@ -495,7 +590,7 @@ class ClipTokenizer:
|
|||
if len(word) == 1:
|
||||
break
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
|
@ -503,19 +598,28 @@ class ClipTokenizer:
|
|||
bpe_tokens = []
|
||||
text = whitespace_clean(text.strip()).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(
|
||||
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
||||
)
|
||||
# Truncation, keeping two slots for start and end tokens.
|
||||
if len(bpe_tokens) > 75:
|
||||
bpe_tokens = bpe_tokens[:75]
|
||||
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
self.alphas_cumprod = Tensor.empty(1000)
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
|
||||
diffusion_model=UNetModel()
|
||||
)
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
|
||||
transformer=namedtuple("Transformer", ["text_model"])(
|
||||
text_model=CLIPTextTransformer()
|
||||
)
|
||||
)
|
||||
|
||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||
temperature = 1
|
||||
|
@ -526,17 +630,30 @@ class StableDiffusion:
|
|||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
|
||||
def get_model_output(
|
||||
self,
|
||||
unconditional_context,
|
||||
context,
|
||||
latent,
|
||||
timestep,
|
||||
unconditional_guidance_scale,
|
||||
):
|
||||
# put into diffuser
|
||||
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
|
||||
latents = self.model.diffusion_model(
|
||||
latent.expand(2, *latent.shape[1:]),
|
||||
timestep,
|
||||
unconditional_context.cat(context, dim=0),
|
||||
)
|
||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (
|
||||
latent - unconditional_latent
|
||||
)
|
||||
return e_t
|
||||
|
||||
def decode(self, x):
|
||||
|
@ -548,14 +665,26 @@ class StableDiffusion:
|
|||
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
|
||||
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
|
||||
|
||||
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
|
||||
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
|
||||
def __call__(
|
||||
self,
|
||||
unconditional_context,
|
||||
context,
|
||||
latent,
|
||||
timestep,
|
||||
alphas,
|
||||
alphas_prev,
|
||||
guidance,
|
||||
):
|
||||
e_t = self.get_model_output(
|
||||
unconditional_context, context, latent, timestep, guidance
|
||||
)
|
||||
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
|
||||
# e_t_next = get_model_output(x_prev)
|
||||
# e_t_prime = (e_t + e_t_next) / 2
|
||||
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||
return x_prev.realize()
|
||||
|
||||
|
||||
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
||||
# 3x512x512 <--> 4x64x64 (16384)
|
||||
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
|
||||
|
@ -573,22 +702,48 @@ class StableDiffusion:
|
|||
# cond_stage_model.transformer.text_model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
|
||||
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
|
||||
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
||||
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
||||
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Stable Diffusion",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps", type=int, default=5, help="Number of steps in diffusion"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a horse sized cat eating a bagel",
|
||||
help="Phrase to render",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out",
|
||||
type=str,
|
||||
default=Path(tempfile.gettempdir()) / "rendered.png",
|
||||
help="Output filename",
|
||||
)
|
||||
parser.add_argument("--noshow", action="store_true", help="Don't show the image")
|
||||
parser.add_argument(
|
||||
"--fp16", action="store_true", help="Cast the weights to float16"
|
||||
)
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per step")
|
||||
parser.add_argument("--seed", type=int, help="Set the random latent seed")
|
||||
parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength")
|
||||
args = parser.parse_args()
|
||||
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
load_state_dict(
|
||||
model,
|
||||
torch_load(
|
||||
fetch(
|
||||
"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
|
||||
"sd-v1-4.ckpt",
|
||||
)
|
||||
)["state_dict"],
|
||||
strict=False,
|
||||
)
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(model).values():
|
||||
|
@ -601,7 +756,9 @@ if __name__ == "__main__":
|
|||
print("got CLIP context", context.shape)
|
||||
|
||||
prompt = Tensor([tokenizer.encode("")])
|
||||
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
||||
unconditional_context = model.cond_stage_model.transformer.text_model(
|
||||
prompt
|
||||
).realize()
|
||||
print("got unconditional CLIP context", unconditional_context.shape)
|
||||
|
||||
# done with clip model
|
||||
|
@ -613,21 +770,37 @@ if __name__ == "__main__":
|
|||
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
||||
|
||||
# start with random noise
|
||||
if args.seed is not None: Tensor._seed = args.seed
|
||||
if args.seed is not None:
|
||||
Tensor._seed = args.seed
|
||||
latent = Tensor.randn(1, 4, 64, 64)
|
||||
|
||||
@TinyJit
|
||||
def run(model, *x): return model(*x).realize()
|
||||
def run(model, *x):
|
||||
return model(*x).realize()
|
||||
|
||||
# this is diffusion
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
with Timing(
|
||||
"step in ",
|
||||
enabled=args.timing,
|
||||
on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB",
|
||||
):
|
||||
tid = Tensor([index])
|
||||
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
latent = run(
|
||||
model,
|
||||
unconditional_context,
|
||||
context,
|
||||
latent,
|
||||
Tensor([timestep]),
|
||||
alphas[tid],
|
||||
alphas_prev[tid],
|
||||
Tensor([args.guidance]),
|
||||
)
|
||||
if args.timing:
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
del run
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
|
@ -637,8 +810,10 @@ if __name__ == "__main__":
|
|||
# save image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
# Open image.
|
||||
if not args.noshow: im.show()
|
||||
if not args.noshow:
|
||||
im.show()
|
||||
|
|
|
@ -10,6 +10,7 @@ from tinygrad.tensor import Tensor
|
|||
from extra.datasets import fetch_cifar
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
|
||||
|
||||
class TinyConvNet:
|
||||
def __init__(self, classes=10):
|
||||
conv = 3
|
||||
|
@ -24,6 +25,7 @@ class TinyConvNet:
|
|||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
IMAGENET = getenv("IMAGENET")
|
||||
classes = 1000 if IMAGENET else 10
|
||||
|
@ -47,12 +49,14 @@ if __name__ == "__main__":
|
|||
|
||||
if IMAGENET:
|
||||
from extra.datasets.imagenet import fetch_batch
|
||||
|
||||
def loader(q):
|
||||
while 1:
|
||||
try:
|
||||
q.put(fetch_batch(BS))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
q = Queue(16)
|
||||
for i in range(2):
|
||||
p = Process(target=loader, args=(q,))
|
||||
|
@ -97,9 +101,17 @@ if __name__ == "__main__":
|
|||
finish_time = (time.time() - st) * 1000.0
|
||||
|
||||
# printing
|
||||
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
||||
(loss, accuracy,
|
||||
fp_time, bp_time, opt_time, finish_time,
|
||||
fp_time + bp_time + opt_time + finish_time))
|
||||
t.set_description(
|
||||
"loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f"
|
||||
% (
|
||||
loss,
|
||||
accuracy,
|
||||
fp_time,
|
||||
bp_time,
|
||||
opt_time,
|
||||
finish_time,
|
||||
fp_time + bp_time + opt_time + finish_time,
|
||||
)
|
||||
)
|
||||
|
||||
del out, y, loss
|
||||
|
|
|
@ -19,27 +19,30 @@ class ComposeTransforms:
|
|||
x = t(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
classes = 10
|
||||
|
||||
TRANSFER = getenv('TRANSFER')
|
||||
model = ResNet(getenv('NUM', 18), num_classes=classes)
|
||||
TRANSFER = getenv("TRANSFER")
|
||||
model = ResNet(getenv("NUM", 18), num_classes=classes)
|
||||
if TRANSFER:
|
||||
model.load_from_pretrained()
|
||||
|
||||
lr = 5e-3
|
||||
transform = ComposeTransforms([
|
||||
lambda x: [Image.fromarray(xx, mode='L').resize((64, 64)) for xx in x],
|
||||
transform = ComposeTransforms(
|
||||
[
|
||||
lambda x: [Image.fromarray(xx, mode="L").resize((64, 64)) for xx in x],
|
||||
lambda x: np.stack([np.asarray(xx) for xx in x], 0),
|
||||
lambda x: x / 255.0,
|
||||
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
|
||||
])
|
||||
]
|
||||
)
|
||||
for _ in range(5):
|
||||
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
|
||||
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
|
||||
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
|
||||
lr /= 1.2
|
||||
print(f'reducing lr to {lr:.7f}')
|
||||
print(f"reducing lr to {lr:.7f}")
|
||||
|
|
|
@ -7,13 +7,16 @@ from tinygrad.nn.optim import Adam
|
|||
from extra.training import train, evaluate
|
||||
from extra.models.transformer import Transformer
|
||||
|
||||
|
||||
# dataset idea from https://github.com/karpathy/minGPT/blob/master/projects/adder/adder.py
|
||||
def make_dataset():
|
||||
ds = []
|
||||
for i in range(100):
|
||||
for j in range(100):
|
||||
s = i + j
|
||||
ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
|
||||
ds.append(
|
||||
[i // 10, i % 10, j // 10, j % 10, s // 100, (s // 10) % 10, s % 10]
|
||||
)
|
||||
random.shuffle(ds)
|
||||
ds = np.array(ds).astype(np.float32)
|
||||
ds_X = ds[:, 0:6]
|
||||
|
@ -22,6 +25,7 @@ def make_dataset():
|
|||
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4, 32)
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
|
@ -29,14 +33,23 @@ if __name__ == "__main__":
|
|||
for i in range(10):
|
||||
optim = Adam(get_parameters(model), lr=lr)
|
||||
train(model, X_train, Y_train, optim, 50, BS=64)
|
||||
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
|
||||
acc, Y_test_preds = evaluate(
|
||||
model, X_test, Y_test, num_classes=10, return_predict=True
|
||||
)
|
||||
lr /= 1.2
|
||||
print(f'reducing lr to {lr:.4f}')
|
||||
print(f"reducing lr to {lr:.4f}")
|
||||
if acc > 0.998:
|
||||
wrong = 0
|
||||
for k in range(len(Y_test_preds)):
|
||||
if (Y_test_preds[k] != Y_test[k]).any():
|
||||
wrong += 1
|
||||
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
|
||||
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
|
||||
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
|
||||
a, b, c, x = (
|
||||
X_test[k, :2],
|
||||
X_test[k, 2:4],
|
||||
Y_test[k, -3:],
|
||||
Y_test_preds[k, -3:],
|
||||
)
|
||||
print(
|
||||
f"{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})"
|
||||
)
|
||||
print(f"Wrong predictions: {wrong}, acc = {acc:.4f}")
|
||||
|
|
|
@ -12,6 +12,7 @@ from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7
|
|||
# amount of context erased by model
|
||||
CONTEXT = 7
|
||||
|
||||
|
||||
def get_sample_count(samples_dir):
|
||||
try:
|
||||
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
|
||||
|
@ -21,18 +22,24 @@ def get_sample_count(samples_dir):
|
|||
except:
|
||||
return 0
|
||||
|
||||
|
||||
def set_sample_count(samples_dir, sc):
|
||||
with open(samples_dir + "/sample_count.txt", "w") as file:
|
||||
file.write(str(sc) + "\n")
|
||||
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("python3 -m examples.vgg7 import MODELJSON MODEL")
|
||||
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
|
||||
print(
|
||||
" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json"
|
||||
)
|
||||
print(" into a safetensors file")
|
||||
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
|
||||
print(" *this format is used by most other commands in this program*")
|
||||
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
|
||||
print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors")
|
||||
print(
|
||||
" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors"
|
||||
)
|
||||
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
|
||||
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
|
||||
print(" output image has 7 pixels removed on all edges")
|
||||
|
@ -53,7 +60,9 @@ if len(sys.argv) < 2:
|
|||
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
|
||||
print(" won't pad or tile, so keep image sizes sane")
|
||||
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
|
||||
print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
|
||||
print(
|
||||
" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training"
|
||||
)
|
||||
print(" maintains/creates samples_count.txt automatically")
|
||||
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
|
||||
sys.exit(1)
|
||||
|
@ -61,9 +70,13 @@ if len(sys.argv) < 2:
|
|||
cmd = sys.argv[1]
|
||||
vgg7 = Vgg7()
|
||||
|
||||
|
||||
def nansbane(p):
|
||||
if numpy.isnan(numpy.min(p.numpy())):
|
||||
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
|
||||
raise Exception(
|
||||
"A NaN in the model has been detected. This model will not be interacted with to prevent further damage."
|
||||
)
|
||||
|
||||
|
||||
def load_and_save(path, save):
|
||||
if save:
|
||||
|
@ -77,6 +90,7 @@ def load_and_save(path, save):
|
|||
for v in vgg7.get_parameters():
|
||||
nansbane(v)
|
||||
|
||||
|
||||
if cmd == "import":
|
||||
src = sys.argv[2]
|
||||
model = sys.argv[3]
|
||||
|
@ -158,7 +172,9 @@ elif cmd == "train":
|
|||
|
||||
sample_idx = 0
|
||||
try:
|
||||
sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
|
||||
sample_idx = numpy.random.choice(
|
||||
samples_count, p=sample_probs / sample_probs.sum()
|
||||
)
|
||||
except:
|
||||
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
|
||||
sample_idx = random.randint(0, samples_count - 1)
|
||||
|
@ -204,7 +220,7 @@ elif cmd == "train":
|
|||
rnum = rnum + 1
|
||||
# Probability management
|
||||
# there must always be a probability, no matter how slim, even if loss goes to 0
|
||||
sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
|
||||
sample_probs[sample_idx] = max(loss_indicator, 1.0e-10)
|
||||
|
||||
# if we were told to save every round, we already saved
|
||||
if rounds_per_save != 1:
|
||||
|
@ -237,8 +253,12 @@ elif cmd == "samplify":
|
|||
samples_added = 0
|
||||
|
||||
# actual patch extraction
|
||||
for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
|
||||
for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
|
||||
for posy in range(
|
||||
CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size
|
||||
):
|
||||
for posx in range(
|
||||
CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size
|
||||
):
|
||||
# this is a viable patch location, add it
|
||||
# note the ranges here:
|
||||
# + there are always CONTEXT pixels *before* the point
|
||||
|
@ -247,7 +267,12 @@ elif cmd == "samplify":
|
|||
# + additionally, there are sample_size - 1 additional sample pixels
|
||||
# + additionally, there are CONTEXT additional pixels
|
||||
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
|
||||
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
|
||||
patch_x = a_img[
|
||||
:,
|
||||
:,
|
||||
posy - CONTEXT : posy + CONTEXT + sample_size,
|
||||
posx - CONTEXT : posx + CONTEXT + sample_size,
|
||||
]
|
||||
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
|
||||
|
||||
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
|
||||
|
|
|
@ -11,6 +11,7 @@ from tinygrad.helpers import fetch
|
|||
# tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
|
||||
# tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
|
||||
|
||||
|
||||
def image_load(path) -> numpy.ndarray:
|
||||
"""
|
||||
Loads an image in the shape expected by other functions in this module.
|
||||
|
@ -29,6 +30,7 @@ def image_load(path) -> numpy.ndarray:
|
|||
na = na.astype("float32") / 255.0
|
||||
return na
|
||||
|
||||
|
||||
def image_save(path, na: numpy.ndarray):
|
||||
"""
|
||||
Saves an image of the shape expected by other functions in this module.
|
||||
|
@ -44,12 +46,15 @@ def image_save(path, na: numpy.ndarray):
|
|||
# file
|
||||
Image.fromarray(na).save(path)
|
||||
|
||||
|
||||
# The Model
|
||||
|
||||
|
||||
class Conv3x3Biased:
|
||||
"""
|
||||
A 3x3 convolution layer with some utility functions.
|
||||
"""
|
||||
|
||||
def __init__(self, inC, outC, last=False):
|
||||
# The properties must be named as "W" and "b".
|
||||
# This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x
|
||||
|
@ -80,9 +85,12 @@ class Conv3x3Biased:
|
|||
# Not outChannel,inChannel,Y,X.
|
||||
# Therefore, transpose it before assignment.
|
||||
# I have long since forgotten how I worked this out.
|
||||
self.W.assign(Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3))
|
||||
self.W.assign(
|
||||
Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3)
|
||||
)
|
||||
self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape))
|
||||
|
||||
|
||||
class Vgg7:
|
||||
"""
|
||||
The 'vgg7' waifu2x network.
|
||||
|
@ -115,14 +123,31 @@ class Vgg7:
|
|||
return x
|
||||
|
||||
def get_parameters(self) -> list:
|
||||
return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
|
||||
return (
|
||||
self.conv1.get_parameters()
|
||||
+ self.conv2.get_parameters()
|
||||
+ self.conv3.get_parameters()
|
||||
+ self.conv4.get_parameters()
|
||||
+ self.conv5.get_parameters()
|
||||
+ self.conv6.get_parameters()
|
||||
+ self.conv7.get_parameters()
|
||||
)
|
||||
|
||||
def load_from_pretrained(self, intent="art", subtype="scale2.0x"):
|
||||
"""
|
||||
Downloads a nagadomi/waifu2x JSON weight file and loads it.
|
||||
"""
|
||||
import json
|
||||
data = json.loads(fetch("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json").read_bytes())
|
||||
|
||||
data = json.loads(
|
||||
fetch(
|
||||
"https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/"
|
||||
+ intent
|
||||
+ "/"
|
||||
+ subtype
|
||||
+ "_model.json"
|
||||
).read_bytes()
|
||||
)
|
||||
self.load_waifu2x_json(data)
|
||||
|
||||
def load_waifu2x_json(self, data: list):
|
||||
|
@ -157,7 +182,9 @@ class Vgg7:
|
|||
|
||||
# Padding next. Note that this padding is done on the whole image.
|
||||
# Padding the tiles would lose critical context, cause seams, etc.
|
||||
image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge")
|
||||
image = numpy.pad(
|
||||
image, [[0, 0], [0, 0], [context, context], [context, context]], mode="edge"
|
||||
)
|
||||
|
||||
# Now for tiling.
|
||||
# The output tile size is the usable output from an input tile (tile_size).
|
||||
|
@ -187,7 +214,8 @@ class Vgg7:
|
|||
tile_t = Tensor(tile)
|
||||
tile_fwd_t = self.forward(tile_t)
|
||||
# Replace tile.
|
||||
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy()
|
||||
image_out[
|
||||
:, :, out_y : out_y + out_h, out_x : out_x + out_w
|
||||
] = tile_fwd_t.numpy()
|
||||
|
||||
return image_out
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from PIL import Image
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
from extra.models.vit import ViT
|
||||
|
||||
"""
|
||||
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
||||
import tensorflow as tf
|
||||
|
@ -22,7 +23,11 @@ else:
|
|||
m.load_from_pretrained()
|
||||
|
||||
# category labels
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
lbls = ast.literal_eval(
|
||||
fetch(
|
||||
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||
).read_text()
|
||||
)
|
||||
|
||||
# url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
||||
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
||||
|
@ -30,7 +35,9 @@ url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-1
|
|||
# junk
|
||||
img = Image.open(fetch(url))
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
img = img.resize(
|
||||
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
|
||||
)
|
||||
img = np.array(img)
|
||||
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
||||
img = img[y0 : y0 + 224, x0 : x0 + 224]
|
||||
|
|
1857
examples/vits.py
1857
examples/vits.py
File diff suppressed because it is too large
Load Diff
|
@ -1,7 +1,13 @@
|
|||
import os
|
||||
from extra.export_model import compile_net, jit_model
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
||||
from tinygrad.nn.state import (
|
||||
get_state_dict,
|
||||
safe_save,
|
||||
safe_load_metadata,
|
||||
torch_load,
|
||||
load_state_dict,
|
||||
)
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device
|
||||
from tinygrad.helpers import fetch
|
||||
|
@ -10,10 +16,13 @@ from pathlib import Path
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_f32_to_f16(input_file, output_file):
|
||||
with open(input_file, 'rb') as f:
|
||||
with open(input_file, "rb") as f:
|
||||
metadata_length_bytes = f.read(8)
|
||||
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
|
||||
metadata_length = int.from_bytes(
|
||||
metadata_length_bytes, byteorder="little", signed=False
|
||||
)
|
||||
metadata_json_bytes = f.read(metadata_length)
|
||||
float32_values = np.fromfile(f, dtype=np.float32)
|
||||
|
||||
|
@ -22,12 +31,13 @@ def convert_f32_to_f16(input_file, output_file):
|
|||
front_float16_values = float32_values[:num_elements].astype(np.float16)
|
||||
rest_float32_values = float32_values[num_elements:]
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(metadata_length_bytes)
|
||||
f.write(metadata_json_bytes)
|
||||
front_float16_values.tofile(f)
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
|
@ -35,7 +45,7 @@ def split_safetensor(fn):
|
|||
|
||||
for k in metadata:
|
||||
# safetensor is in fp16, except for text moel
|
||||
if (metadata[k]["data_offsets"][0] < text_model_offset):
|
||||
if metadata[k]["data_offsets"][0] < text_model_offset:
|
||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0] / 2)
|
||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1] / 2)
|
||||
|
||||
|
@ -43,35 +53,43 @@ def split_safetensor(fn):
|
|||
part_end_offsets = []
|
||||
|
||||
for k in metadata:
|
||||
offset = metadata[k]['data_offsets'][0]
|
||||
offset = metadata[k]["data_offsets"][0]
|
||||
|
||||
if offset == text_model_offset:
|
||||
break
|
||||
|
||||
part_offset = offset - last_offset
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
if part_offset >= chunk_size:
|
||||
part_end_offsets.append(8 + json_len + offset)
|
||||
last_offset = offset
|
||||
|
||||
text_model_start = int(text_model_offset / 2)
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
net_bytes = bytes(open(fn, "rb").read())
|
||||
part_end_offsets.append(text_model_start + 8 + json_len)
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(f'./net_part{i}.safetensors', "wb+") as f:
|
||||
with open(f"./net_part{i}.safetensors", "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(f'./net_textmodel.safetensors', "wb+") as f:
|
||||
with open(f"./net_textmodel.safetensors", "wb+") as f:
|
||||
f.write(net_bytes[text_model_start + 8 + json_len :])
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Stable Diffusion",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remoteweights",
|
||||
action="store_true",
|
||||
help="Use safetensors from Huggingface, or from local",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
|
||||
|
@ -79,7 +97,16 @@ if __name__ == "__main__":
|
|||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
load_state_dict(
|
||||
model,
|
||||
torch_load(
|
||||
fetch(
|
||||
"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
|
||||
"sd-v1-4.ckpt",
|
||||
)
|
||||
)["state_dict"],
|
||||
strict=False,
|
||||
)
|
||||
|
||||
class Step(NamedTuple):
|
||||
name: str = ""
|
||||
|
@ -87,9 +114,25 @@ if __name__ == "__main__":
|
|||
forward: Any = None
|
||||
|
||||
sub_steps = [
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
|
||||
Step(
|
||||
name="textModel",
|
||||
input=[Tensor.randn(1, 77)],
|
||||
forward=model.cond_stage_model.transformer.text_model,
|
||||
),
|
||||
Step(
|
||||
name="diffusor",
|
||||
input=[
|
||||
Tensor.randn(1, 77, 768),
|
||||
Tensor.randn(1, 77, 768),
|
||||
Tensor.randn(1, 4, 64, 64),
|
||||
Tensor.rand(1),
|
||||
Tensor.randn(1),
|
||||
Tensor.randn(1),
|
||||
Tensor.randn(1),
|
||||
],
|
||||
forward=model,
|
||||
),
|
||||
Step(name="decoder", input=[Tensor.randn(1, 4, 64, 64)], forward=model.decode),
|
||||
]
|
||||
|
||||
prg = ""
|
||||
|
@ -99,12 +142,47 @@ if __name__ == "__main__":
|
|||
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
|
||||
kernel_code = "\n\n".join(
|
||||
[
|
||||
f"const {key} = `{code.replace(key, 'main')}`;"
|
||||
for key, code in functions.items()
|
||||
]
|
||||
)
|
||||
kernel_names = ", ".join([name for (name, _, _, _) in statements])
|
||||
kernel_calls = "\n ".join(
|
||||
[
|
||||
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
|
||||
for i, (_name, args, global_size, _local_size) in enumerate(statements)
|
||||
]
|
||||
)
|
||||
bufs = "\n ".join(
|
||||
[
|
||||
f"const {name} = "
|
||||
+ (
|
||||
f"createEmptyBuf(device, {size});"
|
||||
if _key not in weights
|
||||
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))"
|
||||
)
|
||||
+ ";"
|
||||
for name, (size, dtype, _key) in bufs.items()
|
||||
]
|
||||
)
|
||||
gpu_write_bufs = "\n ".join(
|
||||
[
|
||||
f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
|
||||
for i, (_, value) in enumerate(special_names.items())
|
||||
if "output" not in value
|
||||
]
|
||||
)
|
||||
input_writer = "\n ".join(
|
||||
[
|
||||
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
|
||||
+ f"data{i});"
|
||||
+ f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);"
|
||||
for i, (_, value) in enumerate(special_names.items())
|
||||
if value != "output0"
|
||||
]
|
||||
)
|
||||
return f"""\n var {step.name} = function() {{
|
||||
|
||||
{kernel_code}
|
||||
|
@ -143,7 +221,7 @@ if __name__ == "__main__":
|
|||
"""
|
||||
|
||||
for step in sub_steps:
|
||||
print(f'Executing step={step.name}')
|
||||
print(f"Executing step={step.name}")
|
||||
prg += compile_step(model, step)
|
||||
|
||||
if step.name == "diffusor":
|
||||
|
@ -151,7 +229,9 @@ if __name__ == "__main__":
|
|||
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
|
||||
else:
|
||||
state = get_state_dict(model)
|
||||
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
safe_save(
|
||||
state, os.path.join(os.path.dirname(__file__), "net.safetensors")
|
||||
)
|
||||
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
|
||||
split_safetensor("./net_conv.safetensors")
|
||||
os.remove("net.safetensors")
|
||||
|
|
|
@ -15,8 +15,15 @@ from tinygrad.tensor import Tensor
|
|||
import itertools
|
||||
import librosa
|
||||
|
||||
|
||||
class MultiHeadAttention:
|
||||
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
|
||||
def __init__(
|
||||
self,
|
||||
n_state,
|
||||
n_head,
|
||||
kv_caching: Literal["cross", "self"] = None,
|
||||
max_self_attn_cache_len=None,
|
||||
):
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
|
@ -26,11 +33,17 @@ class MultiHeadAttention:
|
|||
self.kv_caching = kv_caching
|
||||
self.max_self_attn_cache_len = max_self_attn_cache_len
|
||||
|
||||
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
|
||||
if self.kv_caching == 'cross':
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
len: Union[Variable, int] = None,
|
||||
):
|
||||
if self.kv_caching == "cross":
|
||||
if xa is not None:
|
||||
k, v = self.key(xa), self.value(xa)
|
||||
if not hasattr(self, 'cache_k'):
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = k, v
|
||||
else:
|
||||
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
|
||||
|
@ -40,50 +53,84 @@ class MultiHeadAttention:
|
|||
k, v = self.cache_k, self.cache_v
|
||||
else:
|
||||
k, v = self.key(x), self.value(x)
|
||||
if self.kv_caching == 'self':
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
if self.kv_caching == "self":
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k = Tensor.zeros(
|
||||
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
|
||||
)
|
||||
self.cache_v = Tensor.zeros(
|
||||
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
|
||||
)
|
||||
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
||||
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
||||
padding = self.max_self_attn_cache_len - len - x.shape[1]
|
||||
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
self.cache_k.assign(
|
||||
k.pad((None, (0, padding), None)).contiguous()
|
||||
).realize()
|
||||
self.cache_v.assign(
|
||||
v.pad((None, (0, padding), None)).contiguous()
|
||||
).realize()
|
||||
|
||||
q = self.query(x)
|
||||
n_ctx = q.shape[1]
|
||||
assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
|
||||
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
||||
head_dim = q.shape[-1] // self.n_head
|
||||
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
|
||||
attn = Tensor.scaled_dot_product_attention(
|
||||
q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None
|
||||
)
|
||||
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
return self.out(wv)
|
||||
|
||||
|
||||
class ResidualAttentionBlock:
|
||||
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
|
||||
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
|
||||
def __init__(
|
||||
self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None
|
||||
):
|
||||
self.attn = MultiHeadAttention(
|
||||
n_state,
|
||||
n_head,
|
||||
kv_caching="self" if is_decoder_block else None,
|
||||
max_self_attn_cache_len=max_self_attn_cache_len,
|
||||
)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head, kv_caching="cross")
|
||||
if is_decoder_block
|
||||
else None
|
||||
)
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
||||
|
||||
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
|
||||
self.mlp = [
|
||||
nn.Linear(n_state, n_state * 4),
|
||||
Tensor.gelu,
|
||||
nn.Linear(n_state * 4, n_state),
|
||||
]
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
||||
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||
x = x + self.mlp_ln(x).sequential(self.mlp)
|
||||
return x.realize()
|
||||
|
||||
|
||||
class AudioEncoder:
|
||||
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
|
||||
def __init__(
|
||||
self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_
|
||||
):
|
||||
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
|
||||
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
|
||||
self.conv2 = nn.Conv1d(
|
||||
n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.blocks = [
|
||||
ResidualAttentionBlock(n_audio_state, n_audio_head)
|
||||
for _ in range(n_audio_layer)
|
||||
]
|
||||
self.ln_post = nn.LayerNorm(n_audio_state)
|
||||
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
||||
self.encode = TinyJit(self.__call__)
|
||||
|
@ -97,14 +144,27 @@ class AudioEncoder:
|
|||
x = self.ln_post(x)
|
||||
return x.realize()
|
||||
|
||||
|
||||
class TextDecoder:
|
||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
||||
def __init__(
|
||||
self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_
|
||||
):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
|
||||
self.max_self_attn_cache_len = (
|
||||
self.max_tokens_to_sample * 2 + 5
|
||||
) # roughly prompt + start toks + max_tokens_to_sample
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
|
||||
self.blocks = [
|
||||
ResidualAttentionBlock(
|
||||
n_text_state,
|
||||
n_text_head,
|
||||
is_decoder_block=True,
|
||||
max_self_attn_cache_len=self.max_self_attn_cache_len,
|
||||
)
|
||||
for _ in range(n_text_layer)
|
||||
]
|
||||
self.ln = nn.LayerNorm(n_text_state)
|
||||
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
||||
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||
|
@ -117,18 +177,23 @@ class TextDecoder:
|
|||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen]
|
||||
if pos == 0:
|
||||
for block in (self.blocks if streaming else self.blocks_start_tok):
|
||||
x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
|
||||
for block in self.blocks if streaming else self.blocks_start_tok:
|
||||
x = block(
|
||||
x, xa=encoded_audio, mask=self.mask, len=0
|
||||
) # pass xa for cross attn kv caching
|
||||
return self.output_tok(x) if streaming else self.start_output_tok(x)
|
||||
else:
|
||||
for block in self.blocks_after_start_tok:
|
||||
len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
|
||||
len_v = Variable(
|
||||
"self_attn_cache_len", 1, self.max_self_attn_cache_len
|
||||
).bind(pos)
|
||||
x = block(x, mask=self.mask, len=len_v)
|
||||
return self.after_start_output_tok(x)
|
||||
|
||||
def output_tok(self, x):
|
||||
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
||||
|
||||
|
||||
class Whisper:
|
||||
def __init__(self, dims, batch_size=1):
|
||||
self.encoder = AudioEncoder(**dims)
|
||||
|
@ -145,7 +210,10 @@ HOP_LENGTH = 160
|
|||
N_MELS = 80
|
||||
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
|
||||
|
||||
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
|
||||
|
||||
def prep_audio(
|
||||
waveforms: List[np.ndarray], batch_size: int, truncate=False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
:param waveforms: A list of possibly variable length 16000Hz audio samples
|
||||
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
|
||||
|
@ -153,24 +221,30 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) ->
|
|||
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
|
||||
:return: mel spectrogram of the given waveforms
|
||||
"""
|
||||
|
||||
def pad_or_trim(arr, target_len):
|
||||
curr_len = len(arr)
|
||||
if curr_len == target_len:
|
||||
return arr
|
||||
elif curr_len < target_len:
|
||||
return np.pad(arr, (0, target_len - curr_len), 'constant')
|
||||
return np.pad(arr, (0, target_len - curr_len), "constant")
|
||||
else:
|
||||
return arr[:target_len]
|
||||
|
||||
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
|
||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
|
||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
|
||||
max_len += SAMPLES_PER_SEGMENT - r
|
||||
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
||||
assert waveforms.shape[0] <= batch_size
|
||||
if waveforms.shape[0] < batch_size:
|
||||
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
|
||||
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
|
||||
waveforms = np.pad(
|
||||
waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0))
|
||||
)
|
||||
|
||||
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
stft = librosa.stft(
|
||||
waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann", dtype=np.csingle
|
||||
)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
|
||||
|
@ -180,22 +254,118 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) ->
|
|||
|
||||
return log_spec
|
||||
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
|
||||
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
|
||||
"he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian",
|
||||
"th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu",
|
||||
"fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
|
||||
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
|
||||
"gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian",
|
||||
"be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole",
|
||||
"ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy",
|
||||
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
|
||||
def get_encoding(encoding_name):
|
||||
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
|
||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
|
||||
with fetch(
|
||||
f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken"
|
||||
).open() as f:
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in f if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
|
@ -212,12 +382,15 @@ def get_encoding(encoding_name):
|
|||
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
||||
n_vocab += len(specials)
|
||||
import tiktoken
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=encoding_name,
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens)
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
|
@ -232,23 +405,28 @@ MODEL_URLS = {
|
|||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
|
||||
|
||||
def init_whisper(model_name="tiny.en", batch_size=1):
|
||||
assert MODEL_URLS[model_name] is not None
|
||||
|
||||
filename = fetch(MODEL_URLS[model_name])
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state['dims'], batch_size)
|
||||
load_state_dict(model, state['model_state_dict'], strict=False)
|
||||
model = Whisper(state["dims"], batch_size)
|
||||
load_state_dict(model, state["model_state_dict"], strict=False)
|
||||
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
||||
return model, enc
|
||||
|
||||
|
||||
def load_file_waveform(filename):
|
||||
waveform, _ = librosa.load(filename, sr=RATE)
|
||||
return waveform
|
||||
|
||||
|
||||
def transcribe_file(model, enc, filename):
|
||||
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
|
||||
|
||||
|
||||
def transcribe_waveform(model, enc, waveforms, truncate=False):
|
||||
"""
|
||||
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
|
||||
|
@ -260,12 +438,18 @@ def transcribe_waveform(model, enc, waveforms, truncate=False):
|
|||
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
|
||||
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
|
||||
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
|
||||
raise Exception("Multi-segment transcription not supported with batch audio input")
|
||||
raise Exception(
|
||||
"Multi-segment transcription not supported with batch audio input"
|
||||
)
|
||||
|
||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
if model.is_multilingual:
|
||||
# TODO detect language
|
||||
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
|
||||
language_token = (
|
||||
enc._special_tokens["<|startoftranscript|>"]
|
||||
+ 1
|
||||
+ tuple(LANGUAGES.keys()).index("en")
|
||||
)
|
||||
start_tokens.append(language_token)
|
||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||
|
@ -274,52 +458,83 @@ def transcribe_waveform(model, enc, waveforms, truncate=False):
|
|||
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
|
||||
|
||||
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
|
||||
encoded_audio = model.encoder.encode(
|
||||
Tensor(log_spec[:, :, curr_frame : curr_frame + FRAMES_PER_SEGMENT])
|
||||
)
|
||||
pos = 0
|
||||
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
||||
if curr_frame > 0:
|
||||
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt = np.concatenate((
|
||||
prompt = np.concatenate(
|
||||
(
|
||||
[enc._special_tokens["<|startofprev|>"]],
|
||||
transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :],
|
||||
start_tokens))
|
||||
start_tokens,
|
||||
)
|
||||
)
|
||||
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
|
||||
transcription_start_index = len(curr_segment_tokens[0])
|
||||
|
||||
for i in range(model.decoder.max_tokens_to_sample):
|
||||
out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
|
||||
out = model.decoder(
|
||||
Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]),
|
||||
pos,
|
||||
encoded_audio,
|
||||
streaming=curr_frame > 0,
|
||||
)
|
||||
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
||||
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
|
||||
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
|
||||
curr_segment_tokens = np.concatenate(
|
||||
(curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1
|
||||
)
|
||||
pos = curr_segment_tokens.shape[-1] - 1
|
||||
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
|
||||
if DEBUG >= 1:
|
||||
print(
|
||||
i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens))
|
||||
)
|
||||
if (curr_segment_tokens[:, -1] == eot).all():
|
||||
break
|
||||
|
||||
for i, t in enumerate(curr_segment_tokens):
|
||||
eot_index = np.where(t == eot)[0]
|
||||
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
||||
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
|
||||
transcription_tokens[i] = np.concatenate(
|
||||
(transcription_tokens[i], t[transcription_start_index:eot_index])
|
||||
)
|
||||
|
||||
transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
|
||||
transcriptions = list(
|
||||
map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens)
|
||||
)
|
||||
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
|
||||
|
||||
|
||||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
|
||||
|
||||
def listener(q):
|
||||
import pyaudio
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
stream = p.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK,
|
||||
)
|
||||
print("listening")
|
||||
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
data = stream.read(CHUNK)
|
||||
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
|
||||
waveform = (np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3
|
||||
q.put(waveform)
|
||||
print("done listening")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
|
||||
model, enc = init_whisper(
|
||||
"small.en" if getenv("SMALL") else "tiny.en", batch_size=1
|
||||
)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
print(transcribe_file(model, enc, sys.argv[1]))
|
||||
|
@ -330,20 +545,29 @@ if __name__ == "__main__":
|
|||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
lst = [
|
||||
enc._special_tokens["<|startoftranscript|>"],
|
||||
enc._special_tokens["<|notimestamps|>"],
|
||||
]
|
||||
total = None
|
||||
did_read = False
|
||||
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
while not q.empty() or total is None:
|
||||
waveform = q.get()
|
||||
if total is None: total = waveform
|
||||
else: total = np.concatenate([total, waveform])
|
||||
if total is None:
|
||||
total = waveform
|
||||
else:
|
||||
total = np.concatenate([total, waveform])
|
||||
did_read = True
|
||||
if did_read:
|
||||
log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
|
||||
log_spec = prep_audio(
|
||||
total.reshape(1, -1), model.batch_size, truncate=True
|
||||
)
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
|
||||
out = model.decoder(
|
||||
Tensor([lst]), 0, encoded_audio, streaming=True
|
||||
).realize()
|
||||
idx = int(out[0, -1].argmax().numpy().item())
|
||||
lst.append(idx)
|
||||
dec = enc.decode(lst)
|
||||
|
|
|
@ -10,11 +10,14 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import BatchNorm2d, Conv2d
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
|
||||
def show_labels(prediction, confidence=0.5, num_classes=80):
|
||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_bytes()
|
||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
||||
coco_labels = fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||
).read_bytes()
|
||||
coco_labels = coco_labels.decode("utf-8").split("\n")
|
||||
prediction = prediction.detach().numpy()
|
||||
conf_mask = (prediction[:,:,4] > confidence)
|
||||
conf_mask = prediction[:, :, 4] > confidence
|
||||
prediction *= np.expand_dims(conf_mask, 2)
|
||||
labels = []
|
||||
# Iterate over batches
|
||||
|
@ -30,16 +33,22 @@ def show_labels(prediction, confidence=0.5, num_classes=80):
|
|||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7))
|
||||
classes, indexes = np.unique(image_pred_[:, -1], return_index=True)
|
||||
for index, coco_class in enumerate(classes):
|
||||
label, probability = coco_labels[int(coco_class)], image_pred_[indexes[index]][4] * 100
|
||||
label, probability = (
|
||||
coco_labels[int(coco_class)],
|
||||
image_pred_[indexes[index]][4] * 100,
|
||||
)
|
||||
print(f"Detected {label} {probability:.2f}")
|
||||
labels.append(label)
|
||||
return labels
|
||||
|
||||
|
||||
def add_boxes(img, prediction):
|
||||
if isinstance(prediction, int): # no predictions
|
||||
return img
|
||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
|
||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
||||
coco_labels = fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||
)
|
||||
coco_labels = coco_labels.decode("utf-8").split("\n")
|
||||
height, width = img.shape[0:2]
|
||||
scale_factor = 608 / width
|
||||
prediction[:, [1, 3]] -= (608 - scale_factor * width) / 2
|
||||
|
@ -55,9 +64,18 @@ def add_boxes(img, prediction):
|
|||
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
|
||||
c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4
|
||||
img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1)
|
||||
img = cv2.putText(img, label, (corner1[0], corner1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1)
|
||||
img = cv2.putText(
|
||||
img,
|
||||
label,
|
||||
(corner1[0], corner1[1] + t_size[1] + 4),
|
||||
cv2.FONT_HERSHEY_PLAIN,
|
||||
1,
|
||||
[225, 255, 255],
|
||||
1,
|
||||
)
|
||||
return img
|
||||
|
||||
|
||||
def bbox_iou(box1, box2):
|
||||
"""
|
||||
Returns the IoU of two bounding boxes
|
||||
|
@ -74,24 +92,27 @@ def bbox_iou(box1, box2):
|
|||
inter_rect_x2 = np.maximum(b1_x2, b2_x2)
|
||||
inter_rect_y2 = np.maximum(b1_y2, b2_y2)
|
||||
# Intersection area
|
||||
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, 99999)
|
||||
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(
|
||||
inter_rect_y2 - inter_rect_y1 + 1, 0, 99999
|
||||
)
|
||||
# Union Area
|
||||
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
|
||||
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
|
||||
iou = inter_area / (b1_area + b2_area - inter_area)
|
||||
return iou
|
||||
|
||||
|
||||
def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
||||
prediction = prediction.detach().numpy()
|
||||
conf_mask = (prediction[:,:,4] > confidence)
|
||||
conf_mask = prediction[:, :, 4] > confidence
|
||||
conf_mask = np.expand_dims(conf_mask, 2)
|
||||
prediction = prediction * conf_mask
|
||||
# Non max suppression
|
||||
box_corner = prediction
|
||||
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
|
||||
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
|
||||
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
|
||||
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
|
||||
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
||||
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
||||
prediction[:, :, :4] = box_corner[:, :, :4]
|
||||
write = False
|
||||
# Process img
|
||||
|
@ -121,7 +142,10 @@ def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
|||
for i in range(image_pred_class.shape[0]):
|
||||
# Get the IOUs of all boxes that come after the one we are looking at in the loop
|
||||
try:
|
||||
ious = bbox_iou(np.expand_dims(image_pred_class[i], axis=0), image_pred_class[i+1:])
|
||||
ious = bbox_iou(
|
||||
np.expand_dims(image_pred_class[i], axis=0),
|
||||
image_pred_class[i + 1 :],
|
||||
)
|
||||
except:
|
||||
break
|
||||
# Zero out all the detections that have IoU > threshold
|
||||
|
@ -139,6 +163,7 @@ def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
|||
output = np.concatenate((output, out))
|
||||
return output
|
||||
|
||||
|
||||
def infer(model, img):
|
||||
img = np.array(Image.fromarray(img).resize((608, 608)))
|
||||
img = img[:, :, ::-1].transpose((2, 0, 1))
|
||||
|
@ -149,9 +174,9 @@ def infer(model, img):
|
|||
|
||||
def parse_cfg(cfg):
|
||||
# Return a list of blocks
|
||||
lines = cfg.decode("utf-8").split('\n')
|
||||
lines = cfg.decode("utf-8").split("\n")
|
||||
lines = [x for x in lines if len(x) > 0]
|
||||
lines = [x for x in lines if x[0] != '#']
|
||||
lines = [x for x in lines if x[0] != "#"]
|
||||
lines = [x.rstrip().lstrip() for x in lines]
|
||||
block, blocks = {}, []
|
||||
for line in lines:
|
||||
|
@ -166,6 +191,7 @@ def parse_cfg(cfg):
|
|||
blocks.append(block)
|
||||
return blocks
|
||||
|
||||
|
||||
# TODO: Speed up this function, avoid copying stuff from GPU to CPU
|
||||
def predict_transform(prediction, inp_dim, anchors, num_classes):
|
||||
batch_size = prediction.shape[0]
|
||||
|
@ -173,9 +199,13 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
|
|||
grid_size = inp_dim // stride
|
||||
bbox_attrs = 5 + num_classes
|
||||
num_anchors = len(anchors)
|
||||
prediction = prediction.reshape(shape=(batch_size, bbox_attrs*num_anchors, grid_size*grid_size))
|
||||
prediction = prediction.reshape(
|
||||
shape=(batch_size, bbox_attrs * num_anchors, grid_size * grid_size)
|
||||
)
|
||||
prediction = prediction.transpose(1, 2)
|
||||
prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs))
|
||||
prediction = prediction.reshape(
|
||||
shape=(batch_size, grid_size * grid_size * num_anchors, bbox_attrs)
|
||||
)
|
||||
prediction_cpu = prediction.numpy()
|
||||
for i in (0, 1, 4):
|
||||
prediction_cpu[:, :, i] = 1 / (1 + np.exp(-prediction_cpu[:, :, i]))
|
||||
|
@ -193,7 +223,9 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
|
|||
anchors = np.expand_dims(anchors, 0)
|
||||
prediction_cpu[:, :, :2] += x_y_offset
|
||||
prediction_cpu[:, :, 2:4] = np.exp(prediction_cpu[:, :, 2:4]) * anchors
|
||||
prediction_cpu[:,:,5:5+num_classes] = 1 / (1 + np.exp(-prediction_cpu[:,:,5:5+num_classes]))
|
||||
prediction_cpu[:, :, 5 : 5 + num_classes] = 1 / (
|
||||
1 + np.exp(-prediction_cpu[:, :, 5 : 5 + num_classes])
|
||||
)
|
||||
prediction_cpu[:, :, :4] *= stride
|
||||
return Tensor(prediction_cpu)
|
||||
|
||||
|
@ -222,18 +254,33 @@ class Darknet:
|
|||
filters = int(x["filters"])
|
||||
padding = int(x["pad"])
|
||||
pad = (int(x["size"]) - 1) // 2 if padding else 0
|
||||
module.append(Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias=bias))
|
||||
module.append(
|
||||
Conv2d(
|
||||
prev_filters,
|
||||
filters,
|
||||
int(x["size"]),
|
||||
int(x["stride"]),
|
||||
pad,
|
||||
bias=bias,
|
||||
)
|
||||
)
|
||||
# BatchNorm2d
|
||||
if batch_normalize:
|
||||
module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True))
|
||||
module.append(
|
||||
BatchNorm2d(filters, eps=1e-05, track_running_stats=True)
|
||||
)
|
||||
# LeakyReLU activation
|
||||
if activation == "leaky":
|
||||
module.append(lambda x: x.leakyrelu(0.1))
|
||||
elif module_type == "maxpool":
|
||||
size, stride = int(x["size"]), int(x["stride"])
|
||||
module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride))
|
||||
module.append(
|
||||
lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride)
|
||||
)
|
||||
elif module_type == "upsample":
|
||||
module.append(lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1)))
|
||||
module.append(
|
||||
lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1))
|
||||
)
|
||||
elif module_type == "route":
|
||||
x["layers"] = x["layers"].split(",")
|
||||
# Start of route
|
||||
|
@ -243,11 +290,15 @@ class Darknet:
|
|||
end = int(x["layers"][1])
|
||||
except:
|
||||
end = 0
|
||||
if start > 0: start -= index
|
||||
if end > 0: end -= index
|
||||
if start > 0:
|
||||
start -= index
|
||||
if end > 0:
|
||||
end -= index
|
||||
module.append(lambda x: x)
|
||||
if end < 0:
|
||||
filters = output_filters[index + start] + output_filters[index + end]
|
||||
filters = (
|
||||
output_filters[index + start] + output_filters[index + end]
|
||||
)
|
||||
else:
|
||||
filters = output_filters[index + start]
|
||||
# Shortcut corresponds to skip connection
|
||||
|
@ -256,7 +307,9 @@ class Darknet:
|
|||
elif module_type == "yolo":
|
||||
mask = list(map(int, x["mask"].split(",")))
|
||||
anchors = [int(a) for a in x["anchors"].split(",")]
|
||||
anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)]
|
||||
anchors = [
|
||||
(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)
|
||||
]
|
||||
module.append([anchors[i] for i in mask])
|
||||
# Append to module_list
|
||||
module_list.append(module)
|
||||
|
@ -308,8 +361,12 @@ class Darknet:
|
|||
# Cast the loaded weights into dims of model weights
|
||||
bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape))
|
||||
bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape))
|
||||
bn_running_mean = bn_running_mean.reshape(shape=tuple(bn.running_mean.shape))
|
||||
bn_running_var = bn_running_var.reshape(shape=tuple(bn.running_var.shape))
|
||||
bn_running_mean = bn_running_mean.reshape(
|
||||
shape=tuple(bn.running_mean.shape)
|
||||
)
|
||||
bn_running_var = bn_running_var.reshape(
|
||||
shape=tuple(bn.running_var.shape)
|
||||
)
|
||||
# Copy data
|
||||
bn.bias = bn_biases
|
||||
bn.weight = bn_weights
|
||||
|
@ -337,7 +394,7 @@ class Darknet:
|
|||
outputs = {} # Cached outputs for route layer
|
||||
detections, write = None, False
|
||||
for i, module in enumerate(modules):
|
||||
module_type = (module["type"])
|
||||
module_type = module["type"]
|
||||
if module_type == "convolutional" or module_type == "upsample":
|
||||
for layer in self.module_list[i]:
|
||||
x = layer(x)
|
||||
|
@ -349,7 +406,8 @@ class Darknet:
|
|||
if len(layers) == 1:
|
||||
x = outputs[i + (layers[0])]
|
||||
else:
|
||||
if (layers[1]) > 0: layers[1] = layers[1] - i
|
||||
if (layers[1]) > 0:
|
||||
layers[1] = layers[1] - i
|
||||
map1 = outputs[i + layers[0]]
|
||||
map2 = outputs[i + layers[1]]
|
||||
x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1))
|
||||
|
@ -364,19 +422,26 @@ class Darknet:
|
|||
if not write:
|
||||
detections, write = x, True
|
||||
else:
|
||||
detections = Tensor(np.concatenate((detections.numpy(), x.numpy()), axis=1))
|
||||
detections = Tensor(
|
||||
np.concatenate((detections.numpy(), x.numpy()), axis=1)
|
||||
)
|
||||
outputs[i] = x
|
||||
return detections
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg'))
|
||||
model = Darknet(
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg"
|
||||
)
|
||||
)
|
||||
print("Loading weights file (237MB). This might take a while…")
|
||||
model.load_weights('https://pjreddie.com/media/files/yolov3.weights')
|
||||
model.load_weights("https://pjreddie.com/media/files/yolov3.weights")
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
|
||||
if url == 'webcam':
|
||||
if url == "webcam":
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
while 1:
|
||||
|
@ -386,21 +451,21 @@ if __name__ == "__main__":
|
|||
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
||||
boxes = add_boxes(np.array(img.resize((608, 608))), prediction)
|
||||
boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow('yolo', boxes)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
cv2.imshow("yolo", boxes)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
elif url.startswith('http'):
|
||||
elif url.startswith("http"):
|
||||
img_stream = io.BytesIO(fetch(url))
|
||||
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
|
||||
else:
|
||||
img = cv2.imread(url)
|
||||
st = time.time()
|
||||
print('running inference…')
|
||||
print("running inference…")
|
||||
prediction = infer(model, img)
|
||||
print(f'did inference in {(time.time() - st):2f}s')
|
||||
print(f"did inference in {(time.time() - st):2f}s")
|
||||
show_labels(prediction)
|
||||
prediction = process_results(prediction)
|
||||
boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction)
|
||||
cv2.imwrite('boxes.jpg', boxes)
|
||||
cv2.imwrite("boxes.jpg", boxes)
|
||||
|
|
|
@ -12,7 +12,10 @@ if not Path("yolov8n-seg.onnx").is_file():
|
|||
model.export(format="onnx", imgsz=[480, 640])
|
||||
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
|
||||
# TODO: move get example inputs to onnx
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
input_shapes = {
|
||||
inp.name: tuple(x.dim_value for x in inp.type.tensor_type.shape.dim)
|
||||
for inp in onnx_model.graph.input
|
||||
}
|
||||
print(input_shapes)
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx({"images": Tensor.zeros(1, 3, 480, 640)}, debug=True)
|
||||
|
|
|
@ -12,8 +12,11 @@ from tinygrad.nn.state import safe_load, load_state_dict
|
|||
# Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||
# The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
|
||||
|
||||
|
||||
# Pre processing image functions.
|
||||
def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
|
||||
def compute_transform(
|
||||
image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32
|
||||
):
|
||||
shape = image.shape[:2] # current shape [height, width]
|
||||
new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
|
@ -24,25 +27,39 @@ def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False,
|
|||
new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad
|
||||
dw /= 2
|
||||
dh /= 2
|
||||
image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else image
|
||||
image = (
|
||||
cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
if shape[::-1] != new_unpad
|
||||
else image
|
||||
)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||
image = cv2.copyMakeBorder(
|
||||
image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||
)
|
||||
return image
|
||||
|
||||
|
||||
def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
|
||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||
auto = same_shapes and model_pt
|
||||
im = Tensor([compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im])
|
||||
im = Tensor(
|
||||
[
|
||||
compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride)
|
||||
for x in im
|
||||
]
|
||||
)
|
||||
im = Tensor.stack(im) if im.shape[0] > 1 else im
|
||||
im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return im
|
||||
|
||||
|
||||
# Post Processing functions
|
||||
def box_area(box):
|
||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
|
||||
|
||||
def box_iou(box1, box2):
|
||||
lt = np.maximum(box1[:, None, :2], box2[:, :2])
|
||||
rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
|
||||
|
@ -53,6 +70,7 @@ def box_iou(box1, box2):
|
|||
iou = inter / (area1 + area2 - inter)
|
||||
return iou
|
||||
|
||||
|
||||
def compute_nms(boxes, scores, iou_threshold):
|
||||
order, keep = scores.argsort()[::-1], []
|
||||
while order.size > 0:
|
||||
|
@ -65,7 +83,16 @@ def compute_nms(boxes, scores, iou_threshold):
|
|||
order = order[inds + 1]
|
||||
return np.array(keep)
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, max_det=300, nc=0, max_wh=7680):
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
agnostic=False,
|
||||
max_det=300,
|
||||
nc=0,
|
||||
max_wh=7680,
|
||||
):
|
||||
prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
|
||||
bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
|
||||
xc = np.amax(prediction[:, 4 : 4 + nc], axis=1) > conf_thres
|
||||
|
@ -74,12 +101,16 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=Fa
|
|||
|
||||
for xi, x in enumerate(prediction):
|
||||
x = x.swapaxes(0, -1)[xc[xi]]
|
||||
if not x.shape[0]: continue
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
|
||||
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(cls, axis=1, keepdims=True)
|
||||
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(
|
||||
cls, axis=1, keepdims=True
|
||||
)
|
||||
x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
|
||||
x = x[conf.ravel() > conf_thres]
|
||||
if not x.shape[0]: continue
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
x = x[np.argsort(-x[:, 4])]
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh)
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
|
@ -87,12 +118,15 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=Fa
|
|||
output[xi] = x[i]
|
||||
return output
|
||||
|
||||
|
||||
def postprocess(preds, img, orig_imgs):
|
||||
print('copying to CPU now for post processing')
|
||||
print("copying to CPU now for post processing")
|
||||
# if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
|
||||
# TODO: make non_max_suppression in tinygrad - to make this faster
|
||||
preds = preds.numpy() if isinstance(preds, Tensor) else preds
|
||||
preds = non_max_suppression(prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300)
|
||||
preds = non_max_suppression(
|
||||
prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300
|
||||
)
|
||||
all_preds = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
|
@ -101,8 +135,16 @@ def postprocess(preds, img, orig_imgs):
|
|||
all_preds.append(pred)
|
||||
return all_preds
|
||||
|
||||
def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5):
|
||||
color_dict = {label: tuple((((i+1) * 50) % 256, ((i+1) * 100) % 256, ((i+1) * 150) % 256)) for i, label in enumerate(class_labels)}
|
||||
|
||||
def draw_bounding_boxes_and_save(
|
||||
orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5
|
||||
):
|
||||
color_dict = {
|
||||
label: tuple(
|
||||
(((i + 1) * 50) % 256, ((i + 1) * 100) % 256, ((i + 1) * 150) % 256)
|
||||
)
|
||||
for i, label in enumerate(class_labels)
|
||||
}
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
|
||||
def is_bright_color(color):
|
||||
|
@ -110,9 +152,15 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
|
|||
brightness = (r * 299 + g * 587 + b * 114) / 1000
|
||||
return brightness > 127
|
||||
|
||||
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(zip(orig_img_paths, output_img_paths, all_predictions)):
|
||||
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(
|
||||
zip(orig_img_paths, output_img_paths, all_predictions)
|
||||
):
|
||||
predictions = np.array(predictions)
|
||||
orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1)
|
||||
orig_img = (
|
||||
cv2.imread(orig_img_path)
|
||||
if not isinstance(orig_img_path, np.ndarray)
|
||||
else cv2.imdecode(orig_img_path, 1)
|
||||
)
|
||||
height, width, _ = orig_img.shape
|
||||
box_thickness = int((height + width) / 400)
|
||||
font_scale = (height + width) / 2500
|
||||
|
@ -129,10 +177,29 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
|
|||
cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
|
||||
label = f"{class_labels[class_id]} {conf:.2f}"
|
||||
text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
|
||||
label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1)
|
||||
cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1)
|
||||
label_y, bg_y = (
|
||||
(y1 - 4, y1 - text_size[1] - 4)
|
||||
if y1 - text_size[1] - 4 > 0
|
||||
else (y1 + text_size[1], y1)
|
||||
)
|
||||
cv2.rectangle(
|
||||
orig_img,
|
||||
(x1, bg_y),
|
||||
(x1 + text_size[0], bg_y + text_size[1]),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
|
||||
cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA)
|
||||
cv2.putText(
|
||||
orig_img,
|
||||
label,
|
||||
(x1, label_y),
|
||||
font,
|
||||
font_scale,
|
||||
font_color,
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
for class_id, pred_list in grouped_preds.items():
|
||||
pred_list = np.array(pred_list)
|
||||
|
@ -155,7 +222,8 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
|
|||
print(f"- {obj}: {count}")
|
||||
|
||||
cv2.imwrite(output_img_path, orig_img)
|
||||
print(f'saved detections at {output_img_path}')
|
||||
print(f"saved detections at {output_img_path}")
|
||||
|
||||
|
||||
# utility functions for forward pass.
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
|
@ -168,6 +236,7 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
|||
return c_xy.cat(wh, dim=1)
|
||||
return x1y1.cat(x2y2, dim=1)
|
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
|
@ -183,25 +252,39 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
|||
anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2))
|
||||
stride_tensor.append(Tensor.full((h * w), stride))
|
||||
anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
|
||||
stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
|
||||
stride_tensor = (
|
||||
stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
|
||||
)
|
||||
return anchor_points, stride_tensor
|
||||
|
||||
|
||||
# this function is from the original implementation
|
||||
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
||||
if d > 1:
|
||||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
||||
k = (
|
||||
d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
|
||||
) # actual kernel-size
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
||||
return p
|
||||
|
||||
|
||||
def clip_boxes(boxes, shape):
|
||||
boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
|
||||
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
|
||||
return boxes
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||
gain = ratio_pad if ratio_pad else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
|
||||
pad = ((img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2)
|
||||
gain = (
|
||||
ratio_pad
|
||||
if ratio_pad
|
||||
else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
|
||||
)
|
||||
pad = (
|
||||
(img1_shape[1] - img0_shape[1] * gain) / 2,
|
||||
(img1_shape[0] - img0_shape[0] * gain) / 2,
|
||||
)
|
||||
boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
|
||||
boxes_np[..., [0, 2]] -= pad[0]
|
||||
boxes_np[..., [1, 3]] -= pad[1]
|
||||
|
@ -209,6 +292,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
|||
boxes_np = clip_boxes(boxes_np, img0_shape)
|
||||
return boxes_np
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
xy = x[..., :2] # center x, y
|
||||
wh = x[..., 2:4] # width, height
|
||||
|
@ -217,8 +301,16 @@ def xywh2xyxy(x):
|
|||
result = np.concatenate((xy1, xy2), axis=-1)
|
||||
return Tensor(result) if isinstance(x, Tensor) else result
|
||||
|
||||
|
||||
def get_variant_multiples(variant):
|
||||
return {'n':(0.33, 0.25, 2.0), 's':(0.33, 0.50, 2.0), 'm':(0.67, 0.75, 1.5), 'l':(1.0, 1.0, 1.0), 'x':(1, 1.25, 1.0) }.get(variant, None)
|
||||
return {
|
||||
"n": (0.33, 0.25, 2.0),
|
||||
"s": (0.33, 0.50, 2.0),
|
||||
"m": (0.67, 0.75, 1.5),
|
||||
"l": (1.0, 1.0, 1.0),
|
||||
"x": (1, 1.25, 1.0),
|
||||
}.get(variant, None)
|
||||
|
||||
|
||||
def label_predictions(all_predictions):
|
||||
class_index_count = defaultdict(int)
|
||||
|
@ -230,6 +322,7 @@ def label_predictions(all_predictions):
|
|||
|
||||
return dict(class_index_count)
|
||||
|
||||
|
||||
# this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling)
|
||||
class Upsample:
|
||||
def __init__(self, scale_factor: int, mode: str = "nearest") -> None:
|
||||
|
@ -240,41 +333,86 @@ class Upsample:
|
|||
def __call__(self, x: Tensor) -> Tensor:
|
||||
assert len(x.shape) > 2 and len(x.shape) <= 5
|
||||
(b, c), _lens = x.shape[:2], len(x.shape[2:])
|
||||
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(*[1, 1, 1] + [self.scale_factor] * _lens)
|
||||
return tmp.reshape(list(x.shape) + [self.scale_factor] * _lens).permute([0, 1] + list(chain.from_iterable([[y+2, y+2+_lens] for y in range(_lens)]))).reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
|
||||
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(
|
||||
*[1, 1, 1] + [self.scale_factor] * _lens
|
||||
)
|
||||
return (
|
||||
tmp.reshape(list(x.shape) + [self.scale_factor] * _lens)
|
||||
.permute(
|
||||
[0, 1]
|
||||
+ list(
|
||||
chain.from_iterable([[y + 2, y + 2 + _lens] for y in range(_lens)])
|
||||
)
|
||||
)
|
||||
.reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
|
||||
)
|
||||
|
||||
|
||||
class Conv_Block:
|
||||
def __init__(self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None):
|
||||
self.conv = Conv2d(c1,c2, kernel_size, stride, padding=autopad(kernel_size, padding, dilation), bias=False, groups=groups, dilation=dilation)
|
||||
def __init__(
|
||||
self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None
|
||||
):
|
||||
self.conv = Conv2d(
|
||||
c1,
|
||||
c2,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=autopad(kernel_size, padding, dilation),
|
||||
bias=False,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
self.bn = BatchNorm2d(c2, eps=0.001)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.bn(self.conv(x)).silu()
|
||||
|
||||
|
||||
class Bottleneck:
|
||||
def __init__(self, c1, c2 , shortcut: bool, g=1, kernels: list = (3,3), channel_factor=0.5):
|
||||
def __init__(
|
||||
self, c1, c2, shortcut: bool, g=1, kernels: list = (3, 3), channel_factor=0.5
|
||||
):
|
||||
c_ = int(c2 * channel_factor)
|
||||
self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None)
|
||||
self.cv2 = Conv_Block(c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g)
|
||||
self.cv2 = Conv_Block(
|
||||
c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g
|
||||
)
|
||||
self.residual = c1 == c2 and shortcut
|
||||
|
||||
def __call__(self, x):
|
||||
return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class C2f:
|
||||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
|
||||
self.c = int(c2 * e)
|
||||
self.cv1 = Conv_Block(c1, 2 * self.c, 1,)
|
||||
self.cv1 = Conv_Block(
|
||||
c1,
|
||||
2 * self.c,
|
||||
1,
|
||||
)
|
||||
self.cv2 = Conv_Block((2 + n) * self.c, c2, 1)
|
||||
self.bottleneck = [Bottleneck(self.c, self.c, shortcut, g, kernels=[(3, 3), (3, 3)], channel_factor=1.0) for _ in range(n)]
|
||||
self.bottleneck = [
|
||||
Bottleneck(
|
||||
self.c,
|
||||
self.c,
|
||||
shortcut,
|
||||
g,
|
||||
kernels=[(3, 3), (3, 3)],
|
||||
channel_factor=1.0,
|
||||
)
|
||||
for _ in range(n)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
y = list(self.cv1(x).chunk(2, 1))
|
||||
y.extend(m(y[-1]) for m in self.bottleneck)
|
||||
z = y[0]
|
||||
for i in y[1:]: z = z.cat(i, dim=1)
|
||||
for i in y[1:]:
|
||||
z = z.cat(i, dim=1)
|
||||
return self.cv2(z)
|
||||
|
||||
|
||||
class SPPF:
|
||||
def __init__(self, c1, c2, k=5):
|
||||
c_ = c1 // 2 # hidden channels
|
||||
|
@ -282,7 +420,9 @@ class SPPF:
|
|||
self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None)
|
||||
|
||||
# TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually.
|
||||
self.maxpool = lambda x : x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1)
|
||||
self.maxpool = lambda x: x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(
|
||||
kernel_size=k, stride=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.cv1(x)
|
||||
|
@ -291,6 +431,7 @@ class SPPF:
|
|||
x4 = self.maxpool(x3)
|
||||
return self.cv2(x.cat(x2, x3, x4, dim=1))
|
||||
|
||||
|
||||
class DFL:
|
||||
def __init__(self, c1=16):
|
||||
self.conv = Conv2d(c1, 1, 1, bias=False)
|
||||
|
@ -300,15 +441,33 @@ class DFL:
|
|||
|
||||
def __call__(self, x):
|
||||
b, c, a = x.shape # batch, channels, anchors
|
||||
return self.conv(x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)).reshape(b, 4, a)
|
||||
return self.conv(
|
||||
x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)
|
||||
).reshape(b, 4, a)
|
||||
|
||||
|
||||
# backbone
|
||||
class Darknet:
|
||||
def __init__(self, w, r, d):
|
||||
self.b1 = [Conv_Block(c1=3, c2= int(64*w), kernel_size=3, stride=2, padding=1), Conv_Block(int(64*w), int(128*w), kernel_size=3, stride=2, padding=1)]
|
||||
self.b2 = [C2f(c1=int(128*w), c2=int(128*w), n=round(3*d), shortcut=True), Conv_Block(int(128*w), int(256*w), 3, 2, 1), C2f(int(256*w), int(256*w), round(6*d), True)]
|
||||
self.b3 = [Conv_Block(int(256*w), int(512*w), kernel_size=3, stride=2, padding=1), C2f(int(512*w), int(512*w), round(6*d), True)]
|
||||
self.b4 = [Conv_Block(int(512*w), int(512*w*r), kernel_size=3, stride=2, padding=1), C2f(int(512*w*r), int(512*w*r), round(3*d), True)]
|
||||
self.b1 = [
|
||||
Conv_Block(c1=3, c2=int(64 * w), kernel_size=3, stride=2, padding=1),
|
||||
Conv_Block(int(64 * w), int(128 * w), kernel_size=3, stride=2, padding=1),
|
||||
]
|
||||
self.b2 = [
|
||||
C2f(c1=int(128 * w), c2=int(128 * w), n=round(3 * d), shortcut=True),
|
||||
Conv_Block(int(128 * w), int(256 * w), 3, 2, 1),
|
||||
C2f(int(256 * w), int(256 * w), round(6 * d), True),
|
||||
]
|
||||
self.b3 = [
|
||||
Conv_Block(int(256 * w), int(512 * w), kernel_size=3, stride=2, padding=1),
|
||||
C2f(int(512 * w), int(512 * w), round(6 * d), True),
|
||||
]
|
||||
self.b4 = [
|
||||
Conv_Block(
|
||||
int(512 * w), int(512 * w * r), kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
C2f(int(512 * w * r), int(512 * w * r), round(3 * d), True),
|
||||
]
|
||||
self.b5 = [SPPF(int(512 * w * r), int(512 * w * r), 5)]
|
||||
|
||||
def return_modules(self):
|
||||
|
@ -322,16 +481,28 @@ class Darknet:
|
|||
x5 = x4.sequential(self.b5)
|
||||
return (x2, x3, x5)
|
||||
|
||||
|
||||
# yolo fpn (neck)
|
||||
class Yolov8NECK:
|
||||
def __init__(self, w, r, d): # width_multiple, ratio_multiple, depth_multiple
|
||||
self.up = Upsample(2, mode='nearest')
|
||||
self.n1 = C2f(c1=int(512*w*(1+r)), c2=int(512*w), n=round(3*d), shortcut=False)
|
||||
self.up = Upsample(2, mode="nearest")
|
||||
self.n1 = C2f(
|
||||
c1=int(512 * w * (1 + r)), c2=int(512 * w), n=round(3 * d), shortcut=False
|
||||
)
|
||||
self.n2 = C2f(c1=int(768 * w), c2=int(256 * w), n=round(3 * d), shortcut=False)
|
||||
self.n3 = Conv_Block(c1=int(256*w), c2=int(256*w), kernel_size=3, stride=2, padding=1)
|
||||
self.n3 = Conv_Block(
|
||||
c1=int(256 * w), c2=int(256 * w), kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.n4 = C2f(c1=int(768 * w), c2=int(512 * w), n=round(3 * d), shortcut=False)
|
||||
self.n5 = Conv_Block(c1=int(512* w), c2=int(512 * w), kernel_size=3, stride=2, padding=1)
|
||||
self.n6 = C2f(c1=int(512*w*(1+r)), c2=int(512*w*r), n=round(3*d), shortcut=False)
|
||||
self.n5 = Conv_Block(
|
||||
c1=int(512 * w), c2=int(512 * w), kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.n6 = C2f(
|
||||
c1=int(512 * w * (1 + r)),
|
||||
c2=int(512 * w * r),
|
||||
n=round(3 * d),
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
def return_modules(self):
|
||||
return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6]
|
||||
|
@ -343,6 +514,7 @@ class Yolov8NECK:
|
|||
head_3 = self.n6(self.n5(head_2).cat(p5, dim=1))
|
||||
return [head_1, head_2, head_3]
|
||||
|
||||
|
||||
# task specific head.
|
||||
class DetectionHead:
|
||||
def __init__(self, nc=80, filters=()):
|
||||
|
@ -354,25 +526,41 @@ class DetectionHead:
|
|||
c1 = max(filters[0], self.nc)
|
||||
c2 = max((filters[0] // 4, self.ch * 4))
|
||||
self.dfl = DFL(self.ch)
|
||||
self.cv3 = [[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)] for x in filters]
|
||||
self.cv2 = [[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)] for x in filters]
|
||||
self.cv3 = [
|
||||
[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)]
|
||||
for x in filters
|
||||
]
|
||||
self.cv2 = [
|
||||
[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)]
|
||||
for x in filters
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for i in range(self.nl):
|
||||
x[i] = (x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1))
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
x[i] = x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1)
|
||||
self.anchors, self.strides = (
|
||||
x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)
|
||||
)
|
||||
y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
|
||||
x_cat = y[0].cat(y[1], y[2], dim=2)
|
||||
box, cls = x_cat[:, : self.ch * 4], x_cat[:, self.ch * 4 :]
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
dbox = (
|
||||
dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1)
|
||||
* self.strides
|
||||
)
|
||||
z = dbox.cat(cls.sigmoid(), dim=1)
|
||||
return z
|
||||
|
||||
|
||||
class YOLOv8:
|
||||
def __init__(self, w, r, d, num_classes): #width_multiple, ratio_multiple, depth_multiple
|
||||
def __init__(
|
||||
self, w, r, d, num_classes
|
||||
): # width_multiple, ratio_multiple, depth_multiple
|
||||
self.net = Darknet(w, r, d)
|
||||
self.fpn = Yolov8NECK(w, r, d)
|
||||
self.head = DetectionHead(num_classes, filters=(int(256*w), int(512*w), int(512*w*r)))
|
||||
self.head = DetectionHead(
|
||||
num_classes, filters=(int(256 * w), int(512 * w), int(512 * w * r))
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.net(x)
|
||||
|
@ -383,27 +571,44 @@ class YOLOv8:
|
|||
backbone_modules = [*range(10)]
|
||||
yolov8neck_modules = [12, 15, 16, 18, 19, 21]
|
||||
yolov8_head_weights = [(22, self.head)]
|
||||
return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights]
|
||||
return [
|
||||
*zip(backbone_modules, self.net.return_modules()),
|
||||
*zip(yolov8neck_modules, self.fpn.return_modules()),
|
||||
*yolov8_head_weights,
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
# usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
|
||||
if len(sys.argv) < 2:
|
||||
print("Error: Image URL or path not provided.")
|
||||
sys.exit(1)
|
||||
|
||||
img_path = sys.argv[1]
|
||||
yolo_variant = sys.argv[2] if len(sys.argv) >= 3 else (print("No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']") or 'n')
|
||||
print(f'running inference for YOLO version {yolo_variant}')
|
||||
yolo_variant = (
|
||||
sys.argv[2]
|
||||
if len(sys.argv) >= 3
|
||||
else (
|
||||
print(
|
||||
"No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']"
|
||||
)
|
||||
or "n"
|
||||
)
|
||||
)
|
||||
print(f"running inference for YOLO version {yolo_variant}")
|
||||
|
||||
output_folder_path = Path('./outputs_yolov8')
|
||||
output_folder_path = Path("./outputs_yolov8")
|
||||
output_folder_path.mkdir(parents=True, exist_ok=True)
|
||||
# absolute image path or URL
|
||||
image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
|
||||
image = [cv2.imdecode(image_location[0], 1)]
|
||||
out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}").as_posix()]
|
||||
out_paths = [
|
||||
(
|
||||
output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}"
|
||||
).as_posix()
|
||||
]
|
||||
if not isinstance(image[0], np.ndarray):
|
||||
print('Error in image loading. Check your image file.')
|
||||
print("Error in image loading. Check your image file.")
|
||||
sys.exit(1)
|
||||
pre_processed_image = preprocess(image)
|
||||
|
||||
|
@ -411,19 +616,36 @@ if __name__ == '__main__':
|
|||
depth, width, ratio = get_variant_multiples(yolo_variant)
|
||||
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
|
||||
state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors'))
|
||||
state_dict = safe_load(
|
||||
fetch(
|
||||
f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors"
|
||||
)
|
||||
)
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
|
||||
st = time.time()
|
||||
predictions = yolo_infer(pre_processed_image)
|
||||
print(f'did inference in {int(round(((time.time() - st) * 1000)))}ms')
|
||||
print(f"did inference in {int(round(((time.time() - st) * 1000)))}ms")
|
||||
|
||||
post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image)
|
||||
post_predictions = postprocess(
|
||||
preds=predictions, img=pre_processed_image, orig_imgs=image
|
||||
)
|
||||
|
||||
# v8 and v3 have same 80 class names for Object Detection
|
||||
class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n")
|
||||
class_labels = (
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||
)
|
||||
.read_text()
|
||||
.split("\n")
|
||||
)
|
||||
|
||||
draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels)
|
||||
draw_bounding_boxes_and_save(
|
||||
orig_img_paths=image_location,
|
||||
output_img_paths=out_paths,
|
||||
all_predictions=post_predictions,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# TODO for later:
|
||||
# 1. Fix SPPF minor difference due to maxpool
|
||||
|
|
|
@ -6,9 +6,9 @@ from coremltools.models.neural_network import datatypes, NeuralNetworkBuilder
|
|||
# KxK GEMM with bias
|
||||
K = 64
|
||||
|
||||
input_features = [('image', datatypes.Array(K))]
|
||||
input_features2 = [('image2', datatypes.Array(K))]
|
||||
output_features = [('probs', datatypes.Array(K))]
|
||||
input_features = [("image", datatypes.Array(K))]
|
||||
input_features2 = [("image2", datatypes.Array(K))]
|
||||
output_features = [("probs", datatypes.Array(K))]
|
||||
|
||||
weights = np.zeros((K, K)) + 3
|
||||
bias = np.ones(K)
|
||||
|
@ -17,7 +17,9 @@ builder = NeuralNetworkBuilder(input_features+input_features2, output_features)
|
|||
|
||||
# builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med')
|
||||
# builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs')
|
||||
builder.add_elementwise(name='element', input_names=['image', 'image2'], output_name='probs', mode='ADD')
|
||||
builder.add_elementwise(
|
||||
name="element", input_names=["image", "image2"], output_name="probs", mode="ADD"
|
||||
)
|
||||
# builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,))
|
||||
# builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs')
|
||||
|
||||
|
@ -25,6 +27,11 @@ builder.add_elementwise(name='element', input_names=['image', 'image2'], output_
|
|||
mlmodel = ct.models.MLModel(builder.spec)
|
||||
|
||||
# trigger the ANE!
|
||||
out = mlmodel.predict({"image": np.zeros(K, dtype=np.float32)+1, "image2": np.zeros(K, dtype=np.float32)+2})
|
||||
out = mlmodel.predict(
|
||||
{
|
||||
"image": np.zeros(K, dtype=np.float32) + 1,
|
||||
"image2": np.zeros(K, dtype=np.float32) + 2,
|
||||
}
|
||||
)
|
||||
print(out)
|
||||
mlmodel.save('test.mlmodel')
|
||||
mlmodel.save("test.mlmodel")
|
||||
|
|
|
@ -6,7 +6,7 @@ import pylab as plt
|
|||
from networkx.drawing.nx_pydot import read_dot
|
||||
|
||||
ret = os.system("./a.out " + sys.argv[1] + " debug")
|
||||
assert(ret == 0)
|
||||
assert ret == 0
|
||||
|
||||
df = "debug/model.hwx.zinir_graph_after_reg_spill.dot"
|
||||
|
||||
|
|
|
@ -3,17 +3,21 @@ import sys
|
|||
from hexdump import hexdump
|
||||
from macholib import MachO
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
||||
def get_macho(fn):
|
||||
# mod to make the header okay
|
||||
# MH_CIGAM_64 is good
|
||||
dat = open(fn, "rb").read()
|
||||
dat = b"\xcf\xfa\xed\xfe" + dat[4:]
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile(delete=False) as f:
|
||||
f.write(dat)
|
||||
f.close()
|
||||
return MachO.MachO(f.name)
|
||||
|
||||
|
||||
a = get_macho("model.hwx.golden")
|
||||
|
||||
# load commands
|
||||
|
@ -23,12 +27,19 @@ for c in a.headers[0].commands:
|
|||
hexdump(c[2])
|
||||
pass
|
||||
if c[0].cmd == 6:
|
||||
print("name:", c[2].decode('utf-8'))
|
||||
print("name:", c[2].decode("utf-8"))
|
||||
if c[0].cmd == 8:
|
||||
print(c[2].decode('utf-8'))
|
||||
print(c[2].decode("utf-8"))
|
||||
if c[0].cmd == 25:
|
||||
for section in c[2]:
|
||||
print(section.segname.strip(b'\0'), section.sectname.strip(b'\0'), hex(section.addr), hex(section.size), "@", hex(c[1].fileoff))
|
||||
print(
|
||||
section.segname.strip(b"\0"),
|
||||
section.sectname.strip(b"\0"),
|
||||
hex(section.addr),
|
||||
hex(section.size),
|
||||
"@",
|
||||
hex(c[1].fileoff),
|
||||
)
|
||||
# print(dir(section))
|
||||
if c[1].filesize > 0:
|
||||
if len(section.section_data) < 0x100:
|
||||
|
@ -38,6 +49,7 @@ for c in a.headers[0].commands:
|
|||
|
||||
# this parser is wrong (fixed with 64-bit one)
|
||||
from macholib import SymbolTable
|
||||
|
||||
sym = SymbolTable.SymbolTable(a)
|
||||
|
||||
syms = {}
|
||||
|
@ -52,6 +64,7 @@ for k,v in syms.items():
|
|||
|
||||
# **** document what we know ***
|
||||
from ane import ANE_Struct, ANE
|
||||
|
||||
ane = ANE()
|
||||
|
||||
aneb = set()
|
||||
|
@ -65,6 +78,8 @@ for l in range(0x34, 0xF4):
|
|||
aneb.add(l)
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def compare(x, y):
|
||||
ss = []
|
||||
ln = []
|
||||
|
@ -73,7 +88,7 @@ def compare(x, y):
|
|||
ll = (max(len(x), len(y)) + 0xF) // 0x10 * 0x10
|
||||
|
||||
highlight = False
|
||||
next_highlight = 0x2b
|
||||
next_highlight = 0x2B
|
||||
for i in range(ll + 1):
|
||||
if i == next_highlight:
|
||||
highlight = True
|
||||
|
@ -83,35 +98,37 @@ def compare(x, y):
|
|||
next_highlight = None
|
||||
else:
|
||||
highlight = False
|
||||
a = "%02X" % x[i] if i < len(x) else "--", \
|
||||
"%02X" % y[i] if i < len(y) else "--"
|
||||
a = "%02X" % x[i] if i < len(x) else "--", "%02X" % y[i] if i < len(y) else "--"
|
||||
|
||||
def fj(x):
|
||||
ss = []
|
||||
for i in range(0, 0x10, 4):
|
||||
ss.append(' '.join(x[i:i+4]))
|
||||
return ' '.join(ss)
|
||||
ss.append(" ".join(x[i : i + 4]))
|
||||
return " ".join(ss)
|
||||
|
||||
if i != 0 and i % 0x10 == 0:
|
||||
ss.append("%8X: " % (i - 0x10) + fj(ln) + " | " + fj(ln2) + "\n")
|
||||
ln = []
|
||||
ln2 = []
|
||||
if a[0] != a[1] and a[0] != "--" and a[1] != "--":
|
||||
ln.append(colored(a[0], 'green'))
|
||||
ln2.append(colored(a[1], 'red'))
|
||||
ln.append(colored(a[0], "green"))
|
||||
ln2.append(colored(a[1], "red"))
|
||||
else:
|
||||
if highlight:
|
||||
ln.append(colored(a[0], 'yellow'))
|
||||
ln2.append(colored(a[1], 'yellow'))
|
||||
ln.append(colored(a[0], "yellow"))
|
||||
ln2.append(colored(a[1], "yellow"))
|
||||
else:
|
||||
if i in aneb:
|
||||
ln.append(colored(a[0], 'white'))
|
||||
ln2.append(colored(a[1], 'white'))
|
||||
ln.append(colored(a[0], "white"))
|
||||
ln2.append(colored(a[1], "white"))
|
||||
else:
|
||||
ln.append(a[0])
|
||||
ln2.append(a[1])
|
||||
return ''.join(ss)
|
||||
return "".join(ss)
|
||||
|
||||
|
||||
import json
|
||||
|
||||
aneregs = dict(json.load(open("aneregs.json")))
|
||||
g = get_macho("model.hwx.golden" if len(sys.argv) < 2 else sys.argv[1])
|
||||
f1 = g.headers[0].commands[1][2][0].section_data
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
from ane import ANE
|
||||
|
||||
ane = ANE()
|
||||
|
||||
lens = {}
|
||||
|
@ -30,7 +31,7 @@ for i in range(0x300):
|
|||
pos.append((k, (i, j, lens[k])))
|
||||
|
||||
import json
|
||||
|
||||
jpos = json.dumps(pos, indent=2)
|
||||
with open("aneregs.json", "w") as f:
|
||||
f.write(jpos)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import ctypes
|
|||
from subprocess import check_output
|
||||
from hexdump import hexdump
|
||||
|
||||
|
||||
def get_pid(name):
|
||||
try:
|
||||
output = check_output(["pgrep", name])
|
||||
|
@ -9,8 +10,10 @@ def get_pid(name):
|
|||
except:
|
||||
return None
|
||||
|
||||
|
||||
from ctypes.util import find_library
|
||||
libc = ctypes.CDLL(find_library('c'))
|
||||
|
||||
libc = ctypes.CDLL(find_library("c"))
|
||||
|
||||
amfid_pid = get_pid("amfid")
|
||||
|
||||
|
@ -21,6 +24,7 @@ print(amfid_pid, ret, task, mytask)
|
|||
|
||||
# myport = libc.mach_task_self()
|
||||
|
||||
|
||||
class vm_region_submap_short_info_data_64(ctypes.Structure):
|
||||
_pack_ = 1
|
||||
_fields_ = [
|
||||
|
@ -38,6 +42,8 @@ class vm_region_submap_short_info_data_64(ctypes.Structure):
|
|||
("object_id", ctypes.c_uint32),
|
||||
("user_wired_count", ctypes.c_uint32),
|
||||
]
|
||||
|
||||
|
||||
submap_info_size = ctypes.sizeof(vm_region_submap_short_info_data_64) // 4
|
||||
|
||||
address = ctypes.c_ulong(0)
|
||||
|
@ -48,21 +54,31 @@ depth = 0
|
|||
|
||||
c_depth = ctypes.c_uint32(depth)
|
||||
for i in range(1):
|
||||
ret = libc.mach_vm_region_recurse(task,
|
||||
ctypes.pointer(address), ctypes.pointer(mapsize),
|
||||
ctypes.pointer(c_depth), ctypes.pointer(sub_info),
|
||||
ctypes.pointer(count))
|
||||
ret = libc.mach_vm_region_recurse(
|
||||
task,
|
||||
ctypes.pointer(address),
|
||||
ctypes.pointer(mapsize),
|
||||
ctypes.pointer(c_depth),
|
||||
ctypes.pointer(sub_info),
|
||||
ctypes.pointer(count),
|
||||
)
|
||||
print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection)
|
||||
# address.value += mapsize.value
|
||||
# exit(0)
|
||||
|
||||
patch_address = address.value + 0x8e38
|
||||
patch_address = address.value + 0x8E38
|
||||
patch = b"\x00\x00\x80\xd2"
|
||||
|
||||
pdata = ctypes.c_void_p(0)
|
||||
data_cnt = ctypes.c_uint32(0)
|
||||
|
||||
ret = libc.mach_vm_read(task, ctypes.c_ulong(patch_address), 4, ctypes.pointer(pdata), ctypes.pointer(data_cnt))
|
||||
ret = libc.mach_vm_read(
|
||||
task,
|
||||
ctypes.c_ulong(patch_address),
|
||||
4,
|
||||
ctypes.pointer(pdata),
|
||||
ctypes.pointer(data_cnt),
|
||||
)
|
||||
buf = ctypes.string_at(pdata.value, data_cnt.value)
|
||||
hexdump(buf)
|
||||
|
||||
|
|
|
@ -6,12 +6,15 @@ import collections
|
|||
import numpy as np
|
||||
import faulthandler
|
||||
import struct
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
basedir = Path(__file__).resolve().parent
|
||||
|
||||
libane = None
|
||||
aneregs = None
|
||||
|
||||
|
||||
def init_libane():
|
||||
global libane, aneregs
|
||||
libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())
|
||||
|
@ -32,71 +35,56 @@ def init_libane():
|
|||
with open(basedir / "aneregs.json") as f:
|
||||
aneregs = json.load(f)
|
||||
|
||||
|
||||
ANE_Struct = [
|
||||
# aneTD.Header
|
||||
("u32", 0x1C, "NextCommandOffset"),
|
||||
|
||||
# KernelDMASrc @ section @ 0x2C len 0xF4
|
||||
# reloc 0x2c-0x34?? = weights
|
||||
# u32[16] 0x34-0x74 = 0x80 | 1 if used
|
||||
# u32[16] 0x74-0xB4 = <channel data offset>
|
||||
# u32[16] 0xB4-0xF4 = <channel data length>
|
||||
|
||||
# Common @ section @ 0x128 len 0x3C (conv)
|
||||
("u16", 0x128, "InputWidth"),
|
||||
("u16", 0x12A, "InputHeight"),
|
||||
("u16", 0x12C, "InputDepth"),
|
||||
|
||||
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
|
||||
# UInt8 = 0, Int8 = 1, Float16 = 2
|
||||
|
||||
("u32", 0x134, "InputChannels"),
|
||||
("u32", 0x138, "OutputChannels"),
|
||||
|
||||
("u16", 0x13C, "OutputWidth"),
|
||||
("u16", 0x13E, "OutputHeight"),
|
||||
("u16", 0x140, "OutputDepth"),
|
||||
|
||||
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
|
||||
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
|
||||
|
||||
("u16", 0x14C, "BatchSize"),
|
||||
|
||||
# TileDMASrc @ section @ 0x16C len 0x6C (input)
|
||||
# reloc 0x16c-0x174 = image
|
||||
("u32", 0x178, "InputRowStride"),
|
||||
("u32", 0x17C, "InputPlaneStride"),
|
||||
("u32", 0x180, "InputDepthStride"),
|
||||
("u32", 0x184, "InputBatchStride"),
|
||||
|
||||
("u8", 0x1A7, "InputInterleave"),
|
||||
|
||||
# L2 @ section @ 0x1E0 len 0x44
|
||||
# [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines
|
||||
# [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv?
|
||||
# [0x21c, 0x220, 0x224] = engines for outconv?
|
||||
|
||||
# NE @ section @ 0x22c len 0xC (scaling)
|
||||
("u16", 0x230, "BiasScalar"),
|
||||
("u16", 0x232, "ScaleScalar"),
|
||||
|
||||
# section @ 0x240 len 0x10
|
||||
("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom
|
||||
("u32", 0x250, "PostScale"),
|
||||
|
||||
# TileDMADst @ section @ 0x258 len 0x18
|
||||
|
||||
# HandleTileDmaDstConfig
|
||||
# 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0;
|
||||
# (GetCacheHintRegisterValue & 0xf) << 6;
|
||||
("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at?
|
||||
|
||||
# 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6;
|
||||
("u32", 0x260, "OutputRowStride"),
|
||||
("u32", 0x264, "OutputPlaneStride"),
|
||||
("u32", 0x268, "OutputDepthStride"),
|
||||
("u32", 0x26C, "OutputBatchStride"),
|
||||
|
||||
# 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000;
|
||||
# uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031;
|
||||
# (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18;
|
||||
|
@ -108,24 +96,26 @@ for typ, num, nam in ANE_Struct:
|
|||
styp = {"u32": "I", "u16": "H", "u8": "B"}[typ]
|
||||
ANE_Struct_Dict[nam] = (styp, num)
|
||||
|
||||
|
||||
class ANETensor:
|
||||
def __init__(self, *shape):
|
||||
self.shape = shape
|
||||
self.dtype = np.float16
|
||||
self.sz = int(np.prod(shape))
|
||||
assert(self.sz <= 0x4000)
|
||||
assert self.sz <= 0x4000
|
||||
self.tt = libane.ANE_TensorCreate(self.sz, 1)
|
||||
assert(self.tt is not None)
|
||||
assert self.tt is not None
|
||||
|
||||
def data(self):
|
||||
data = libane.ANE_TensorData(self.tt)
|
||||
assert(data is not None)
|
||||
assert data is not None
|
||||
# print(hex(addressof(data.contents)))
|
||||
buf = np.ctypeslib.as_array(data, shape=(self.sz,))
|
||||
ret = np.frombuffer(buf, dtype=self.dtype)
|
||||
# print(ret.data)
|
||||
return ret
|
||||
|
||||
|
||||
class ANE:
|
||||
def __init__(self):
|
||||
init_libane()
|
||||
|
@ -133,11 +123,13 @@ class ANE:
|
|||
|
||||
def compile(self, dat):
|
||||
ret = libane.ANE_Compile(create_string_buffer(dat), len(dat))
|
||||
assert(ret is not None)
|
||||
assert ret is not None
|
||||
return ret
|
||||
|
||||
def run(self, prog, tin, tout, tweights=None):
|
||||
libane.ANE_Run(prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0)
|
||||
libane.ANE_Run(
|
||||
prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0
|
||||
)
|
||||
|
||||
def tensor(self, shape):
|
||||
return ANETensor(shape)
|
||||
|
@ -165,9 +157,9 @@ class ANE:
|
|||
return dat
|
||||
|
||||
def debug(self, dat, mems=0):
|
||||
add = [0x30, 0x1d4, 0x220, 0x29c, 0x2f0, 0x30c, 0x32c]
|
||||
add = [0x30, 0x1D4, 0x220, 0x29C, 0x2F0, 0x30C, 0x32C]
|
||||
lens = [244, 60, 108, 68, 12, 16, 24]
|
||||
ptr = 0x2b
|
||||
ptr = 0x2B
|
||||
ddat = dat[0:0x28]
|
||||
for a, pm in zip(add, lens):
|
||||
# assert pm == dat[ptr]
|
||||
|
@ -176,7 +168,12 @@ class ANE:
|
|||
ptr += pm + 8
|
||||
ddat += b"\x00" * 0x100
|
||||
ret = collections.OrderedDict()
|
||||
for ln in libane.ANE_RegDebug(0, create_string_buffer(ddat), mems).decode('utf-8').strip().split("\n"):
|
||||
for ln in (
|
||||
libane.ANE_RegDebug(0, create_string_buffer(ddat), mems)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split("\n")
|
||||
):
|
||||
lnn = ln.split(" = ")
|
||||
if len(lnn) == 2:
|
||||
ret[lnn[0]] = int(lnn[1])
|
||||
|
@ -194,6 +191,7 @@ class ANE:
|
|||
dat[base + a : base + a + len(x)] = x
|
||||
return dat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ane = ANE()
|
||||
|
||||
|
@ -212,11 +210,10 @@ if __name__ == "__main__":
|
|||
md = dat[0x4000:0x4300]
|
||||
dd = ane.unpack(md)
|
||||
mdf = ane.pack(dd, md)
|
||||
assert(md == mdf)
|
||||
assert md == mdf
|
||||
|
||||
comp = ane.compile(dat)
|
||||
ret = ane.run(comp, tin, tout)
|
||||
print("** after **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import time
|
||||
from ane import ANE, ANETensor
|
||||
|
||||
|
||||
def benchmark(ane):
|
||||
tin = ANETensor(512 * 0x20)
|
||||
tout = ANETensor(512 * 0x20)
|
||||
|
@ -14,7 +15,7 @@ def benchmark(ane):
|
|||
for i in range(1000):
|
||||
ret = ane.run(comp, tin, tout)
|
||||
et = time.time()
|
||||
ts = (et-st)
|
||||
ts = et - st
|
||||
ops = 1000 * 512 * 512 * 2
|
||||
|
||||
print("%.2f ms, %.2f gigaops/sec" % (ts * 1000, ops * 1e-9 / ts))
|
||||
|
@ -72,7 +73,7 @@ if __name__ == "__main__":
|
|||
dd = ane.unpack(dat[0x4000:0x4300])
|
||||
# use the 3rd arg as the weights
|
||||
dd["aneTD.Header[9].KBase0"] = 6
|
||||
dd["aneRegs.NE.PostScale.PostScale"] = 0x3c00
|
||||
dd["aneRegs.NE.PostScale.PostScale"] = 0x3C00
|
||||
# dd["aneRegs.L2.L2Cfg.InputReLU"] = 1
|
||||
# dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1
|
||||
# dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
from functools import lru_cache
|
||||
from .tensor import Device, Function, register
|
||||
|
||||
|
||||
@lru_cache
|
||||
def compile_wrapper(ane, dat):
|
||||
return ane.compile(dat)
|
||||
|
||||
|
||||
def roundup(x, v):
|
||||
return x + (v - x) % v
|
||||
|
||||
|
||||
@lru_cache
|
||||
def compile_relu(ane, sz):
|
||||
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
|
||||
|
@ -17,16 +20,25 @@ def compile_relu(ane, sz):
|
|||
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
|
||||
# 0x1f4, 0x1f8?
|
||||
# 0x214 = L2.ResultBase.Addr
|
||||
dat = ane.fill(dat, [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214], "I", l2_stride)
|
||||
dat = ane.fill(dat, [0x1EC, 0x1F0, 0x1F4, 0x1F8, 0x214], "I", l2_stride)
|
||||
stride = roundup(sz * 2, 0x40)
|
||||
dat = ane.filln(dat, {
|
||||
dat = ane.filln(
|
||||
dat,
|
||||
{
|
||||
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
|
||||
"InputWidth": sz, "OutputWidth": sz,
|
||||
"InputRowStride": stride, "InputPlaneStride": stride, "InputDepthStride": stride,
|
||||
"OutputRowStride": stride, "OutputPlaneStride": stride, "OutputDepthStride": stride,
|
||||
})
|
||||
"InputWidth": sz,
|
||||
"OutputWidth": sz,
|
||||
"InputRowStride": stride,
|
||||
"InputPlaneStride": stride,
|
||||
"InputDepthStride": stride,
|
||||
"OutputRowStride": stride,
|
||||
"OutputPlaneStride": stride,
|
||||
"OutputDepthStride": stride,
|
||||
},
|
||||
)
|
||||
return compile_wrapper(ane, bytes(dat))
|
||||
|
||||
|
||||
class ReLU(Function):
|
||||
def forward(ctx, input):
|
||||
ret = ctx.ane.tensor(input.shape)
|
||||
|
@ -36,4 +48,5 @@ class ReLU(Function):
|
|||
def backward(ctx, grad_output):
|
||||
return 0
|
||||
|
||||
register('relu', ReLU, device=Device.ANE)
|
||||
|
||||
register("relu", ReLU, device=Device.ANE)
|
||||
|
|
|
@ -31,13 +31,14 @@ for x in out.values(): x.realize()
|
|||
"""
|
||||
|
||||
from openvino.runtime import Core
|
||||
|
||||
core = Core()
|
||||
devices = core.available_devices
|
||||
for device in devices:
|
||||
device_name = core.get_property(device, "FULL_DEVICE_NAME")
|
||||
print(f"{device}: {device_name}")
|
||||
model = core.read_model(onnx_path)
|
||||
compiled_model = core.compile_model(model, device_name='GPU.0')
|
||||
compiled_model = core.compile_model(model, device_name="GPU.0")
|
||||
print(compiled_model)
|
||||
ireq = compiled_model.create_infer_request()
|
||||
for model_input in compiled_model.inputs:
|
||||
|
@ -51,7 +52,7 @@ print("did one")
|
|||
|
||||
REPS = 20
|
||||
st = time.perf_counter()
|
||||
for i in range(REPS): ireq.infer()
|
||||
for i in range(REPS):
|
||||
ireq.infer()
|
||||
et = time.perf_counter() - st
|
||||
print(f"{et*1000:.2f} ms {(CNT*N*N*N*REPS*2/et)*1e-9:.2f} GFLOPS")
|
||||
|
||||
|
|
|
@ -7,9 +7,12 @@ from tqdm import trange, tqdm
|
|||
from matplotlib import pyplot as plt
|
||||
|
||||
tests = {}
|
||||
|
||||
|
||||
def register_test(fxn):
|
||||
tests[fxn.__name__] = fxn
|
||||
|
||||
|
||||
def warp_size2(nthread):
|
||||
prg = """__kernel void warp_size2(
|
||||
__global float* src,
|
||||
|
@ -27,16 +30,36 @@ def warp_size2(nthread):
|
|||
src_buf = CLBuffer(1, dtypes.float32)
|
||||
dst_buf = CLBuffer(1, dtypes.int32)
|
||||
cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32])
|
||||
return min([cl([nthread, 1024, 1], [nthread, 1, 1], src_buf, dst_buf, 10, 3, wait=True) for _ in range(5)])*1e9
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl(
|
||||
[nthread, 1024, 1],
|
||||
[nthread, 1, 1],
|
||||
src_buf,
|
||||
dst_buf,
|
||||
10,
|
||||
3,
|
||||
wait=True,
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_warp_size():
|
||||
return [(nthread, warp_size2(nthread)) for nthread in trange(1, 256)]
|
||||
|
||||
|
||||
def reg_count(nthread, ngrp, nreg):
|
||||
reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)])
|
||||
reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
|
||||
reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
|
||||
reg_declr = "".join(
|
||||
[f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)]
|
||||
)
|
||||
reg_comp = "".join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
|
||||
reg_reduce = "".join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
|
||||
prg = f"""__kernel void reg_count(
|
||||
__global float* out_buf,
|
||||
__private const int niter
|
||||
|
@ -51,12 +74,25 @@ def reg_count(nthread, ngrp, nreg):
|
|||
}}"""
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
|
||||
return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True) for _ in range(10)])*1e9
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True)
|
||||
for _ in range(10)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_reg_count(nthread=1, ngrp=1):
|
||||
base = reg_count(nthread, ngrp, 1)
|
||||
return [(nreg, (reg_count(nthread, ngrp, nreg)-base)/nreg) for nreg in trange(4, 513, 4)]
|
||||
return [
|
||||
(nreg, (reg_count(nthread, ngrp, nreg) - base) / nreg)
|
||||
for nreg in trange(4, 513, 4)
|
||||
]
|
||||
|
||||
|
||||
def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
|
||||
ndata //= NCOMP * 4 # ptr size
|
||||
|
@ -72,22 +108,40 @@ def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
|
|||
*dst = idx;
|
||||
}}"""
|
||||
idx_buf = np.zeros(ndata * NCOMP, dtype=np.int32)
|
||||
for i in range(ndata): idx_buf[i*NCOMP] = (i + stride) % ndata
|
||||
for i in range(ndata):
|
||||
idx_buf[i * NCOMP] = (i + stride) % ndata
|
||||
in_buf = CLBuffer.fromCPU(idx_buf)
|
||||
out_buf = CLBuffer(1, dtypes.int32)
|
||||
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
|
||||
return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True)/steps for _ in range(5)])*1e9
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True) / steps
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_memory_latency():
|
||||
# requires cacheline < 16
|
||||
szs = [int(1.3**x) for x in range(20, 70)]
|
||||
return [(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128*1024)) for ndata in tqdm(szs)]
|
||||
return [
|
||||
(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128 * 1024))
|
||||
for ndata in tqdm(szs)
|
||||
]
|
||||
|
||||
|
||||
@register_test
|
||||
def test_cacheline_size():
|
||||
# TODO: this buffer must be at least 2x the L1 cache for this test to work
|
||||
return [(stride, buf_cache_hierarchy_pchase(4*65536, stride, steps=65536)) for stride in trange(1,64)]
|
||||
return [
|
||||
(stride, buf_cache_hierarchy_pchase(4 * 65536, stride, steps=65536))
|
||||
for stride in trange(1, 64)
|
||||
]
|
||||
|
||||
|
||||
def cl_read(sz, niter=1):
|
||||
prg = f"""__kernel void copy(
|
||||
|
@ -101,7 +155,16 @@ def cl_read(sz, niter=1):
|
|||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("copy", prg)
|
||||
# NOTE: if nay of the niters form a local group, this is wrong
|
||||
return min([cl([sz//16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True) for _ in range(10)])*1e9
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl([sz // 16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True)
|
||||
for _ in range(10)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_read_bandwidth():
|
||||
|
@ -129,12 +192,17 @@ def gflops(niter=4, nroll=4, ngroups=4096):
|
|||
cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math")
|
||||
FLOPS = NCOMP * 2 * 2 * niter * nroll * ngroups * 32
|
||||
# NOTE: if nay of the niters form a local group, this is wrong
|
||||
return FLOPS/(min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])*1e9)
|
||||
return FLOPS / (
|
||||
min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_gflops():
|
||||
return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cache = {}
|
||||
# cache = pickle.load(open("/tmp/cache.pkl", "rb"))
|
||||
|
@ -144,8 +212,10 @@ if __name__ == "__main__":
|
|||
print(f"running {k}")
|
||||
plt.subplot(2, (len(tests) + 1) // 2, i + 1)
|
||||
plt.title(k)
|
||||
if k == "test_memory_latency": plt.xscale('log')
|
||||
if k not in cache: cache[k] = test()
|
||||
if k == "test_memory_latency":
|
||||
plt.xscale("log")
|
||||
if k not in cache:
|
||||
cache[k] = test()
|
||||
plt.plot(*zip(*cache[k]))
|
||||
# pickle.dump(cache, open("/tmp/cache.pkl", "wb"))
|
||||
|
||||
|
|
|
@ -1,32 +1,69 @@
|
|||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from typing import (
|
||||
Tuple,
|
||||
List,
|
||||
NamedTuple,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
Union,
|
||||
DefaultDict,
|
||||
cast,
|
||||
)
|
||||
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
from tinygrad.shape.symbolic import (
|
||||
Variable,
|
||||
NumNode,
|
||||
MulNode,
|
||||
DivNode,
|
||||
ModNode,
|
||||
LtNode,
|
||||
SumNode,
|
||||
AndNode,
|
||||
)
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||||
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
||||
_type_to_letter = {
|
||||
dtypes.float32: "f",
|
||||
dtypes.bool: "p",
|
||||
dtypes.int32: "i",
|
||||
dtypes.int64: "a",
|
||||
dtypes.uint32: "u",
|
||||
dtypes.uint64: "b",
|
||||
dtypes.float.vec(4): "x",
|
||||
dtypes.uint8: "uc",
|
||||
dtypes.float16: "h",
|
||||
dtypes.int8: "c",
|
||||
dtypes.uint16: "us",
|
||||
dtypes.float64: "d",
|
||||
}
|
||||
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm: str
|
||||
dtype: DType
|
||||
scalar: bool
|
||||
off: Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes.float.vec(4):
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
|
||||
|
||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyLanguage:
|
||||
supports_load3: bool = False
|
||||
|
@ -37,9 +74,15 @@ class AssemblyLanguage:
|
|||
tor: Dict[Any, Register] = {}
|
||||
ins: List[AssemblyInstruction] = []
|
||||
|
||||
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
def type_to_letter(self, x):
|
||||
return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||||
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
self.tor[tok] = ret = Register(
|
||||
f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}",
|
||||
dtype,
|
||||
scalar,
|
||||
)
|
||||
if dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
|
@ -48,30 +91,72 @@ class AssemblyLanguage:
|
|||
|
||||
def render_numnode(self, b) -> Register:
|
||||
key = ("num", b)
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
if key not in self.tor:
|
||||
self.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b
|
||||
)
|
||||
)
|
||||
return self.tor[key]
|
||||
|
||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
def render_alu(
|
||||
self, op, a: Register, b: Union[Register, int, float], dtype=dtypes.int32
|
||||
) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
# if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
self.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
self.newreg(
|
||||
key,
|
||||
dtype=dtype,
|
||||
scalar=a.scalar and (not isinstance(b, Register) or b.scalar),
|
||||
),
|
||||
[a, b],
|
||||
op,
|
||||
)
|
||||
)
|
||||
return self.tor[key]
|
||||
|
||||
def render_cast(self, a: Register, new_dtype: DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
if a.dtype == new_dtype:
|
||||
return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
self.ins.append(
|
||||
AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])
|
||||
)
|
||||
return self.tor[key]
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
render_ops: Any = {
|
||||
Variable: lambda self, ops, ctx: ctx.tor[self],
|
||||
NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.MUL, self.a.render(ops, ctx), self.b
|
||||
),
|
||||
DivNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.DIV, self.a.render(ops, ctx), self.b
|
||||
),
|
||||
ModNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.MOD, self.a.render(ops, ctx), self.b
|
||||
),
|
||||
LtNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool
|
||||
),
|
||||
SumNode: lambda self, ops, ctx: functools.reduce(
|
||||
lambda a, b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops, ctx)),
|
||||
self.nodes[1:],
|
||||
self.nodes[0].render(ops, ctx),
|
||||
),
|
||||
AndNode: lambda self, ops, ctx: functools.reduce(
|
||||
lambda a, b: ctx.render_alu(
|
||||
BinaryOps.MUL, a, b.render(ops, ctx), dtype=dtypes.bool
|
||||
),
|
||||
self.nodes[1:],
|
||||
self.nodes[0].render(ops, ctx),
|
||||
),
|
||||
}
|
||||
|
||||
def addr_w_offset(self, args):
|
||||
assert isinstance(args, MemOp)
|
||||
|
@ -79,110 +164,264 @@ class AssemblyLanguage:
|
|||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||
if (
|
||||
nums and nums[0] < 4096 and (idx - nums[0]).min >= 0
|
||||
): # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = cast(int, nums[0])
|
||||
reg = idx.render(self.render_ops, self)
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
new_reg = self.newreg((reg.nm, "vec"), dtype=reg.dtype)
|
||||
self.ins.append(
|
||||
AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)
|
||||
)
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||
reg = self.render_alu(
|
||||
BinaryOps.ADD,
|
||||
self.render_cast(reg, dtypes.uint64),
|
||||
self.tor[args.name],
|
||||
dtype=dtypes.uint64,
|
||||
)
|
||||
return reg, None, off
|
||||
|
||||
|
||||
def uops_to_asmstyle(lang, function_name: str, uops: List[UOp]):
|
||||
# TODO: Do not use clear()
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
|
||||
buf_to_dtype = {
|
||||
args[0]: args[1] for uop, _, _, args, _ in uops if uop == UOps.DEFINE_GLOBAL
|
||||
}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
lang.ins += [
|
||||
AssemblyInstruction(
|
||||
UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf
|
||||
)
|
||||
for buf in buf_to_dtype
|
||||
]
|
||||
for u in uops:
|
||||
uop, dtype, vin, args, _ = u
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
lang.newreg(args[0], dtype=dtypes.uint64),
|
||||
[args[0]],
|
||||
UnaryOps.NOOP,
|
||||
)
|
||||
)
|
||||
elif uop == UOps.LOOP:
|
||||
if args[1] == "global":
|
||||
for i, var in enumerate(args[0]):
|
||||
global_size.append(var.max + 1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.SPECIAL,
|
||||
lang.newreg(var, dtype=dtypes.int32),
|
||||
[],
|
||||
f"gid{len(args[0])-1-i}",
|
||||
)
|
||||
)
|
||||
elif args[1] == "local":
|
||||
for i, var in enumerate(args[0]):
|
||||
local_size.append(var.max + 1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.SPECIAL,
|
||||
lang.newreg(var, dtype=dtypes.int32),
|
||||
[],
|
||||
f"lid{len(args[0])-1-i}",
|
||||
)
|
||||
)
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
if not isinstance(
|
||||
var, NumNode
|
||||
): # TODO: why is this coming through?
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LOAD,
|
||||
lang.newreg(var, dtype=dtypes.int32, scalar=True),
|
||||
[],
|
||||
0,
|
||||
)
|
||||
)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LABEL, None, [], "$loop_" + var.expr
|
||||
)
|
||||
)
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
if not isinstance(
|
||||
var, NumNode
|
||||
): # TODO: why is this coming through?
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
lang.tor[var],
|
||||
[lang.tor[var], 1],
|
||||
BinaryOps.ADD,
|
||||
)
|
||||
)
|
||||
pred = lang.render_alu(
|
||||
BinaryOps.CMPLT, lang.tor[var], var.max + 1, dtypes.bool
|
||||
)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.COND_BRANCH,
|
||||
None,
|
||||
[pred],
|
||||
("$loop_" + var.expr, True),
|
||||
)
|
||||
)
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
elif args[1] == 'local':
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ENDLOOP,
|
||||
None,
|
||||
[lang.tor[var]],
|
||||
(var.max + 1, f"gid{i}"),
|
||||
)
|
||||
)
|
||||
elif args[1] == "local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ENDLOOP,
|
||||
None,
|
||||
[lang.tor[var]],
|
||||
(var.max + 1, f"lid{i}"),
|
||||
)
|
||||
)
|
||||
elif uop == UOps.CAST:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(u, dtype)
|
||||
for i, sr in enumerate(out.subregs()):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)
|
||||
)
|
||||
elif uop == UOps.ALU:
|
||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
pred_reg = lang.newreg((u, "pred"), dtype=dtypes.bool)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args
|
||||
)
|
||||
)
|
||||
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((u, "rcp"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP
|
||||
)
|
||||
)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL
|
||||
)
|
||||
)
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((u, "2pi"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
tmp,
|
||||
[lang.tor[vin[0]], 1 / (math.pi * 2)],
|
||||
BinaryOps.MUL,
|
||||
)
|
||||
)
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)
|
||||
)
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
reg = lang.newreg(u, dtype=dtype)
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
||||
elif uop == UOps.SPECIAL:
|
||||
lang.tor[u] = lang.tor[args]
|
||||
elif uop == UOps.CONST:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
||||
reg = lang.newreg(
|
||||
u,
|
||||
dtype=dtype,
|
||||
scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)),
|
||||
)
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.COND_BRANCH,
|
||||
None,
|
||||
[pred],
|
||||
(f"$skipload_{skipload_branch}", False),
|
||||
)
|
||||
)
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LOAD,
|
||||
reg,
|
||||
[idx] + ([treg] if treg is not None else []),
|
||||
(
|
||||
off,
|
||||
"global" if not args.local else "shared",
|
||||
args.memory_dtype
|
||||
if args.memory_dtype != dtypes.float
|
||||
else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LABEL, None, [], f"$skipload_{skipload_branch}"
|
||||
)
|
||||
)
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
if args is None:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP
|
||||
)
|
||||
)
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.STORE,
|
||||
None,
|
||||
[idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []),
|
||||
(
|
||||
off,
|
||||
"global" if not args.local else "shared",
|
||||
args.memory_dtype
|
||||
if args.memory_dtype != dtypes.float
|
||||
else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
for tins in lang.ins:
|
||||
print(tins)
|
||||
return global_size, local_size
|
||||
|
|
|
@ -6,28 +6,60 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
|||
from tinygrad.helpers import dtypes, CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
|
||||
def float_to_hex(x):
|
||||
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
|
||||
|
||||
|
||||
def compute_offsets(total):
|
||||
quotient, remainder = divmod(total, 4096)
|
||||
return [4096] * quotient + [remainder] if remainder else [4096] * quotient
|
||||
|
||||
#NOTE: Darwin needs names to start with a "_"
|
||||
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
||||
|
||||
class ARM64Language(AssemblyLanguage): pass
|
||||
# NOTE: Darwin needs names to start with a "_"
|
||||
def get_name(name):
|
||||
return ("_" if system() == "Darwin" else "") + name
|
||||
|
||||
|
||||
class ARM64Language(AssemblyLanguage):
|
||||
pass
|
||||
|
||||
|
||||
def specialize_to_arm64(fn_nm, asm):
|
||||
var_size = 16
|
||||
prev_uop: Optional[UOps] = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
||||
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
||||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
||||
x_regs = ["x" + str(i) for i in reversed(range(12))]
|
||||
s_regs = ["s" + str(i) for i in reversed(range(3, 32)) if i <= 7 or i >= 16]
|
||||
type_to_reg = {
|
||||
dtypes.double: "d",
|
||||
dtypes.half: "h",
|
||||
dtypes.float32: "s",
|
||||
dtypes.bool: "w",
|
||||
dtypes.int8: "w",
|
||||
dtypes.int32: "w",
|
||||
dtypes.int64: "x",
|
||||
dtypes.uint8: "w",
|
||||
dtypes.uint32: "w",
|
||||
dtypes.uint64: "x",
|
||||
}
|
||||
alu = {
|
||||
BinaryOps.ADD: "add",
|
||||
BinaryOps.SUB: "sub",
|
||||
BinaryOps.MUL: "mul",
|
||||
BinaryOps.DIV: "div",
|
||||
BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "",
|
||||
BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.NOOP: "mov",
|
||||
UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "bl " + get_name("sinf"),
|
||||
UnaryOps.LOG2: "bl " + get_name("log2f"),
|
||||
UnaryOps.EXP2: "bl " + get_name("exp2f"),
|
||||
UnaryOps.SQRT: "bl " + get_name("sqrtf"),
|
||||
TernaryOps.MULACC: "madd",
|
||||
TernaryOps.WHERE: "fcsel",
|
||||
}
|
||||
|
||||
def mov_imm(value, reg):
|
||||
# Manually move value into reg if value can't fit
|
||||
|
@ -35,7 +67,7 @@ def specialize_to_arm64(fn_nm, asm):
|
|||
ins.append(f"movz w15, #{value & 0xffff}")
|
||||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {reg}, w15")
|
||||
elif reg[0] == 's':
|
||||
elif reg[0] == "s":
|
||||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||
ins.append("str x15, [sp, 16]")
|
||||
|
@ -46,42 +78,51 @@ def specialize_to_arm64(fn_nm, asm):
|
|||
# Get variables intervals
|
||||
live_range: Dict[str, List[int]] = {}
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
||||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
for var in [v for v in [out] + vin if v is not None and v.__class__ is not int]:
|
||||
live_range[var.nm] = (
|
||||
[i, i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
)
|
||||
|
||||
mem_vars: Dict[str, int] = {}
|
||||
rtor: Dict[str, str] = {}
|
||||
|
||||
def allocate_regs(mvars):
|
||||
nonlocal var_size
|
||||
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
||||
for v in [
|
||||
v
|
||||
for v in mvars
|
||||
if v is not None and v.__class__ is not int and v.nm not in rtor
|
||||
]:
|
||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
# NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||
if not available_regs:
|
||||
# ARM needs the stack 16-byte aligned
|
||||
var_size += 16
|
||||
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
|
||||
available_regs.append("s0" if dtypes.is_float(out[1]) else "x12")
|
||||
mem_vars[v.nm] = var_size
|
||||
rtor[v.nm] = available_regs.pop()
|
||||
|
||||
temp_floats = ['s0', 's1', 's2']
|
||||
temp_ints = ['x12', 'x13', 'x16']
|
||||
temp_floats = ["s0", "s1", "s2"]
|
||||
temp_ints = ["x12", "x13", "x16"]
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
# Clear regs out of interval
|
||||
for var, reg in list(rtor.items()):
|
||||
available_regs = s_regs if reg[0] == 's' else x_regs
|
||||
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs = s_regs if reg[0] == "s" else x_regs
|
||||
if var[1] not in "B" and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs.append(rtor.pop(var))
|
||||
# Assign a registers to the variables using live ranges.
|
||||
allocate_regs([out] + vin)
|
||||
# Assign temp regs to vin and load them before direct use
|
||||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
||||
for i, v in enumerate(
|
||||
[v for v in vin if v.__class__ is not int and v.nm in mem_vars]
|
||||
):
|
||||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == UOps.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
if arg.startswith("data"):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
||||
|
@ -91,28 +132,40 @@ def specialize_to_arm64(fn_nm, asm):
|
|||
ins.append(f"loop_{arg}:")
|
||||
elif uop == UOps.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if rtor[out.nm][0] == 's':
|
||||
mov_imm(0.0, 's0')
|
||||
mov_imm(1.0, 's1')
|
||||
if rtor[out.nm][0] == "s":
|
||||
mov_imm(0.0, "s0")
|
||||
mov_imm(1.0, "s1")
|
||||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||
if rtor[out.nm][0] == 'x':
|
||||
mov_imm(0, 'x14')
|
||||
mov_imm(1, 'x15')
|
||||
if rtor[out.nm][0] == "x":
|
||||
mov_imm(0, "x14")
|
||||
mov_imm(1, "x15")
|
||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == UOps.ALU:
|
||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||
if len(vin) == 2 and vin[1].__class__ is int:
|
||||
mov_imm(vin[1], "x15")
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
ins.append(
|
||||
f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||
)
|
||||
elif arg == TernaryOps.WHERE:
|
||||
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
||||
ins.append(
|
||||
f"fcmp {rtor[vin[0].nm]}, #0.0"
|
||||
if rtor[vin[0].nm][0] == "s"
|
||||
else f"cmp {rtor[vin[0].nm]}, #0"
|
||||
)
|
||||
ins.append(
|
||||
f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne"
|
||||
)
|
||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||
# NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||
if CI:
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||
else:
|
||||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
||||
save_regs = [
|
||||
k for k in rtor.keys() if k != out.nm and k not in mem_vars
|
||||
]
|
||||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||
# Save the registers before they are cleared by func call
|
||||
for i, k in enumerate(save_regs, 1):
|
||||
|
@ -128,27 +181,49 @@ def specialize_to_arm64(fn_nm, asm):
|
|||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||
elif arg == BinaryOps.CMPLT:
|
||||
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
||||
ins.append(
|
||||
f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||
if not dtypes.is_float(vin[0][1])
|
||||
else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}"
|
||||
)
|
||||
elif arg == BinaryOps.MOD:
|
||||
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
|
||||
rhs = "x15" if vin[1].__class__ is int else rtor[vin[1].nm]
|
||||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
ins.append(
|
||||
f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
||||
reg_in = (
|
||||
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
||||
if arg[2] is not None
|
||||
else rtor[out.nm]
|
||||
)
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
ins.append(
|
||||
f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]"
|
||||
)
|
||||
if arg[2] is not None:
|
||||
ins.append(
|
||||
f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}"
|
||||
)
|
||||
elif uop == UOps.STORE:
|
||||
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
reg_out = (
|
||||
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
||||
if arg[2] is not None
|
||||
else rtor[vin[1].nm]
|
||||
)
|
||||
if arg[2] is not None:
|
||||
ins.append(
|
||||
f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}"
|
||||
)
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
|
@ -168,9 +243,31 @@ def specialize_to_arm64(fn_nm, asm):
|
|||
if out is not None and out.nm in mem_vars:
|
||||
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
||||
return "\n".join(
|
||||
[
|
||||
f"//varsize {var_size}",
|
||||
".arch armv8-a",
|
||||
".text",
|
||||
f".global {get_name(fn_nm)}",
|
||||
".p2align 2",
|
||||
f"{get_name(fn_nm)}:",
|
||||
"mov x17, sp",
|
||||
]
|
||||
+ [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
||||
+ ins
|
||||
+ [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
||||
+ ["ret", "\n"]
|
||||
)
|
||||
|
||||
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
||||
|
||||
def uops_to_arm64_asm(
|
||||
fn_nm: str, uops: List[UOp]
|
||||
) -> Tuple[str, List[int], List[int], bool]:
|
||||
lang = ARM64Language()
|
||||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
||||
return (
|
||||
specialize_to_arm64(fn_nm, lang.ins),
|
||||
global_size[::-1],
|
||||
local_size[::-1],
|
||||
True,
|
||||
)
|
||||
|
|
|
@ -6,50 +6,113 @@ from tinygrad.helpers import dtypes
|
|||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
|
||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
dtype_to_nvtype = {
|
||||
dtypes.float32: "f32",
|
||||
dtypes.float16: "f16",
|
||||
dtypes.int64: "s64",
|
||||
dtypes.int32: "s32",
|
||||
dtypes.int8: "s8",
|
||||
dtypes.bool: "pred",
|
||||
dtypes.uint64: "u64",
|
||||
dtypes.uint32: "u32",
|
||||
dtypes.uint16: "u16",
|
||||
dtypes.uint8: "u8",
|
||||
"bits16": "b16",
|
||||
dtypes.float64: "f64",
|
||||
}
|
||||
|
||||
|
||||
def float_to_hex(x):
|
||||
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
|
||||
|
||||
|
||||
def ptx_needs_cast(dest_dtype, src_dtype):
|
||||
return (
|
||||
dtypes.is_float(dest_dtype)
|
||||
and dtypes.is_int(src_dtype)
|
||||
or dtypes.is_int(dest_dtype)
|
||||
and dtypes.is_float(src_dtype)
|
||||
or (
|
||||
dtypes.is_float(src_dtype)
|
||||
and dtypes.is_float(dest_dtype)
|
||||
and dest_dtype.itemsize != src_dtype.itemsize
|
||||
)
|
||||
)
|
||||
|
||||
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
|
||||
|
||||
def render_cast(ins, inp, out):
|
||||
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
|
||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
|
||||
if inp.dtype == dtypes.bool and (
|
||||
dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)
|
||||
):
|
||||
ins.append(
|
||||
f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};"
|
||||
)
|
||||
elif out.dtype == dtypes.bool:
|
||||
if inp.dtype == dtypes.bool:
|
||||
ins.append(f"mov.pred {out}, {inp};")
|
||||
else:
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
|
||||
ins.append(
|
||||
f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};"
|
||||
)
|
||||
else:
|
||||
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
|
||||
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
|
||||
round_mod = (
|
||||
".rzi"
|
||||
if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype)
|
||||
else ".rz"
|
||||
if dtypes.is_float(out.dtype)
|
||||
and (
|
||||
dtypes.is_int(inp.dtype)
|
||||
or dtypes.is_float(inp.dtype)
|
||||
and inp.dtype.itemsize > out.dtype.itemsize
|
||||
)
|
||||
else ""
|
||||
)
|
||||
ins.append(
|
||||
f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};"
|
||||
)
|
||||
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
supports_constant_folding: bool = True
|
||||
|
||||
|
||||
def specialize_to_ptx(lang, function_name):
|
||||
param_cnt = 0
|
||||
ins = []
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
||||
alu = {
|
||||
BinaryOps.ADD: "add",
|
||||
BinaryOps.SUB: "sub",
|
||||
BinaryOps.MUL: "mul",
|
||||
BinaryOps.DIV: "div",
|
||||
BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem",
|
||||
BinaryOps.CMPLT: "setp.lt",
|
||||
UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov",
|
||||
UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "sin.approx",
|
||||
UnaryOps.LOG2: "lg2.approx",
|
||||
UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn",
|
||||
TernaryOps.WHERE: "selp",
|
||||
}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == UOps.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
if arg.startswith("data"):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
||||
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith('gid'):
|
||||
elif arg.startswith("gid"):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
elif arg.startswith("lid"):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
|
@ -60,31 +123,64 @@ def specialize_to_ptx(lang, function_name):
|
|||
if vin[0].dtype == dtypes.bool:
|
||||
reg = vin[0]
|
||||
else:
|
||||
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
||||
reg = lang.newreg((vin[0], "bool"), dtypes.bool)
|
||||
ins.append(
|
||||
f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};"
|
||||
)
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
ins.append(
|
||||
f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};"
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
||||
ins.append(
|
||||
f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};"
|
||||
)
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
|
||||
dt = (
|
||||
("u16", dtypes.uint16)
|
||||
if arg[2] == dtypes.bool == out.dtype
|
||||
else ("u8", dtypes.uint8)
|
||||
if arg[2] == dtypes.bool
|
||||
else ("b16", dtypes.float16)
|
||||
if arg[2] == dtypes.half
|
||||
else (dtype_to_nvtype[arg[2]], arg[2])
|
||||
)
|
||||
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
||||
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
ins.append(
|
||||
f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
|
||||
)
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
ins.append(
|
||||
f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
|
||||
)
|
||||
elif uop == UOps.STORE:
|
||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
||||
if (
|
||||
ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype)
|
||||
or arg[2] == dtypes.bool
|
||||
):
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
||||
prereg = lang.newreg((vin[1], "bool"), dtype=dtypes.bool)
|
||||
render_cast(ins, vin[1], prereg)
|
||||
else: prereg = vin[1]
|
||||
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
|
||||
render_cast(ins, prereg, reg)
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
||||
else:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
prereg = vin[1]
|
||||
reg = lang.newreg(
|
||||
(prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]),
|
||||
dtype=dtypes.uint16
|
||||
if arg[2] == dtypes.bool
|
||||
else dtypes.float
|
||||
if arg[2] is None
|
||||
else arg[2],
|
||||
)
|
||||
render_cast(ins, prereg, reg)
|
||||
ins.append(
|
||||
f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};"
|
||||
)
|
||||
else:
|
||||
ins.append(
|
||||
f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};"
|
||||
)
|
||||
elif uop == UOps.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == UOps.LABEL:
|
||||
|
@ -92,14 +188,29 @@ def specialize_to_ptx(lang, function_name):
|
|||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
|
||||
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
||||
ins_prefix = [
|
||||
".version 7.8",
|
||||
".target " + arch(),
|
||||
".address_size 64",
|
||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{",
|
||||
]
|
||||
for arg in [
|
||||
(dtype, lang.type_to_letter(dtype), c) for dtype, c in lang.cnts.items()
|
||||
]:
|
||||
ins_prefix.append(
|
||||
f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",
|
||||
)
|
||||
ins = ins_prefix + ins
|
||||
ins += ["ret;", "}"]
|
||||
return '\n'.join(ins)
|
||||
return "\n".join(ins)
|
||||
|
||||
|
||||
def uops_to_ptx_asm(function_name: str, uops: List[UOp]):
|
||||
lang = PTXLanguage()
|
||||
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
||||
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
|
||||
return (
|
||||
specialize_to_ptx(lang, function_name),
|
||||
global_size[::-1],
|
||||
local_size[::-1],
|
||||
True,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
|||
|
||||
# ugh, is this really needed?
|
||||
from extra.helpers import enable_early_exec
|
||||
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
boilerplate_start = """
|
||||
|
@ -24,6 +25,7 @@ code_start = """.end_amdhsa_kernel
|
|||
code:
|
||||
"""
|
||||
|
||||
|
||||
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
||||
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
||||
# RDNA3 is actually a SIMD machine!
|
||||
|
@ -36,107 +38,202 @@ class RDNACodegen(AssemblyCodegen):
|
|||
|
||||
def specialize(self, asm) -> Tuple[str, str]:
|
||||
args = []
|
||||
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
|
||||
for i, b in enumerate(self.bufs):
|
||||
args.append(
|
||||
{
|
||||
".address_space": "global",
|
||||
".name": f"buf_{i}",
|
||||
".offset": i * 8,
|
||||
".size": 8,
|
||||
".type_name": b.dtype.name + "*",
|
||||
".value_kind": "global_buffer",
|
||||
}
|
||||
)
|
||||
ins = []
|
||||
|
||||
v_cnt = 3 # v[0:2] is local_xyz
|
||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||
|
||||
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
|
||||
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
|
||||
BinaryOps.CMPLT: "cmp_lt"}
|
||||
dtype_to_rdnatype = {
|
||||
dtypes.float32: "f32",
|
||||
dtypes.int64: "i64",
|
||||
dtypes.int32: "i32",
|
||||
dtypes.uint64: "u64",
|
||||
dtypes.bool: "i32",
|
||||
}
|
||||
alu = {
|
||||
BinaryOps.ADD: "add",
|
||||
BinaryOps.SUB: "sub",
|
||||
BinaryOps.MUL: "mul",
|
||||
TernaryOps.MULACC: "fma",
|
||||
BinaryOps.MAX: "max",
|
||||
UnaryOps.RECIP: "rcp",
|
||||
UnaryOps.NOOP: "mov",
|
||||
UnaryOps.SIN: "sin",
|
||||
UnaryOps.LOG2: "log",
|
||||
UnaryOps.EXP2: "exp",
|
||||
BinaryOps.CMPLT: "cmp_lt",
|
||||
}
|
||||
|
||||
pend_regs: Set[Register] = set()
|
||||
rtor: Dict[Register, str] = {}
|
||||
|
||||
def reg_in(x):
|
||||
nonlocal pend_regs
|
||||
# print("reg_in", x, rtor[x], pend_regs)
|
||||
if x in pend_regs:
|
||||
# print("clear")
|
||||
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
|
||||
ins.append("s_waitcnt lgkmcnt(0), vmcnt(0)")
|
||||
pend_regs.clear()
|
||||
return rtor[x]
|
||||
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
||||
if arg[0][0] in [
|
||||
dtypes.uint32,
|
||||
dtypes.uint64,
|
||||
dtypes.int64,
|
||||
dtypes.int32,
|
||||
dtypes.float32,
|
||||
dtypes.float.vec(4),
|
||||
]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
if arg[0][1]:
|
||||
s_cnt += s_cnt % align
|
||||
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
|
||||
reg_name = (
|
||||
f"s[{s_cnt}:{s_cnt + align - 1}]"
|
||||
if align > 1
|
||||
else f"s{s_cnt}"
|
||||
)
|
||||
s_cnt += align
|
||||
else:
|
||||
v_cnt += v_cnt % align
|
||||
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
|
||||
reg_name = (
|
||||
f"v[{v_cnt}:{v_cnt + align - 1}]"
|
||||
if align > 1
|
||||
else f"v{v_cnt}"
|
||||
)
|
||||
v_cnt += align
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
|
||||
if arg[0][0] == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
||||
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
||||
reg_name = (
|
||||
f"s{s_cnt-align+off}"
|
||||
if arg[0][1]
|
||||
else f"v{v_cnt-align+off}"
|
||||
)
|
||||
rtor[
|
||||
Register(
|
||||
f"%{arg[1]}{i}", dtypes.float, False, off=off
|
||||
)
|
||||
] = reg_name
|
||||
elif arg[0][0] == dtypes.bool:
|
||||
for i in range(arg[2]):
|
||||
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
|
||||
reg_name = (
|
||||
"scc" if arg[0][1] else "vcc_lo"
|
||||
) # `_lo` suffix since we're running wavefront_size=32
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
||||
raise NotImplementedError(
|
||||
"DEFINE_REGISTER not implemented for arg: ", arg
|
||||
)
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
if arg.startswith("buf"):
|
||||
i = int(arg[3:])
|
||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
||||
ins.append(f"s_load_b64 {reg_out(out)}, s[0:1], {i*8}")
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
|
||||
for r in out.subregs():
|
||||
pend_regs.add(r)
|
||||
elif arg.startswith("gid"):
|
||||
ins.append(f"v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}")
|
||||
# the docs lied, this is actually y
|
||||
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||
if int(arg[3]) == 2:
|
||||
ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||
if int(arg[3]) == 1:
|
||||
ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||
elif int(arg[3]) == 0:
|
||||
ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||
# get local size
|
||||
offset = len(args) * 8
|
||||
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
|
||||
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
|
||||
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
|
||||
args.append(
|
||||
{
|
||||
".offset": offset,
|
||||
".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}",
|
||||
".size": 8,
|
||||
}
|
||||
)
|
||||
ins.append(f"s_load_b32 s{2+int(arg[3])}, s[0:1], {offset}")
|
||||
ins.append("s_waitcnt vmcnt(0) lgkmcnt(0)")
|
||||
pend_regs.clear()
|
||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
||||
ins.append(
|
||||
f"v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}"
|
||||
)
|
||||
ins.append(
|
||||
f"v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}"
|
||||
)
|
||||
elif uop == UOps.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if arg == float("inf"):
|
||||
arg = "0x7f800000"
|
||||
elif arg == float("-inf"):
|
||||
arg = "0xff800000"
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
ins.append(
|
||||
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}"
|
||||
)
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
ins.append(
|
||||
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}"
|
||||
)
|
||||
elif uop == UOps.ALU:
|
||||
if arg in [BinaryOps.CMPLT]:
|
||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
ins.append(
|
||||
f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
|
||||
)
|
||||
else:
|
||||
alu_arg = alu[arg]
|
||||
if arg == TernaryOps.MULACC and out == vin[2]:
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
for rr in zip(
|
||||
*[
|
||||
x.subregs()
|
||||
if x.dtype == dtypes.float.vec(4)
|
||||
else [x, x, x, x]
|
||||
for x in [out] + vin
|
||||
]
|
||||
):
|
||||
ins.append(
|
||||
f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}"
|
||||
)
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
ins.append(
|
||||
f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
ins.append(
|
||||
f"s_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}"
|
||||
)
|
||||
else:
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
ins.append(
|
||||
f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
|
||||
)
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
for r in out.subregs():
|
||||
pend_regs.add(r)
|
||||
elif uop == UOps.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
ins.append(
|
||||
f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
|
||||
)
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
|
@ -144,29 +241,40 @@ class RDNACodegen(AssemblyCodegen):
|
|||
elif uop == UOps.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
||||
ins.append(
|
||||
f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
||||
else:
|
||||
raise NotImplementedError(uop)
|
||||
|
||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
||||
ins += ["s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm", "s_code_end"]
|
||||
|
||||
# dual alu group
|
||||
seen = set()
|
||||
new_ins = []
|
||||
for i, tins in enumerate(ins):
|
||||
if tins in seen: continue
|
||||
if tins in seen:
|
||||
continue
|
||||
if tins.startswith("v_fmac_f32"):
|
||||
for gins in reversed(ins[i + 1 :]):
|
||||
if gins in seen: continue
|
||||
if gins in seen:
|
||||
continue
|
||||
if gins.startswith("v_fmac_f32"):
|
||||
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
|
||||
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
|
||||
if r0[0]%2 == r1[0]%2: continue
|
||||
if r0[1]%2 == r1[1]%2: continue
|
||||
if r0[2]%2 == r1[2]%2: continue
|
||||
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
|
||||
r0 = [int(x[1:].strip(",")) for x in tins.split(" ")[1:]]
|
||||
r1 = [int(x[1:].strip(",")) for x in gins.split(" ")[1:]]
|
||||
if r0[0] % 2 == r1[0] % 2:
|
||||
continue
|
||||
if r0[1] % 2 == r1[1] % 2:
|
||||
continue
|
||||
if r0[2] % 2 == r1[2] % 2:
|
||||
continue
|
||||
new_ins.append(
|
||||
tins.replace("v_", "v_dual_")
|
||||
+ " :: "
|
||||
+ gins.replace("v_", "v_dual_")
|
||||
)
|
||||
seen.add(tins)
|
||||
seen.add(gins)
|
||||
break
|
||||
|
@ -174,30 +282,102 @@ class RDNACodegen(AssemblyCodegen):
|
|||
new_ins.append(tins)
|
||||
ins = new_ins
|
||||
|
||||
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
||||
return "code", self.assemble(args, ins, v_cnt, s_cnt)
|
||||
|
||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
||||
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
||||
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
||||
'.amdhsa_next_free_sgpr': s_cnt,
|
||||
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
|
||||
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
||||
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
||||
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
||||
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
||||
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
||||
kernel_desc = {
|
||||
".amdhsa_group_segment_fixed_size": 0,
|
||||
".amdhsa_private_segment_fixed_size": 0,
|
||||
".amdhsa_kernarg_size": 0,
|
||||
".amdhsa_next_free_vgpr": v_cnt, # this matters!
|
||||
".amdhsa_reserve_vcc": 0,
|
||||
".amdhsa_reserve_xnack_mask": 0,
|
||||
".amdhsa_next_free_sgpr": s_cnt,
|
||||
".amdhsa_float_round_mode_32": 0,
|
||||
".amdhsa_float_round_mode_16_64": 0,
|
||||
".amdhsa_float_denorm_mode_32": 3,
|
||||
".amdhsa_float_denorm_mode_16_64": 3,
|
||||
".amdhsa_dx10_clamp": 1,
|
||||
".amdhsa_ieee_mode": 1,
|
||||
".amdhsa_fp16_overflow": 0,
|
||||
".amdhsa_workgroup_processor_mode": 1,
|
||||
".amdhsa_memory_ordered": 1,
|
||||
".amdhsa_forward_progress": 0,
|
||||
".amdhsa_enable_private_segment": 0,
|
||||
".amdhsa_system_sgpr_workgroup_id_x": 1,
|
||||
".amdhsa_system_sgpr_workgroup_id_y": 1,
|
||||
".amdhsa_system_sgpr_workgroup_id_z": 1,
|
||||
".amdhsa_system_sgpr_workgroup_info": 0,
|
||||
".amdhsa_system_vgpr_workitem_id": 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
".amdhsa_exception_fp_ieee_invalid_op": 0,
|
||||
".amdhsa_exception_fp_denorm_src": 0,
|
||||
".amdhsa_exception_fp_ieee_div_zero": 0,
|
||||
".amdhsa_exception_fp_ieee_overflow": 0,
|
||||
".amdhsa_exception_fp_ieee_underflow": 0,
|
||||
".amdhsa_exception_fp_ieee_inexact": 0,
|
||||
".amdhsa_exception_int_div_zero": 0,
|
||||
".amdhsa_user_sgpr_dispatch_ptr": 0,
|
||||
".amdhsa_user_sgpr_queue_ptr": 0,
|
||||
".amdhsa_user_sgpr_kernarg_segment_ptr": 1,
|
||||
".amdhsa_user_sgpr_dispatch_id": 0,
|
||||
".amdhsa_user_sgpr_private_segment_size": 0,
|
||||
".amdhsa_wavefront_size32": 1,
|
||||
".amdhsa_uses_dynamic_stack": 0,
|
||||
}
|
||||
|
||||
metadata = {'amdhsa.kernels': [{'.args': args,
|
||||
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
||||
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
||||
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
||||
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
||||
'.wavefront_size': 32}],
|
||||
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
||||
metadata = {
|
||||
"amdhsa.kernels": [
|
||||
{
|
||||
".args": args,
|
||||
".group_segment_fixed_size": 0,
|
||||
".kernarg_segment_align": 8,
|
||||
".kernarg_segment_size": args[-1][".offset"] + args[-1][".size"],
|
||||
".language": "OpenCL C",
|
||||
".language_version": [1, 2],
|
||||
".max_flat_workgroup_size": 256,
|
||||
".name": "code",
|
||||
".private_segment_fixed_size": 0,
|
||||
".sgpr_count": s_cnt,
|
||||
".sgpr_spill_count": 0,
|
||||
".symbol": "code.kd",
|
||||
".uses_dynamic_stack": False,
|
||||
".vgpr_count": v_cnt,
|
||||
".vgpr_spill_count": 0,
|
||||
".wavefront_size": 32,
|
||||
}
|
||||
],
|
||||
"amdhsa.target": "amdgcn-amd-amdhsa--gfx1100",
|
||||
"amdhsa.version": [1, 2],
|
||||
}
|
||||
|
||||
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
|
||||
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
|
||||
code = (
|
||||
boilerplate_start
|
||||
+ "\n"
|
||||
+ "\n".join("%s %d" % x for x in kernel_desc.items())
|
||||
+ "\n"
|
||||
+ code_start
|
||||
+ "\n".join(ins)
|
||||
+ "\n.amdgpu_metadata\n"
|
||||
+ yaml.dump(metadata)
|
||||
+ ".end_amdgpu_metadata"
|
||||
)
|
||||
obj = early_exec(
|
||||
(
|
||||
[
|
||||
ROCM_LLVM_PATH / "llvm-mc",
|
||||
"--arch=amdgcn",
|
||||
"--mcpu=gfx1100",
|
||||
"--triple=amdgcn-amd-amdhsa",
|
||||
"--filetype=obj",
|
||||
"-",
|
||||
],
|
||||
code.encode("utf-8"),
|
||||
)
|
||||
)
|
||||
asm = early_exec(
|
||||
(
|
||||
[ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"],
|
||||
obj,
|
||||
)
|
||||
)
|
||||
return asm
|
||||
|
|
|
@ -4,7 +4,9 @@ from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
|
|||
|
||||
if __name__ == "__main__":
|
||||
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
|
||||
prg = CUDAProgram("test", """
|
||||
prg = CUDAProgram(
|
||||
"test",
|
||||
"""
|
||||
.version 7.8
|
||||
.target sm_86
|
||||
.address_size 64
|
||||
|
@ -17,7 +19,8 @@ if __name__ == "__main__":
|
|||
mov.u32 %r1, 0x40000000; // 2.0 in float
|
||||
st.global.u32 [%rd2], %r1;
|
||||
ret;
|
||||
}""", binary=True)
|
||||
}""",
|
||||
binary=True,
|
||||
)
|
||||
prg([1], [1], test)
|
||||
print(test.toCPU())
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import pathlib
|
|||
from hexdump import hexdump
|
||||
from tinygrad.helpers import colored
|
||||
from extra.helpers import enable_early_exec
|
||||
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH
|
||||
|
@ -37,29 +38,49 @@ for j in range(1):
|
|||
c = (y * KX + x) * 8
|
||||
a = (KY * KX * 8) + y * 8
|
||||
b = (KY * KX * 8) + (KY * 8) + x * 8
|
||||
gen.append(f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]")
|
||||
gen.append(
|
||||
f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]"
|
||||
)
|
||||
FLOPS += 16 * 8 * 2
|
||||
else:
|
||||
for i in range(0, MAX_REG, 6):
|
||||
if DUAL_ALU:
|
||||
if F32:
|
||||
gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
||||
gen.append(
|
||||
f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}"
|
||||
)
|
||||
FLOPS += 4
|
||||
else:
|
||||
gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}")
|
||||
gen.append(
|
||||
f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}"
|
||||
)
|
||||
FLOPS += 8
|
||||
else:
|
||||
assert F32
|
||||
gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}")
|
||||
gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
||||
code = code.replace("// FLOPS", '\n'.join(gen))
|
||||
code = code.replace("// FLOPS", "\n".join(gen))
|
||||
print(code)
|
||||
|
||||
|
||||
# fix: COMGR failed to get code object ISA name. set triple to 'amdgcn-amd-amdhsa'
|
||||
|
||||
object = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object))
|
||||
object = early_exec(
|
||||
(
|
||||
[
|
||||
ROCM_LLVM_PATH / "llvm-mc",
|
||||
"--arch=amdgcn",
|
||||
"--mcpu=gfx1100",
|
||||
"--triple=amdgcn-amd-amdhsa",
|
||||
"--filetype=obj",
|
||||
"-",
|
||||
],
|
||||
code.encode("utf-8"),
|
||||
)
|
||||
)
|
||||
asm = early_exec(
|
||||
([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object)
|
||||
)
|
||||
|
||||
with open("/tmp/cc2.o", "wb") as f:
|
||||
f.write(object)
|
||||
|
|
|
@ -2,12 +2,14 @@ import numpy as np
|
|||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
cwd = Path.cwd()
|
||||
sys.path.append(cwd.as_posix())
|
||||
sys.path.append((cwd / 'test').as_posix())
|
||||
sys.path.append((cwd / "test").as_posix())
|
||||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
|
||||
def augment_img(X, rotate=10, px=3):
|
||||
Xaug = np.zeros_like(X)
|
||||
for i in trange(len(X)):
|
||||
|
@ -20,8 +22,10 @@ def augment_img(X, rotate=10, px=3):
|
|||
Xaug[i] = im
|
||||
return Xaug
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
|
@ -29,14 +33,18 @@ if __name__ == "__main__":
|
|||
fig, a = plt.subplots(2, len(X))
|
||||
Xaug = augment_img(X)
|
||||
for i in range(len(X)):
|
||||
a[0][i].imshow(X[i], cmap='gray')
|
||||
a[1][i].imshow(Xaug[i],cmap='gray')
|
||||
a[0][i].axis('off')
|
||||
a[1][i].axis('off')
|
||||
a[0][i].imshow(X[i], cmap="gray")
|
||||
a[1][i].imshow(Xaug[i], cmap="gray")
|
||||
a[0][i].axis("off")
|
||||
a[1][i].axis("off")
|
||||
plt.show()
|
||||
|
||||
# create some nice gifs for doc?!
|
||||
for i in range(10):
|
||||
im = Image.fromarray(X_train[7353 + i])
|
||||
im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
|
||||
im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0)
|
||||
im_aug = [
|
||||
Image.fromarray(x) for x in augment_img(np.array([X_train[7353 + i]] * 100))
|
||||
]
|
||||
im.save(
|
||||
f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0
|
||||
)
|
||||
|
|
|
@ -3,30 +3,53 @@ import numpy as np
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, fetch
|
||||
|
||||
|
||||
def fetch_mnist(tensors=False):
|
||||
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
||||
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
|
||||
X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
X_train = (
|
||||
parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:]
|
||||
.reshape((-1, 28 * 28))
|
||||
.astype(np.float32)
|
||||
)
|
||||
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
X_test = (
|
||||
parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:]
|
||||
.reshape((-1, 28 * 28))
|
||||
.astype(np.float32)
|
||||
)
|
||||
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
|
||||
else: return X_train, Y_train, X_test, Y_test
|
||||
if tensors:
|
||||
return (
|
||||
Tensor(X_train).reshape(-1, 1, 28, 28),
|
||||
Tensor(Y_train),
|
||||
Tensor(X_test).reshape(-1, 1, 28, 28),
|
||||
Tensor(Y_test),
|
||||
)
|
||||
else:
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
||||
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
|
||||
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
||||
|
||||
|
||||
def fetch_cifar():
|
||||
X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8)
|
||||
Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64)
|
||||
X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8)
|
||||
Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64)
|
||||
X_train = Tensor.empty(
|
||||
50000, 3 * 32 * 32, device=f"disk:/tmp/cifar_train_x", dtype=dtypes.uint8
|
||||
)
|
||||
Y_train = Tensor.empty(50000, device=f"disk:/tmp/cifar_train_y", dtype=dtypes.int64)
|
||||
X_test = Tensor.empty(
|
||||
10000, 3 * 32 * 32, device=f"disk:/tmp/cifar_test_x", dtype=dtypes.uint8
|
||||
)
|
||||
Y_test = Tensor.empty(10000, device=f"disk:/tmp/cifar_test_y", dtype=dtypes.int64)
|
||||
|
||||
if not os.path.isfile("/tmp/cifar_extracted"):
|
||||
|
||||
def _load_disk_tensor(X, Y, db_list):
|
||||
idx = 0
|
||||
for db in db_list:
|
||||
x, y = db[b'data'], np.array(db[b'labels'])
|
||||
x, y = db[b"data"], np.array(db[b"labels"])
|
||||
assert x.shape[0] == y.shape[0]
|
||||
X[idx : idx + x.shape[0]].assign(x)
|
||||
Y[idx : idx + x.shape[0]].assign(y)
|
||||
|
@ -34,10 +57,28 @@ def fetch_cifar():
|
|||
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
|
||||
|
||||
print("downloading and extracting CIFAR...")
|
||||
fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
|
||||
tt = tarfile.open(fn, mode='r:gz')
|
||||
_load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
|
||||
_load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])
|
||||
fn = fetch("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz")
|
||||
tt = tarfile.open(fn, mode="r:gz")
|
||||
_load_disk_tensor(
|
||||
X_train,
|
||||
Y_train,
|
||||
[
|
||||
pickle.load(
|
||||
tt.extractfile(f"cifar-10-batches-py/data_batch_{i}"),
|
||||
encoding="bytes",
|
||||
)
|
||||
for i in range(1, 6)
|
||||
],
|
||||
)
|
||||
_load_disk_tensor(
|
||||
X_test,
|
||||
Y_test,
|
||||
[
|
||||
pickle.load(
|
||||
tt.extractfile("cifar-10-batches-py/test_batch"), encoding="bytes"
|
||||
)
|
||||
],
|
||||
)
|
||||
open("/tmp/cifar_extracted", "wb").close()
|
||||
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
|
|
@ -15,32 +15,36 @@ frPyObjects = _mask.frPyObjects
|
|||
BASEDIR = pathlib.Path(__file__).parent / "COCO"
|
||||
BASEDIR.mkdir(exist_ok=True)
|
||||
|
||||
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
|
||||
|
||||
def create_dict(key_row, val_row, rows):
|
||||
return {row[key_row]: row[val_row] for row in rows}
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'val2017').is_dir():
|
||||
fn = fetch('http://images.cocodataset.org/zips/val2017.zip')
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
if not pathlib.Path(BASEDIR / "val2017").is_dir():
|
||||
fn = fetch("http://images.cocodataset.org/zips/val2017.zip")
|
||||
with zipfile.ZipFile(fn, "r") as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'annotations').is_dir():
|
||||
fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip')
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
if not pathlib.Path(BASEDIR / "annotations").is_dir():
|
||||
fn = fetch("http://images.cocodataset.org/annotations/annotations_trainval2017.zip")
|
||||
with zipfile.ZipFile(fn, "r") as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
|
||||
with open(BASEDIR / "annotations/instances_val2017.json", "r") as f:
|
||||
annotations_raw = json.loads(f.read())
|
||||
images = annotations_raw['images']
|
||||
categories = annotations_raw['categories']
|
||||
annotations = annotations_raw['annotations']
|
||||
file_name_to_id = create_dict('file_name', 'id', images)
|
||||
id_to_width = create_dict('id', 'width', images)
|
||||
id_to_height = create_dict('id', 'height', images)
|
||||
json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
|
||||
contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
|
||||
images = annotations_raw["images"]
|
||||
categories = annotations_raw["categories"]
|
||||
annotations = annotations_raw["annotations"]
|
||||
file_name_to_id = create_dict("file_name", "id", images)
|
||||
id_to_width = create_dict("id", "width", images)
|
||||
id_to_height = create_dict("id", "height", images)
|
||||
json_category_id_to_contiguous_id = {v["id"]: i + 1 for i, v in enumerate(categories)}
|
||||
contiguous_category_id_to_json_id = {
|
||||
v: k for k, v in json_category_id_to_contiguous_id.items()
|
||||
}
|
||||
|
||||
|
||||
def encode(bimask):
|
||||
|
@ -48,7 +52,8 @@ def encode(bimask):
|
|||
return _mask.encode(bimask)
|
||||
elif len(bimask.shape) == 2:
|
||||
h, w = bimask.shape
|
||||
return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
|
||||
return _mask.encode(bimask.reshape((h, w, 1), order="F"))[0]
|
||||
|
||||
|
||||
def decode(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
|
@ -56,12 +61,14 @@ def decode(rleObjs):
|
|||
else:
|
||||
return _mask.decode([rleObjs])[:, :, 0]
|
||||
|
||||
|
||||
def area(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.area(rleObjs)
|
||||
else:
|
||||
return _mask.area([rleObjs])[0]
|
||||
|
||||
|
||||
def toBbox(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.toBbox(rleObjs)
|
||||
|
@ -102,8 +109,10 @@ def convert_prediction_to_coco_bbox(file_name, prediction):
|
|||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
|
||||
def convert_prediction_to_coco_mask(file_name, prediction):
|
||||
coco_results = []
|
||||
try:
|
||||
|
@ -122,8 +131,7 @@ def convert_prediction_to_coco_mask(file_name, prediction):
|
|||
masks = masker([masks], [prediction])[0].numpy()
|
||||
|
||||
rles = [
|
||||
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
|
||||
for mask in masks
|
||||
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] for mask in masks
|
||||
]
|
||||
for rle in rles:
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
|
@ -146,20 +154,22 @@ def convert_prediction_to_coco_mask(file_name, prediction):
|
|||
return coco_results
|
||||
|
||||
|
||||
|
||||
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
|
||||
path = pathlib.Path(json_result_file)
|
||||
if rm and path.exists(): path.unlink()
|
||||
if rm and path.exists():
|
||||
path.unlink()
|
||||
with open(path, "a") as f:
|
||||
for s in coco_results:
|
||||
f.write(json.dumps(s))
|
||||
f.write('\n')
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def remove_dup(l):
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in l if not (x in seen or seen_add(x))]
|
||||
|
||||
|
||||
class NpEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
|
@ -177,23 +187,28 @@ def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
|
|||
for line in f:
|
||||
coco_results.append(json.loads(line))
|
||||
|
||||
coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
|
||||
coco_gt = COCO(str(BASEDIR / "annotations/instances_val2017.json"))
|
||||
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
|
||||
unique_list = [json.loads(s) for s in set_of_json]
|
||||
|
||||
with open(f'{json_result_file}.flattend', "w") as f:
|
||||
with open(f"{json_result_file}.flattend", "w") as f:
|
||||
json.dump(unique_list, f)
|
||||
|
||||
coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
|
||||
coco_dt = coco_gt.loadRes(str(f"{json_result_file}.flattend"))
|
||||
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
return coco_eval
|
||||
|
||||
|
||||
def iterate(files, bs=1):
|
||||
batch = []
|
||||
for file in files:
|
||||
batch.append(file)
|
||||
if len(batch) >= bs: yield batch; batch = []
|
||||
if len(batch) > 0: yield batch; batch = []
|
||||
if len(batch) >= bs:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
batch = []
|
||||
|
|
|
@ -9,36 +9,45 @@ BASEDIR = pathlib.Path(__file__).parent / "imagenet"
|
|||
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
||||
cir = {v[0]: int(k) for k, v in ci.items()}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_train_files():
|
||||
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
|
||||
return [(BASEDIR / "train" / x) for x in train_files]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_val_files():
|
||||
val_files = glob.glob(str(BASEDIR / "val/*/*"))
|
||||
return val_files
|
||||
|
||||
|
||||
# rrc = transforms.RandomResizedCrop(224)
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
def image_load(fn):
|
||||
img = Image.open(fn).convert('RGB')
|
||||
img = Image.open(fn).convert("RGB")
|
||||
img = F.resize(img, 256, Image.BILINEAR)
|
||||
img = F.center_crop(img, 224)
|
||||
ret = np.array(img)
|
||||
return ret
|
||||
|
||||
|
||||
def iterate(bs=32, val=True, shuffle=True):
|
||||
files = get_val_files() if val else get_train_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle: random.shuffle(order)
|
||||
if shuffle:
|
||||
random.shuffle(order)
|
||||
from multiprocessing import Pool
|
||||
|
||||
p = Pool(16)
|
||||
for i in range(0, len(files), bs):
|
||||
X = p.map(image_load, [files[i] for i in order[i : i + bs]])
|
||||
Y = [cir[files[i].split("/")[-2]] for i in order[i : i + bs]]
|
||||
yield (np.array(X), np.array(Y))
|
||||
|
||||
|
||||
def fetch_batch(bs, val=False):
|
||||
files = get_val_files() if val else get_train_files()
|
||||
samp = np.random.randint(0, len(files), size=(bs))
|
||||
|
@ -47,7 +56,7 @@ def fetch_batch(bs, val=False):
|
|||
Y = [cir[x.split("/")[0]] for x in files]
|
||||
return np.array(X), np.array(Y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
X, Y = fetch_batch(64)
|
||||
print(X.shape, Y)
|
||||
|
||||
|
|
|
@ -4,17 +4,26 @@ from pathlib import Path
|
|||
from tqdm import tqdm
|
||||
import tarfile, os
|
||||
|
||||
|
||||
def imagenet_extract(file, path, small=False):
|
||||
with tarfile.open(name=file) as tar:
|
||||
if small: # Show progressbar only for big files
|
||||
for member in tar.getmembers(): tar.extract(path=path, member=member)
|
||||
for member in tar.getmembers():
|
||||
tar.extract(path=path, member=member)
|
||||
else:
|
||||
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member)
|
||||
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())):
|
||||
tar.extract(path=path, member=member)
|
||||
tar.close()
|
||||
|
||||
|
||||
def imagenet_prepare_val():
|
||||
# Read in the labels file
|
||||
with open(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt", 'r') as f:
|
||||
with open(
|
||||
Path(__file__).parent
|
||||
/ "imagenet"
|
||||
/ "imagenet_2012_validation_synset_labels.txt",
|
||||
"r",
|
||||
) as f:
|
||||
labels = f.read().splitlines()
|
||||
f.close()
|
||||
# Get a list of images
|
||||
|
@ -23,8 +32,16 @@ def imagenet_prepare_val():
|
|||
# Create folders and move files into those
|
||||
for co, dir in enumerate(labels):
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True)
|
||||
os.replace(Path(__file__).parent / "imagenet" / "val" / images[co], Path(__file__).parent / "imagenet" / "val" / dir / images[co])
|
||||
os.remove(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt")
|
||||
os.replace(
|
||||
Path(__file__).parent / "imagenet" / "val" / images[co],
|
||||
Path(__file__).parent / "imagenet" / "val" / dir / images[co],
|
||||
)
|
||||
os.remove(
|
||||
Path(__file__).parent
|
||||
/ "imagenet"
|
||||
/ "imagenet_2012_validation_synset_labels.txt"
|
||||
)
|
||||
|
||||
|
||||
def imagenet_prepare_train():
|
||||
images = os.listdir(Path(__file__).parent / "imagenet" / "train")
|
||||
|
@ -32,20 +49,47 @@ def imagenet_prepare_train():
|
|||
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
|
||||
if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file():
|
||||
images[co] = tarf[:-4] # remove .tar from extracted tar files
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True)
|
||||
imagenet_extract(Path(__file__).parent / "imagenet" / "train" / tarf, Path(__file__).parent/ "imagenet" / "train" / images[co], small=True)
|
||||
os.makedirs(
|
||||
Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True
|
||||
)
|
||||
imagenet_extract(
|
||||
Path(__file__).parent / "imagenet" / "train" / tarf,
|
||||
Path(__file__).parent / "imagenet" / "train" / images[co],
|
||||
small=True,
|
||||
)
|
||||
os.remove(Path(__file__).parent / "imagenet" / "train" / tarf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True)
|
||||
fetch("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent / "imagenet" / "imagenet_class_index.json")
|
||||
fetch("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent / "imagenet"/ "imagenet_2012_validation_synset_labels.txt")
|
||||
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar") # 7GB
|
||||
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "val")
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json",
|
||||
Path(__file__).parent / "imagenet" / "imagenet_class_index.json",
|
||||
)
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt",
|
||||
Path(__file__).parent
|
||||
/ "imagenet"
|
||||
/ "imagenet_2012_validation_synset_labels.txt",
|
||||
)
|
||||
fetch(
|
||||
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar",
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
|
||||
) # 7GB
|
||||
imagenet_extract(
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
|
||||
Path(__file__).parent / "imagenet" / "val",
|
||||
)
|
||||
imagenet_prepare_val()
|
||||
if os.getenv('IMGNET_TRAIN', None) is not None:
|
||||
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar") #138GB!
|
||||
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "train")
|
||||
if os.getenv("IMGNET_TRAIN", None) is not None:
|
||||
fetch(
|
||||
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar",
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
|
||||
) # 138GB!
|
||||
imagenet_extract(
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
|
||||
Path(__file__).parent / "imagenet" / "train",
|
||||
)
|
||||
imagenet_prepare_train()
|
||||
|
|
|
@ -23,41 +23,70 @@ mv kits extra/datasets
|
|||
```
|
||||
"""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_val_files():
|
||||
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
|
||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
|
||||
data = fetch(
|
||||
"https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt"
|
||||
).read_text()
|
||||
return sorted(
|
||||
[x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")]
|
||||
)
|
||||
|
||||
|
||||
def load_pair(file_path):
|
||||
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
|
||||
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(
|
||||
file_path / "segmentation.nii.gz"
|
||||
)
|
||||
image_spacings = image.header["pixdim"][1:4].tolist()
|
||||
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
|
||||
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(
|
||||
np.uint8
|
||||
)
|
||||
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
|
||||
return image, label, image_spacings
|
||||
|
||||
|
||||
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
|
||||
if image_spacings != target_spacing:
|
||||
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
|
||||
spc_arr, targ_arr, shp_arr = (
|
||||
np.array(image_spacings),
|
||||
np.array(target_spacing),
|
||||
np.array(image.shape[1:]),
|
||||
)
|
||||
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
|
||||
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
|
||||
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
|
||||
image = F.interpolate(
|
||||
torch.from_numpy(np.expand_dims(image, axis=0)),
|
||||
size=new_shape,
|
||||
mode="trilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
label = F.interpolate(
|
||||
torch.from_numpy(np.expand_dims(label, axis=0)),
|
||||
size=new_shape,
|
||||
mode="nearest",
|
||||
)
|
||||
image = np.squeeze(image.numpy(), axis=0)
|
||||
label = np.squeeze(label.numpy(), axis=0)
|
||||
return image, label
|
||||
|
||||
|
||||
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
|
||||
image = np.clip(image, min_clip, max_clip)
|
||||
image = (image - mean) / std
|
||||
return image
|
||||
|
||||
|
||||
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
|
||||
current_shape = image.shape[1:]
|
||||
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
|
||||
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
|
||||
paddings = [(0, 0)] + [
|
||||
(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)
|
||||
]
|
||||
image = np.pad(image, paddings, mode="edge")
|
||||
label = np.pad(label, paddings, mode="edge")
|
||||
return image, label
|
||||
|
||||
|
||||
def preprocess(file_path):
|
||||
image, label, image_spacings = load_pair(file_path)
|
||||
image, label = resample3d(image, label, image_spacings)
|
||||
|
@ -65,16 +94,20 @@ def preprocess(file_path):
|
|||
image, label = pad_to_min_shape(image, label)
|
||||
return image, label
|
||||
|
||||
|
||||
def iterate(val=True, shuffle=False):
|
||||
if not val: raise NotImplementedError
|
||||
if not val:
|
||||
raise NotImplementedError
|
||||
files = get_val_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle: random.shuffle(order)
|
||||
if shuffle:
|
||||
random.shuffle(order)
|
||||
for file in files:
|
||||
X, Y = preprocess(file)
|
||||
X = np.expand_dims(X, axis=0)
|
||||
yield (X, Y)
|
||||
|
||||
|
||||
def gaussian_kernel(n, std):
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
|
@ -84,14 +117,44 @@ def gaussian_kernel(n, std):
|
|||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
|
||||
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
|
||||
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
|
||||
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
|
||||
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
|
||||
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
|
||||
|
||||
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
|
||||
def pad_input(
|
||||
volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3
|
||||
):
|
||||
bounds = [
|
||||
(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)
|
||||
]
|
||||
bounds = [
|
||||
bounds[i]
|
||||
if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i]
|
||||
else bounds[i] + strides[i]
|
||||
for i in range(dim)
|
||||
]
|
||||
paddings = [
|
||||
bounds[2] // 2,
|
||||
bounds[2] - bounds[2] // 2,
|
||||
bounds[1] // 2,
|
||||
bounds[1] - bounds[1] // 2,
|
||||
bounds[0] // 2,
|
||||
bounds[0] - bounds[0] // 2,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
return (
|
||||
F.pad(
|
||||
torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val
|
||||
).numpy(),
|
||||
paddings,
|
||||
)
|
||||
|
||||
|
||||
def sliding_window_inference(
|
||||
model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5
|
||||
):
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
mdl_run = TinyJit(lambda x: model(x).realize())
|
||||
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
||||
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
|
||||
|
@ -119,13 +182,40 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
|
|||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
|
||||
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
|
||||
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
|
||||
out = mdl_run(
|
||||
Tensor(
|
||||
inputs[
|
||||
...,
|
||||
i : roi_shape[0] + i,
|
||||
j : roi_shape[1] + j,
|
||||
k : roi_shape[2] + k,
|
||||
]
|
||||
)
|
||||
).numpy()
|
||||
result[
|
||||
...,
|
||||
i : roi_shape[0] + i,
|
||||
j : roi_shape[1] + j,
|
||||
k : roi_shape[2] + k,
|
||||
] += (
|
||||
out * norm_patch
|
||||
)
|
||||
norm_map[
|
||||
...,
|
||||
i : roi_shape[0] + i,
|
||||
j : roi_shape[1] + j,
|
||||
k : roi_shape[2] + k,
|
||||
] += norm_patch
|
||||
result /= norm_map
|
||||
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
|
||||
result = result[
|
||||
...,
|
||||
paddings[4] : image_shape[0] + paddings[4],
|
||||
paddings[2] : image_shape[1] + paddings[2],
|
||||
paddings[0] : image_shape[2] + paddings[0],
|
||||
]
|
||||
return result, labels
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for X, Y in iterate():
|
||||
print(X.shape, Y.shape)
|
||||
|
|
|
@ -19,17 +19,30 @@ BASEDIR = pathlib.Path(__file__).parent / "librispeech"
|
|||
with open(BASEDIR / "dev-clean-wav.json") as f:
|
||||
ci = json.load(f)
|
||||
|
||||
FILTER_BANK = np.expand_dims(librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0)
|
||||
FILTER_BANK = np.expand_dims(
|
||||
librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0
|
||||
)
|
||||
WINDOW = librosa.filters.get_window("hann", 320)
|
||||
|
||||
|
||||
def feature_extract(x, x_lens):
|
||||
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
|
||||
|
||||
# pre-emphasis
|
||||
x = np.concatenate((np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1)
|
||||
x = np.concatenate(
|
||||
(np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1
|
||||
)
|
||||
|
||||
# stft
|
||||
x = librosa.stft(x, n_fft=512, window=WINDOW, hop_length=160, win_length=320, center=True, pad_mode="reflect")
|
||||
x = librosa.stft(
|
||||
x,
|
||||
n_fft=512,
|
||||
window=WINDOW,
|
||||
hop_length=160,
|
||||
win_length=320,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
x = np.stack((x.real, x.imag), axis=-1)
|
||||
|
||||
# power spectrum
|
||||
|
@ -56,18 +69,24 @@ def feature_extract(x, x_lens):
|
|||
features_mean[i, :] = features[i, :, : x_lens[i]].mean(axis=1)
|
||||
features_std[i, :] = features[i, :, : x_lens[i]].std(axis=1, ddof=1)
|
||||
features_std += 1e-5
|
||||
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(features_std, 2)
|
||||
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(
|
||||
features_std, 2
|
||||
)
|
||||
|
||||
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
|
||||
|
||||
|
||||
def load_wav(file):
|
||||
sample = soundfile.read(file)[0].astype(np.float32)
|
||||
return sample, sample.shape[0]
|
||||
|
||||
|
||||
def iterate(bs=1, start=0):
|
||||
print(f"there are {len(ci)} samples in the dataset")
|
||||
for i in range(start, len(ci), bs):
|
||||
samples, sample_lens = zip(*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]])
|
||||
samples, sample_lens = zip(
|
||||
*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]]
|
||||
)
|
||||
samples = list(samples)
|
||||
# pad to same length
|
||||
max_len = max(sample_lens)
|
||||
|
@ -75,7 +94,10 @@ def iterate(bs=1, start=0):
|
|||
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
|
||||
samples, sample_lens = np.array(samples), np.array(sample_lens)
|
||||
|
||||
yield feature_extract(samples, sample_lens), np.array([v["transcript"] for v in ci[i : i + bs]])
|
||||
yield feature_extract(samples, sample_lens), np.array(
|
||||
[v["transcript"] for v in ci[i : i + bs]]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
X, Y = next(iterate())
|
||||
|
|
|
@ -12,133 +12,441 @@ import concurrent.futures
|
|||
|
||||
BASEDIR = pathlib.Path(__file__).parent / "open-images-v6-mlperf"
|
||||
BUCKET_NAME = "open-images-dataset"
|
||||
BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
||||
MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
||||
MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana',
|
||||
'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle',
|
||||
'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot',
|
||||
'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread',
|
||||
'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry',
|
||||
'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart',
|
||||
'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken',
|
||||
'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin',
|
||||
'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store',
|
||||
'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard',
|
||||
'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly',
|
||||
'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant',
|
||||
'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork',
|
||||
'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses',
|
||||
'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar',
|
||||
'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels',
|
||||
'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard',
|
||||
'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair',
|
||||
'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream',
|
||||
'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite',
|
||||
'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse',
|
||||
'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror',
|
||||
'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule',
|
||||
'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building',
|
||||
'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen',
|
||||
'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow',
|
||||
'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle',
|
||||
'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion',
|
||||
'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard',
|
||||
'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon',
|
||||
'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light',
|
||||
'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan',
|
||||
'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television',
|
||||
'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower',
|
||||
'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase',
|
||||
'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch',
|
||||
'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman',
|
||||
'Zebra', 'Zucchini',
|
||||
BBOX_ANNOTATIONS_URL = (
|
||||
"https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
||||
)
|
||||
MAP_CLASSES_URL = (
|
||||
"https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
||||
)
|
||||
MLPERF_CLASSES = [
|
||||
"Airplane",
|
||||
"Antelope",
|
||||
"Apple",
|
||||
"Backpack",
|
||||
"Balloon",
|
||||
"Banana",
|
||||
"Barrel",
|
||||
"Baseball bat",
|
||||
"Baseball glove",
|
||||
"Bee",
|
||||
"Beer",
|
||||
"Bench",
|
||||
"Bicycle",
|
||||
"Bicycle helmet",
|
||||
"Bicycle wheel",
|
||||
"Billboard",
|
||||
"Book",
|
||||
"Bookcase",
|
||||
"Boot",
|
||||
"Bottle",
|
||||
"Bowl",
|
||||
"Bowling equipment",
|
||||
"Box",
|
||||
"Boy",
|
||||
"Brassiere",
|
||||
"Bread",
|
||||
"Broccoli",
|
||||
"Bronze sculpture",
|
||||
"Bull",
|
||||
"Bus",
|
||||
"Bust",
|
||||
"Butterfly",
|
||||
"Cabinetry",
|
||||
"Cake",
|
||||
"Camel",
|
||||
"Camera",
|
||||
"Candle",
|
||||
"Candy",
|
||||
"Cannon",
|
||||
"Canoe",
|
||||
"Carrot",
|
||||
"Cart",
|
||||
"Castle",
|
||||
"Cat",
|
||||
"Cattle",
|
||||
"Cello",
|
||||
"Chair",
|
||||
"Cheese",
|
||||
"Chest of drawers",
|
||||
"Chicken",
|
||||
"Christmas tree",
|
||||
"Coat",
|
||||
"Cocktail",
|
||||
"Coffee",
|
||||
"Coffee cup",
|
||||
"Coffee table",
|
||||
"Coin",
|
||||
"Common sunflower",
|
||||
"Computer keyboard",
|
||||
"Computer monitor",
|
||||
"Convenience store",
|
||||
"Cookie",
|
||||
"Countertop",
|
||||
"Cowboy hat",
|
||||
"Crab",
|
||||
"Crocodile",
|
||||
"Cucumber",
|
||||
"Cupboard",
|
||||
"Curtain",
|
||||
"Deer",
|
||||
"Desk",
|
||||
"Dinosaur",
|
||||
"Dog",
|
||||
"Doll",
|
||||
"Dolphin",
|
||||
"Door",
|
||||
"Dragonfly",
|
||||
"Drawer",
|
||||
"Dress",
|
||||
"Drum",
|
||||
"Duck",
|
||||
"Eagle",
|
||||
"Earrings",
|
||||
"Egg (Food)",
|
||||
"Elephant",
|
||||
"Falcon",
|
||||
"Fedora",
|
||||
"Flag",
|
||||
"Flowerpot",
|
||||
"Football",
|
||||
"Football helmet",
|
||||
"Fork",
|
||||
"Fountain",
|
||||
"French fries",
|
||||
"French horn",
|
||||
"Frog",
|
||||
"Giraffe",
|
||||
"Girl",
|
||||
"Glasses",
|
||||
"Goat",
|
||||
"Goggles",
|
||||
"Goldfish",
|
||||
"Gondola",
|
||||
"Goose",
|
||||
"Grape",
|
||||
"Grapefruit",
|
||||
"Guitar",
|
||||
"Hamburger",
|
||||
"Handbag",
|
||||
"Harbor seal",
|
||||
"Headphones",
|
||||
"Helicopter",
|
||||
"High heels",
|
||||
"Hiking equipment",
|
||||
"Horse",
|
||||
"House",
|
||||
"Houseplant",
|
||||
"Human arm",
|
||||
"Human beard",
|
||||
"Human body",
|
||||
"Human ear",
|
||||
"Human eye",
|
||||
"Human face",
|
||||
"Human foot",
|
||||
"Human hair",
|
||||
"Human hand",
|
||||
"Human head",
|
||||
"Human leg",
|
||||
"Human mouth",
|
||||
"Human nose",
|
||||
"Ice cream",
|
||||
"Jacket",
|
||||
"Jeans",
|
||||
"Jellyfish",
|
||||
"Juice",
|
||||
"Kitchen & dining room table",
|
||||
"Kite",
|
||||
"Lamp",
|
||||
"Lantern",
|
||||
"Laptop",
|
||||
"Lavender (Plant)",
|
||||
"Lemon",
|
||||
"Light bulb",
|
||||
"Lighthouse",
|
||||
"Lily",
|
||||
"Lion",
|
||||
"Lipstick",
|
||||
"Lizard",
|
||||
"Man",
|
||||
"Maple",
|
||||
"Microphone",
|
||||
"Mirror",
|
||||
"Mixing bowl",
|
||||
"Mobile phone",
|
||||
"Monkey",
|
||||
"Motorcycle",
|
||||
"Muffin",
|
||||
"Mug",
|
||||
"Mule",
|
||||
"Mushroom",
|
||||
"Musical keyboard",
|
||||
"Necklace",
|
||||
"Nightstand",
|
||||
"Office building",
|
||||
"Orange",
|
||||
"Owl",
|
||||
"Oyster",
|
||||
"Paddle",
|
||||
"Palm tree",
|
||||
"Parachute",
|
||||
"Parrot",
|
||||
"Pen",
|
||||
"Penguin",
|
||||
"Personal flotation device",
|
||||
"Piano",
|
||||
"Picture frame",
|
||||
"Pig",
|
||||
"Pillow",
|
||||
"Pizza",
|
||||
"Plate",
|
||||
"Platter",
|
||||
"Porch",
|
||||
"Poster",
|
||||
"Pumpkin",
|
||||
"Rabbit",
|
||||
"Rifle",
|
||||
"Roller skates",
|
||||
"Rose",
|
||||
"Salad",
|
||||
"Sandal",
|
||||
"Saucer",
|
||||
"Saxophone",
|
||||
"Scarf",
|
||||
"Sea lion",
|
||||
"Sea turtle",
|
||||
"Sheep",
|
||||
"Shelf",
|
||||
"Shirt",
|
||||
"Shorts",
|
||||
"Shrimp",
|
||||
"Sink",
|
||||
"Skateboard",
|
||||
"Ski",
|
||||
"Skull",
|
||||
"Skyscraper",
|
||||
"Snake",
|
||||
"Sock",
|
||||
"Sofa bed",
|
||||
"Sparrow",
|
||||
"Spider",
|
||||
"Spoon",
|
||||
"Sports uniform",
|
||||
"Squirrel",
|
||||
"Stairs",
|
||||
"Stool",
|
||||
"Strawberry",
|
||||
"Street light",
|
||||
"Studio couch",
|
||||
"Suit",
|
||||
"Sun hat",
|
||||
"Sunglasses",
|
||||
"Surfboard",
|
||||
"Sushi",
|
||||
"Swan",
|
||||
"Swimming pool",
|
||||
"Swimwear",
|
||||
"Tank",
|
||||
"Tap",
|
||||
"Taxi",
|
||||
"Tea",
|
||||
"Teddy bear",
|
||||
"Television",
|
||||
"Tent",
|
||||
"Tie",
|
||||
"Tiger",
|
||||
"Tin can",
|
||||
"Tire",
|
||||
"Toilet",
|
||||
"Tomato",
|
||||
"Tortoise",
|
||||
"Tower",
|
||||
"Traffic light",
|
||||
"Train",
|
||||
"Tripod",
|
||||
"Truck",
|
||||
"Trumpet",
|
||||
"Umbrella",
|
||||
"Van",
|
||||
"Vase",
|
||||
"Vehicle registration plate",
|
||||
"Violin",
|
||||
"Wall clock",
|
||||
"Waste container",
|
||||
"Watch",
|
||||
"Whale",
|
||||
"Wheel",
|
||||
"Wheelchair",
|
||||
"Whiteboard",
|
||||
"Window",
|
||||
"Wine",
|
||||
"Wine glass",
|
||||
"Woman",
|
||||
"Zebra",
|
||||
"Zucchini",
|
||||
]
|
||||
|
||||
|
||||
def openimages():
|
||||
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
|
||||
if not ann_file.is_file():
|
||||
fetch_openimages(ann_file)
|
||||
return ann_file
|
||||
|
||||
|
||||
# this slows down the conversion a lot!
|
||||
# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py
|
||||
def extract_dims(path): return Image.open(path).size[::-1]
|
||||
def extract_dims(path):
|
||||
return Image.open(path).size[::-1]
|
||||
|
||||
def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES):
|
||||
|
||||
def export_to_coco(
|
||||
class_map,
|
||||
annotations,
|
||||
image_list,
|
||||
dataset_path,
|
||||
output_path,
|
||||
classes=MLPERF_CLASSES,
|
||||
):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
|
||||
categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"])
|
||||
class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner")
|
||||
categories_map = pd.DataFrame(
|
||||
[(i, c) for i, c in enumerate(classes)],
|
||||
columns=["category_id", "category_name"],
|
||||
)
|
||||
class_map = class_map.merge(
|
||||
categories_map, left_on="DisplayName", right_on="category_name", how="inner"
|
||||
)
|
||||
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
|
||||
annotations = annotations.merge(class_map, on="LabelName", how="inner")
|
||||
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
|
||||
annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand")
|
||||
annotations[["height", "width"]] = annotations.apply(
|
||||
lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"),
|
||||
axis=1,
|
||||
result_type="expand",
|
||||
)
|
||||
|
||||
# Images
|
||||
imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None}
|
||||
for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows())
|
||||
imgs = [
|
||||
{
|
||||
"id": int(id + 1),
|
||||
"file_name": f"{image_id}.jpg",
|
||||
"height": row["height"],
|
||||
"width": row["width"],
|
||||
"license": None,
|
||||
"coco_url": None,
|
||||
}
|
||||
for (id, image_id), row in (
|
||||
annotations.groupby(["image_id", "ImageID"]).first().iterrows()
|
||||
)
|
||||
]
|
||||
|
||||
# Annotations
|
||||
annots = []
|
||||
for i, row in annotations.iterrows():
|
||||
xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]]
|
||||
x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h
|
||||
coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h}
|
||||
coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]})
|
||||
xmin, ymin, xmax, ymax, img_w, img_h = [
|
||||
row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]
|
||||
]
|
||||
x, y, w, h = (
|
||||
xmin * img_w,
|
||||
ymin * img_h,
|
||||
(xmax - xmin) * img_w,
|
||||
(ymax - ymin) * img_h,
|
||||
)
|
||||
coco_annot = {
|
||||
"id": int(i) + 1,
|
||||
"image_id": int(row["image_id"] + 1),
|
||||
"category_id": int(row["category_id"]),
|
||||
"bbox": [x, y, w, h],
|
||||
"area": w * h,
|
||||
}
|
||||
coco_annot.update(
|
||||
{
|
||||
k: row[k]
|
||||
for k in [
|
||||
"IsOccluded",
|
||||
"IsInside",
|
||||
"IsDepiction",
|
||||
"IsTruncated",
|
||||
"IsGroupOf",
|
||||
]
|
||||
}
|
||||
)
|
||||
coco_annot["iscrowd"] = int(row["IsGroupOf"])
|
||||
annots.append(coco_annot)
|
||||
|
||||
info = {"dataset": "openimages_mlperf", "version": "v6"}
|
||||
coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots}
|
||||
coco_annotations = {
|
||||
"info": info,
|
||||
"licenses": [],
|
||||
"categories": cats,
|
||||
"images": imgs,
|
||||
"annotations": annots,
|
||||
}
|
||||
with open(output_path, "w") as fp:
|
||||
json.dump(coco_annotations, fp)
|
||||
|
||||
|
||||
def get_image_list(class_map, annotations, classes=MLPERF_CLASSES):
|
||||
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
|
||||
image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique()
|
||||
image_ids = annotations[np.isin(annotations["LabelName"], labels)][
|
||||
"ImageID"
|
||||
].unique()
|
||||
return image_ids
|
||||
|
||||
|
||||
def download_image(bucket, image_id, data_dir):
|
||||
try:
|
||||
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}")
|
||||
sys.exit(
|
||||
f"ERROR when downloading image `validation/{image_id}`: {str(exception)}"
|
||||
)
|
||||
|
||||
|
||||
def fetch_openimages(output_fn):
|
||||
bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
|
||||
bucket = boto3.resource(
|
||||
"s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)
|
||||
).Bucket(BUCKET_NAME)
|
||||
|
||||
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
|
||||
annotations_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1]
|
||||
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split("/")[-1]
|
||||
fetch(BBOX_ANNOTATIONS_URL, annotations_fn)
|
||||
annotations = pd.read_csv(annotations_fn)
|
||||
|
||||
classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1]
|
||||
classmap_fn = annotations_dir / MAP_CLASSES_URL.split("/")[-1]
|
||||
fetch(MAP_CLASSES_URL, classmap_fn)
|
||||
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
|
||||
|
||||
image_list = get_image_list(class_map, annotations)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list]
|
||||
for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))):
|
||||
futures = [
|
||||
executor.submit(download_image, bucket, image_id, data_dir)
|
||||
for image_id in image_list
|
||||
]
|
||||
for future in (
|
||||
t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))
|
||||
):
|
||||
t.set_description(f"Downloading images")
|
||||
future.result()
|
||||
|
||||
print("Converting annotations to COCO format...")
|
||||
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
|
||||
|
||||
|
||||
def image_load(fn):
|
||||
img_folder = BASEDIR / "validation/data"
|
||||
img = Image.open(img_folder / fn).convert('RGB')
|
||||
img = Image.open(img_folder / fn).convert("RGB")
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
ret = F.resize(img, size=(800, 800))
|
||||
ret = np.array(ret)
|
||||
return ret, img.size[::-1]
|
||||
|
||||
|
||||
def prepare_target(annotations, img_id, img_size):
|
||||
boxes = [annot["bbox"] for annot in annotations]
|
||||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
|
||||
|
@ -150,7 +458,13 @@ def prepare_target(annotations, img_id, img_size):
|
|||
classes = [annot["category_id"] for annot in annotations]
|
||||
classes = np.array(classes, dtype=np.int64)
|
||||
classes = classes[keep]
|
||||
return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size}
|
||||
return {
|
||||
"boxes": boxes,
|
||||
"labels": classes,
|
||||
"image_id": img_id,
|
||||
"image_size": img_size,
|
||||
}
|
||||
|
||||
|
||||
def iterate(coco, bs=8):
|
||||
image_ids = sorted(coco.imgs.keys())
|
||||
|
|
|
@ -13,10 +13,15 @@ if __name__ == "__main__":
|
|||
assert x.shape[0] == y.shape[0]
|
||||
bs = x.shape[0]
|
||||
if X is None:
|
||||
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8)
|
||||
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64)
|
||||
X = Tensor.empty(
|
||||
sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8
|
||||
)
|
||||
Y = Tensor.empty(
|
||||
sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64
|
||||
)
|
||||
print(X.shape, Y.shape)
|
||||
X[idx : idx + bs].assign(x)
|
||||
Y[idx : idx + bs].assign(y)
|
||||
idx += bs
|
||||
if idx >= sz: break
|
||||
if idx >= sz:
|
||||
break
|
||||
|
|
|
@ -6,9 +6,14 @@ import numpy as np
|
|||
from tinygrad.helpers import fetch
|
||||
|
||||
BASEDIR = Path(__file__).parent / "squad"
|
||||
|
||||
|
||||
def init_dataset():
|
||||
os.makedirs(BASEDIR, exist_ok=True)
|
||||
fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
|
||||
fetch(
|
||||
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
|
||||
BASEDIR / "dev-v1.1.json",
|
||||
)
|
||||
with open(BASEDIR / "dev-v1.1.json") as f:
|
||||
data = json.load(f)["data"]
|
||||
|
||||
|
@ -32,14 +37,17 @@ def init_dataset():
|
|||
qa_id = qa["id"]
|
||||
q_text = qa["question"]
|
||||
|
||||
examples.append({
|
||||
examples.append(
|
||||
{
|
||||
"id": qa_id,
|
||||
"question": q_text,
|
||||
"context": doc_tokens,
|
||||
"answers": list(map(lambda x: x["text"], qa["answers"]))
|
||||
})
|
||||
"answers": list(map(lambda x: x["text"], qa["answers"])),
|
||||
}
|
||||
)
|
||||
return examples
|
||||
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
best_score, best_span_index = None, None
|
||||
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
||||
|
@ -56,6 +64,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
|||
best_span_index = di
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def convert_example_to_features(example, tokenizer):
|
||||
query_tokens = tokenizer.tokenize(example["question"])
|
||||
|
||||
|
@ -101,7 +110,9 @@ def convert_example_to_features(example, tokenizer):
|
|||
for i in range(doc_length):
|
||||
split_token_index = doc_start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
|
||||
token_is_max_context[len(tokens)] = _check_is_max_context(
|
||||
doc_spans, di, split_token_index
|
||||
)
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
|
@ -119,17 +130,24 @@ def convert_example_to_features(example, tokenizer):
|
|||
assert len(input_mask) == 384
|
||||
assert len(segment_ids) == 384
|
||||
|
||||
outputs.append({
|
||||
outputs.append(
|
||||
{
|
||||
"input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
|
||||
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
|
||||
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
|
||||
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(
|
||||
np.float32
|
||||
),
|
||||
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(
|
||||
np.float32
|
||||
),
|
||||
"token_to_orig_map": token_to_orig_map,
|
||||
"token_is_max_context": token_is_max_context,
|
||||
"tokens": tokens,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def iterate(tokenizer, start=0):
|
||||
examples = init_dataset()
|
||||
print(f"there are {len(examples)} pairs in the dataset")
|
||||
|
@ -140,8 +158,11 @@ def iterate(tokenizer, start=0):
|
|||
# we need to yield all features here as the f1 score is the maximum over all features
|
||||
yield features, example
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))
|
||||
tokenizer = BertTokenizer(
|
||||
str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt")
|
||||
)
|
||||
|
||||
X, Y = next(iterate(tokenizer))
|
||||
print(" ".join(X[0]["tokens"]))
|
||||
|
|
|
@ -5,11 +5,13 @@ from tinygrad.helpers import DEBUG, getenv
|
|||
import multiprocessing as mp
|
||||
import os
|
||||
|
||||
|
||||
# this needs to be called before everything else if you are using distributed
|
||||
def preinit():
|
||||
os.environ["DELAYED_RUNTIME_INIT"] = "1"
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
|
||||
# out-of-band communication/synchronization
|
||||
class _OOB:
|
||||
def __init__(self, pipes: List[Tuple[Connection, Connection]]):
|
||||
|
@ -22,14 +24,18 @@ class _OOB:
|
|||
# receive some data from a target rank, blocks until data is received
|
||||
def recv(self, target_rank: int) -> Any:
|
||||
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
|
||||
|
||||
|
||||
OOB: Optional[_OOB] = None
|
||||
|
||||
|
||||
def init_oob(world_size: int):
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
global OOB
|
||||
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
|
||||
|
||||
|
||||
# this runs in the spawned process so we can do all the delayed runtime initialization
|
||||
def _process_wrap(rank: int, device: str, oob: _OOB, fn: Callable, args=()):
|
||||
# setup the rank
|
||||
|
@ -41,19 +47,27 @@ def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
|
|||
|
||||
# do specific runtime initialization for distributed
|
||||
from tinygrad import Device
|
||||
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1])
|
||||
|
||||
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(
|
||||
device.split(":")[-1]
|
||||
)
|
||||
if "GPU" in device:
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
|
||||
CL.post_init(device_num)
|
||||
elif "HIP" in device:
|
||||
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(device_num)
|
||||
if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}")
|
||||
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(
|
||||
device_num
|
||||
)
|
||||
if DEBUG >= 1:
|
||||
print(f"distributed process {rank} initialized runtime for device {device}")
|
||||
|
||||
# convert device to be process specific
|
||||
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
|
||||
|
||||
fn(*args)
|
||||
|
||||
|
||||
# wrapper around mp.Process that initializes the runtime
|
||||
def spawn(rank: int, device: str, fn: Callable, args=()) -> mp.Process:
|
||||
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()
|
||||
|
|
|
@ -3,6 +3,7 @@ from tinygrad.helpers import getenv
|
|||
|
||||
from extra.dist import world
|
||||
|
||||
|
||||
def allreduce(t: Tensor) -> Tensor:
|
||||
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
||||
|
||||
|
@ -11,7 +12,9 @@ def allreduce(t:Tensor) -> Tensor:
|
|||
|
||||
# pad to evenly divide
|
||||
if flattened.shape[0] % WORLD_SIZE != 0:
|
||||
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
||||
flattened = Tensor.cat(
|
||||
flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE))
|
||||
)
|
||||
|
||||
# chunk
|
||||
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
||||
|
|
|
@ -4,15 +4,18 @@ from multiprocessing import shared_memory
|
|||
from tinygrad.helpers import DEBUG, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut
|
||||
|
||||
try:
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
|
||||
except: RawHIPBuffer = None
|
||||
except:
|
||||
RawHIPBuffer = None
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.tensor import Tensor, Function
|
||||
import numpy as np
|
||||
|
||||
|
||||
# match the function signature of JITRunner so we can put it in the cache
|
||||
def __send_rb(args, variables=None, wait=False, jit=False):
|
||||
x, target_rank, y = args[:3]
|
||||
|
@ -20,19 +23,31 @@ def __send_rb(args, variables=None, wait=False, jit=False):
|
|||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
else:
|
||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else: y.fromCPU(x.toCPU())
|
||||
if isinstance(x, RawBufferCopyInOut):
|
||||
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else:
|
||||
y.fromCPU(x.toCPU())
|
||||
dist.OOB.send(None, target_rank)
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
if DEBUG >= 2:
|
||||
print(
|
||||
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}"
|
||||
)
|
||||
|
||||
|
||||
def __recv_rb(args, variables=None, wait=False, jit=False):
|
||||
x, target_rank, y = args[:3]
|
||||
dist.OOB.recv(target_rank)
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
x._transfer(y)
|
||||
elif isinstance(x, RawBuffer): x._copyin(y.toCPU())
|
||||
else: x.fromCPU(y.toCPU())
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}")
|
||||
elif isinstance(x, RawBuffer):
|
||||
x._copyin(y.toCPU())
|
||||
else:
|
||||
x.fromCPU(y.toCPU())
|
||||
if DEBUG >= 2:
|
||||
print(
|
||||
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}"
|
||||
)
|
||||
|
||||
|
||||
# send a rawbuffer from out rank to the target rank
|
||||
def _send_rb(x: RawBuffer, target_rank: int):
|
||||
|
@ -40,7 +55,11 @@ def _send_rb(x:RawBuffer, target_rank:int):
|
|||
# send ipc handle
|
||||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
check(hip.hipIpcGetMemHandle(ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf))
|
||||
check(
|
||||
hip.hipIpcGetMemHandle(
|
||||
ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf
|
||||
)
|
||||
)
|
||||
dist.OOB.send((handle, x._device), target_rank)
|
||||
|
||||
# jit support
|
||||
|
@ -48,20 +67,26 @@ def _send_rb(x:RawBuffer, target_rank:int):
|
|||
CacheCollector.add(__send_rb, [x, target_rank, None], {})
|
||||
else:
|
||||
# create shared memory
|
||||
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
|
||||
shm_name = (
|
||||
s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)
|
||||
).name
|
||||
s.close()
|
||||
|
||||
# copy the buffer into shared memory
|
||||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
||||
# fast path when we can directly copyout
|
||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else: y.fromCPU(x.toCPU())
|
||||
if isinstance(x, RawBufferCopyInOut):
|
||||
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else:
|
||||
y.fromCPU(x.toCPU())
|
||||
|
||||
dist.OOB.send(shm_name, target_rank)
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__send_rb, [x, target_rank, y], {})
|
||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
if DEBUG >= 2:
|
||||
print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
|
||||
|
||||
# receive a rawbuffer from the target rank
|
||||
def _recv_rb(x: RawBuffer, target_rank: int):
|
||||
|
@ -69,7 +94,9 @@ def _recv_rb(x:RawBuffer, target_rank:int):
|
|||
# open ipc handle
|
||||
handle, y_device = dist.OOB.recv(target_rank)
|
||||
check(hip.hipSetDevice(y_device))
|
||||
check(hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0))
|
||||
check(
|
||||
hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0)
|
||||
)
|
||||
|
||||
# build a new buffer
|
||||
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
|
||||
|
@ -81,34 +108,50 @@ def _recv_rb(x:RawBuffer, target_rank:int):
|
|||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
||||
|
||||
# fast path when we can directly copyin
|
||||
if isinstance(x, RawBuffer): x._copyin(y.toCPU())
|
||||
else: x.fromCPU(y.toCPU())
|
||||
if isinstance(x, RawBuffer):
|
||||
x._copyin(y.toCPU())
|
||||
else:
|
||||
x.fromCPU(y.toCPU())
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
||||
if DEBUG >= 2:
|
||||
print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
||||
|
||||
|
||||
# sends a lazybuffer from our rank to the target rank
|
||||
def _send_lb(x: LazyBuffer, target_rank: int) -> None:
|
||||
assert x.st.contiguous and x.realized, "sending buffer must be contiguous and realized"
|
||||
assert (
|
||||
x.st.contiguous and x.realized
|
||||
), "sending buffer must be contiguous and realized"
|
||||
_send_rb(x.realized, target_rank)
|
||||
|
||||
|
||||
# receive a lazybuffer from the target rank
|
||||
def _recv_lb(x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||
assert x.st.contiguous and x.realized, "receiving buffer must be contiguous and realized"
|
||||
assert (
|
||||
x.st.contiguous and x.realized
|
||||
), "receiving buffer must be contiguous and realized"
|
||||
_recv_rb(x.realized, target_rank)
|
||||
return x
|
||||
|
||||
|
||||
class Send(Function):
|
||||
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||
self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype
|
||||
_send_lb(x, target_rank)
|
||||
return x
|
||||
|
||||
|
||||
class Recv(Function):
|
||||
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||
self.target_rank = target_rank
|
||||
return _recv_lb(x, target_rank)
|
||||
|
||||
def send(x:Tensor, target_rank:int) -> Tensor: return Send.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
def recv(x:Tensor, target_rank:int) -> Tensor: return Recv.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
|
||||
def send(x: Tensor, target_rank: int) -> Tensor:
|
||||
return Send.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
|
||||
|
||||
def recv(x: Tensor, target_rank: int) -> Tensor:
|
||||
return Recv.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
|
|
|
@ -17,5 +17,10 @@ if __name__ == "__main__":
|
|||
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
||||
for f in cur3.fetchall():
|
||||
v = pickle.loads(f[-1])
|
||||
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
|
||||
print(
|
||||
" ",
|
||||
len(f[0]) if isinstance(f[0], str) else f[0],
|
||||
f[1:-1],
|
||||
str(v)[0:50],
|
||||
)
|
||||
# print(f"{len(k):10d}, {sk} -> {v}")
|
||||
|
|
|
@ -7,77 +7,190 @@ import json
|
|||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
|
||||
def compile_net(
|
||||
run: TinyJit, special_names: Dict[int, str]
|
||||
) -> Tuple[
|
||||
Dict[str, str],
|
||||
List[Tuple[str, List[str], List[int]]],
|
||||
Dict[str, Tuple[int, DType, int]],
|
||||
Dict[str, Tensor],
|
||||
]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for ji in run.jit_cache:
|
||||
fxn = ji.prg
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
functions[
|
||||
fxn.name
|
||||
] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i, arg in enumerate(ji.rawbufs):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key)
|
||||
bufs[key] = (
|
||||
special_names[key],
|
||||
arg.size * arg.dtype.itemsize,
|
||||
arg.dtype,
|
||||
key,
|
||||
)
|
||||
else:
|
||||
bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key)
|
||||
bufs[key] = (
|
||||
f"buf_{bufnum}",
|
||||
arg.size * arg.dtype.itemsize,
|
||||
arg.dtype,
|
||||
key,
|
||||
)
|
||||
bufnum += 1
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
if i > 0:
|
||||
bufs_to_save[
|
||||
bufs[key][0]
|
||||
] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
|
||||
return (
|
||||
functions,
|
||||
statements,
|
||||
{name: (size, dtype, key) for (name, size, dtype, key) in bufs.values()},
|
||||
bufs_to_save,
|
||||
)
|
||||
|
||||
|
||||
def jit_model(model, *args) -> Tuple[TinyJit, Dict[int, str]]:
|
||||
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
|
||||
assert hasattr(model, "forward") or callable(
|
||||
model
|
||||
), "model needs a forward function"
|
||||
|
||||
@TinyJit
|
||||
def run(*x):
|
||||
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
|
||||
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||
assert (
|
||||
isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor)
|
||||
), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||
out = [out] if isinstance(out, Tensor) else out
|
||||
return [o.realize() for o in out]
|
||||
|
||||
# twice to run the JIT
|
||||
for _ in range(2): the_output = run(*args)
|
||||
for _ in range(2):
|
||||
the_output = run(*args)
|
||||
special_names = {}
|
||||
|
||||
# hack to put the inputs back
|
||||
for (j, i), idx in run.input_replace.items():
|
||||
realized_input = args[idx].lazydata.realized
|
||||
run.jit_cache[j].rawbufs[i] = realized_input
|
||||
special_names[id(realized_input)] = f'input{idx}'
|
||||
special_names[id(realized_input)] = f"input{idx}"
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
for i, output in enumerate(the_output):
|
||||
special_names[id(output.lazydata.realized)] = f'output{i}'
|
||||
special_names[id(output.lazydata.realized)] = f"output{i}"
|
||||
return run, special_names
|
||||
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
|
||||
|
||||
def export_model_clang(
|
||||
functions: Dict[str, str],
|
||||
statements: Dict[str, Tuple[str, int, int]],
|
||||
bufs: Dict[str, Tuple[str, int, int]],
|
||||
bufs_to_save: Dict[str, Tensor],
|
||||
input_names: List[str],
|
||||
output_names: List[str],
|
||||
) -> str:
|
||||
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
||||
|
||||
cprog = [CLANG_PROGRAM_HEADER]
|
||||
|
||||
for name, cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
weight = "".join(["\\x%02X" % x for x in bytes(cl._buf)])
|
||||
cprog.append(f'unsigned char {name}_data[] = "{weight}";')
|
||||
|
||||
inputs = ", ".join([f'float* {input}' for input in input_names])
|
||||
outputs = ", ".join([f'float* {output}' for output in output_names])
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
|
||||
inputs = ", ".join([f"float* {input}" for input in input_names])
|
||||
outputs = ", ".join([f"float* {output}" for output in output_names])
|
||||
cprog += [
|
||||
f"float {name}[{len}];"
|
||||
if name not in bufs_to_save
|
||||
else f"float *{name} = (float *){name}_data;"
|
||||
for name, (len, dtype, _key) in bufs.items()
|
||||
if name not in ["input", "outputs"]
|
||||
]
|
||||
cprog += list(functions.values())
|
||||
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
cprog += (
|
||||
[f"void net({inputs}, {outputs}) {{"]
|
||||
+ [
|
||||
f"{name}({', '.join(args)});"
|
||||
for (name, args, _global_size, _local_size) in statements
|
||||
]
|
||||
+ ["}"]
|
||||
)
|
||||
return "\n".join(cprog)
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
|
||||
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
|
||||
def export_model_webgpu(
|
||||
functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names
|
||||
) -> Tuple[str, int, int]:
|
||||
kernel_code = "\n\n".join(
|
||||
[
|
||||
f"const {key} = `{code.replace(key, 'main')}`;"
|
||||
for key, code in functions.items()
|
||||
]
|
||||
)
|
||||
kernel_names = ", ".join(
|
||||
[name for (name, _args, _global_size, _local_size) in statements]
|
||||
)
|
||||
kernel_calls = "\n ".join(
|
||||
[
|
||||
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
|
||||
for i, (_name, args, global_size, _local_size) in enumerate(statements)
|
||||
]
|
||||
)
|
||||
_bufs = "\n ".join(
|
||||
[
|
||||
f"const {name} = "
|
||||
+ (
|
||||
f"createEmptyBuf(device, {size});"
|
||||
if _key not in weight_names
|
||||
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))"
|
||||
)
|
||||
+ ";"
|
||||
for name, (size, dtype, _key) in bufs.items()
|
||||
]
|
||||
)
|
||||
gpu_write_bufs = "\n ".join(
|
||||
[
|
||||
f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
|
||||
for i, input_name in enumerate(input_names)
|
||||
]
|
||||
)
|
||||
input_writers = "\n ".join(
|
||||
[
|
||||
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
|
||||
+ f"_{inp_name});"
|
||||
+ f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);"
|
||||
for i, inp_name in enumerate(input_names)
|
||||
]
|
||||
)
|
||||
gpu_read_bufs = "\n ".join(
|
||||
[
|
||||
f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});"
|
||||
for i, output_name in enumerate(output_names)
|
||||
]
|
||||
)
|
||||
outbuf_copies = "\n ".join(
|
||||
[
|
||||
f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);"
|
||||
for i, output_name in enumerate(output_names)
|
||||
]
|
||||
)
|
||||
output_readers = "\n ".join(
|
||||
[
|
||||
f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();"
|
||||
for i in range(len(output_names))
|
||||
]
|
||||
)
|
||||
output_return = "[{}]".format(
|
||||
",".join([f"resultBuffer{i}" for i in range(len(output_names))])
|
||||
)
|
||||
return (
|
||||
f"""
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
|
@ -134,10 +247,15 @@ const setupNet = async (device, safetensor) => {{
|
|||
return {output_return};
|
||||
}}
|
||||
}}
|
||||
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
"""
|
||||
+ f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
)
|
||||
|
||||
|
||||
def export_model(model, target: str, *inputs):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
assert (
|
||||
Device.DEFAULT in EXPORT_SUPPORTED_DEVICE
|
||||
), "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
run, special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
|
@ -146,34 +264,56 @@ def export_model(model, target:str, *inputs):
|
|||
output_names = [name for _, name in special_names.items() if "output" in name]
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
|
||||
prg = export_model_clang(
|
||||
functions, statements, bufs, bufs_to_save, input_names, output_names
|
||||
)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
|
||||
prg = export_model_webgpu(
|
||||
functions,
|
||||
statements,
|
||||
bufs,
|
||||
bufs_to_save,
|
||||
weight_names,
|
||||
input_names,
|
||||
output_names,
|
||||
)
|
||||
else:
|
||||
prg = json.dumps({
|
||||
prg = json.dumps(
|
||||
{
|
||||
"backend": Device.DEFAULT,
|
||||
"inputs": [{
|
||||
"size": bufs[name][0],
|
||||
"dtype": bufs[name][1].name
|
||||
} for name in input_names],
|
||||
"outputs": [{
|
||||
"size": bufs[name][0],
|
||||
"dtype": bufs[name][1].name
|
||||
} for name in output_names],
|
||||
"inputs": [
|
||||
{"size": bufs[name][0], "dtype": bufs[name][1].name}
|
||||
for name in input_names
|
||||
],
|
||||
"outputs": [
|
||||
{"size": bufs[name][0], "dtype": bufs[name][1].name}
|
||||
for name in output_names
|
||||
],
|
||||
"functions": functions,
|
||||
"statements": [{
|
||||
"statements": [
|
||||
{
|
||||
"kernel": kernel,
|
||||
"args": args,
|
||||
"global_size": global_size,
|
||||
"local_size": local_size
|
||||
} for (kernel, args, global_size, local_size) in statements],
|
||||
"local_size": local_size,
|
||||
}
|
||||
for (kernel, args, global_size, local_size) in statements
|
||||
],
|
||||
"buffers": {
|
||||
name: {
|
||||
"size": size,
|
||||
"dtype": dtype.name,
|
||||
"id": weight_names[_key] if _key in weight_names else ""
|
||||
} for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]
|
||||
"id": weight_names[_key] if _key in weight_names else "",
|
||||
}
|
||||
})
|
||||
for name, (size, dtype, _key) in bufs.items()
|
||||
if name not in ["input", "outputs"]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state
|
||||
return (
|
||||
prg,
|
||||
{input: bufs[input][0] for input in input_names},
|
||||
{output: bufs[output][0] for output in output_names},
|
||||
state,
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
|
||||
np.set_printoptions(linewidth=160)
|
||||
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
|
||||
from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const
|
||||
|
@ -11,18 +12,61 @@ from llvmlite import ir # type: ignore
|
|||
# https://github.com/corsix/amx/blob/main/Instructions.md
|
||||
# 12 lines for AMX support
|
||||
from functools import partialmethod
|
||||
|
||||
|
||||
class AMX:
|
||||
@staticmethod
|
||||
def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True)
|
||||
def nop_op_imm5(op, imm5, builder):
|
||||
builder.asm(
|
||||
ir.FunctionType(ir.VoidType(), []),
|
||||
f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}",
|
||||
"",
|
||||
tuple(),
|
||||
True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True)
|
||||
def op_gpr(op, builder, gpr):
|
||||
builder.asm(
|
||||
ir.FunctionType(ir.VoidType(), [ir.IntType(64)]),
|
||||
f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0",
|
||||
"r",
|
||||
(gpr,),
|
||||
True,
|
||||
)
|
||||
|
||||
set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
|
||||
ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3)
|
||||
ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7)
|
||||
ldx, ldy, stx, sty = (
|
||||
partialmethod(op_gpr, 0),
|
||||
partialmethod(op_gpr, 1),
|
||||
partialmethod(op_gpr, 2),
|
||||
partialmethod(op_gpr, 3),
|
||||
)
|
||||
ldz, stz, ldzi, stzi = (
|
||||
partialmethod(op_gpr, 4),
|
||||
partialmethod(op_gpr, 5),
|
||||
partialmethod(op_gpr, 6),
|
||||
partialmethod(op_gpr, 7),
|
||||
)
|
||||
extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
|
||||
fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13)
|
||||
mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16)
|
||||
vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22)
|
||||
fma64, fms64, fma32, fms32 = (
|
||||
partialmethod(op_gpr, 10),
|
||||
partialmethod(op_gpr, 11),
|
||||
partialmethod(op_gpr, 12),
|
||||
partialmethod(op_gpr, 13),
|
||||
)
|
||||
mac16, fma16, fms16 = (
|
||||
partialmethod(op_gpr, 14),
|
||||
partialmethod(op_gpr, 15),
|
||||
partialmethod(op_gpr, 16),
|
||||
)
|
||||
vecint, vecfp, matint, matfp, genlut = (
|
||||
partialmethod(op_gpr, 18),
|
||||
partialmethod(op_gpr, 19),
|
||||
partialmethod(op_gpr, 20),
|
||||
partialmethod(op_gpr, 21),
|
||||
partialmethod(op_gpr, 22),
|
||||
)
|
||||
|
||||
|
||||
N = 4096
|
||||
|
@ -54,7 +98,11 @@ c = LLVMBuffer.fromCPU(np.zeros(256))
|
|||
bufs = [c, a, b]
|
||||
|
||||
module = ir.Module(name=__file__)
|
||||
func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
|
||||
func = ir.Function(
|
||||
module,
|
||||
ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()] * 3),
|
||||
name="exec",
|
||||
)
|
||||
|
||||
# load all
|
||||
entry = ir.IRBuilder(func.append_basic_block(name="entry"))
|
||||
|
@ -69,7 +117,19 @@ y.add_incoming(int_const(0), entry._block)
|
|||
yp = loop_1_exit.add(y, int_const(32 * 2))
|
||||
y.add_incoming(yp, loop_1_exit._block)
|
||||
|
||||
prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch")
|
||||
prefetch_function = ir.Function(
|
||||
module,
|
||||
ir.FunctionType(
|
||||
ir.VoidType(),
|
||||
[
|
||||
ir.PointerType(ir.FloatType()),
|
||||
ir.IntType(32),
|
||||
ir.IntType(32),
|
||||
ir.IntType(32),
|
||||
],
|
||||
),
|
||||
name="llvm.prefetch",
|
||||
)
|
||||
|
||||
xptr = y
|
||||
addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
||||
|
@ -79,7 +139,12 @@ addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
|||
|
||||
AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1 << 62), addr))
|
||||
xptr = loop_1_exit.add(xptr, int_const(32))
|
||||
AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))))
|
||||
AMX.ldy(
|
||||
loop_1_exit,
|
||||
loop_1_exit.add(
|
||||
int_const(1 << 62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
||||
),
|
||||
)
|
||||
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16 * 4) << 10))
|
||||
|
@ -93,7 +158,9 @@ AMX.clr(exit)
|
|||
|
||||
entry.branch(loop_1._block)
|
||||
loop_1.branch(loop_1_exit._block)
|
||||
loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block)
|
||||
loop_1_exit.cbranch(
|
||||
loop_1_exit.icmp_unsigned("==", yp, int_const(N * N)), exit._block, loop_1._block
|
||||
)
|
||||
exit.ret(int_const(0))
|
||||
|
||||
cfunc = LLVM().exec(module, bufs, N**2)
|
||||
|
@ -185,4 +252,3 @@ np.testing.assert_allclose(c.toCPU()[:sn.shape[0]], sn, atol=1e-4, rtol=1e-4)
|
|||
print(cn.astype(np.int64))
|
||||
np.testing.assert_allclose(c.toCPU(), cn, atol=1e-4, rtol=1e-5)
|
||||
"""
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import numpy as np
|
||||
|
||||
os.environ["CUDA"] = "1"
|
||||
from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda
|
||||
|
||||
|
@ -21,7 +22,10 @@ c = RawCUDABuffer.fromCPU(np.ones((N,N),dtype=np.float32))
|
|||
FLOPS = N * N * N * 2
|
||||
BW = N * N * 3 * 4
|
||||
|
||||
prog = CUDAProgram("wmma_example", compile_cuda(f"""
|
||||
prog = CUDAProgram(
|
||||
"wmma_example",
|
||||
compile_cuda(
|
||||
f"""
|
||||
#include <mma.h>
|
||||
using namespace nvcuda;
|
||||
|
||||
|
@ -88,10 +92,23 @@ __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLO
|
|||
}}
|
||||
}}
|
||||
}}
|
||||
"""))
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
global_size, local_size = [(N // 16) // 4, (N // 16) // 4], [32, 1, 1]
|
||||
tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
tm = min(
|
||||
[
|
||||
prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)
|
||||
for _ in range(20)
|
||||
]
|
||||
)
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), c.toCPU().reshape((N,N)).T, atol=1e-2)
|
||||
np.testing.assert_allclose(
|
||||
na.T.astype(np.float32) @ nb.T.astype(np.float32),
|
||||
c.toCPU().reshape((N, N)).T,
|
||||
atol=1e-2,
|
||||
)
|
||||
|
|
|
@ -15,6 +15,7 @@ from tinygrad.helpers import partition, GlobalCounters, Context, getenv, prod, d
|
|||
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
|
||||
|
||||
def single_kernel():
|
||||
# single kernel
|
||||
sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4)
|
||||
|
@ -22,11 +23,17 @@ def single_kernel():
|
|||
x = CLBuffer(prod(sz2), dtypes.imageh(sz2))
|
||||
w = CLBuffer(prod(sz3), dtypes.imageh(sz3))
|
||||
|
||||
old = CLProgram("r_32_16_16_64_4_4_4", open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read())
|
||||
old_tms = [old([1,1,32], [16,16,1], out, x, w, wait=True)*1e6 for _ in range(5)]
|
||||
old = CLProgram(
|
||||
"r_32_16_16_64_4_4_4",
|
||||
open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read(),
|
||||
)
|
||||
old_tms = [
|
||||
old([1, 1, 32], [16, 16, 1], out, x, w, wait=True) * 1e6 for _ in range(5)
|
||||
]
|
||||
print(old_tms, 67.107 / min(old_tms) * 1e3)
|
||||
exit(0)
|
||||
|
||||
|
||||
# CONV=0 PYTHONPATH="." LATEDEBUG=5 OPT=99 IMAGE=2 FLOAT16=1 NOLOCALS=1 python3 extra/fastvits/fastvits_speed.py
|
||||
if __name__ == "__main__":
|
||||
# single_kernel()
|
||||
|
@ -43,7 +50,11 @@ if __name__ == "__main__":
|
|||
out = x.sequential([c1, c2, c3, c4, c5])
|
||||
schedule = out.lazydata.schedule()
|
||||
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps and any(y.op in ReduceOps for y in x.ast.get_lazyops()))
|
||||
schedule, schedule_input = partition(
|
||||
schedule,
|
||||
lambda x: x.ast.op not in LoadOps
|
||||
and any(y.op in ReduceOps for y in x.ast.get_lazyops()),
|
||||
)
|
||||
run_schedule(schedule_input)
|
||||
run_schedule(schedule[: getenv("CONV")])
|
||||
print("*** init done ***")
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
|
||||
# os.environ['OMP_NUM_THREADS'] = '1'
|
||||
import time
|
||||
import numpy as np
|
||||
|
|
|
@ -78,5 +78,3 @@ if __name__ == "__main__":
|
|||
new_tms.append(new([256, 1, 1], [4, 16, 1], out, x, w, b, wait=True))
|
||||
|
||||
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")
|
||||
|
||||
|
||||
|
|
|
@ -30,12 +30,21 @@ a = hipallocator.alloc(N*N*4)
|
|||
b = hipallocator.alloc(N * N * 2)
|
||||
c = hipallocator.alloc(N * N * 2)
|
||||
na = np.empty(N * N, np.float32)
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
nb = (
|
||||
np.random.default_rng()
|
||||
.standard_normal(size=(N, N), dtype=np.float32)
|
||||
.astype(np.float16)
|
||||
)
|
||||
nc = (
|
||||
np.random.default_rng()
|
||||
.standard_normal(size=(N, N), dtype=np.float32)
|
||||
.astype(np.float16)
|
||||
)
|
||||
hipallocator.copyin(b, bytearray(nb))
|
||||
hipallocator.copyin(c, bytearray(nc))
|
||||
|
||||
lib = compile_hip(f"""
|
||||
lib = compile_hip(
|
||||
f"""
|
||||
#define F32
|
||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||
|
@ -92,10 +101,12 @@ extern "C" __global__ void __launch_bounds__ (128, 1) test(float* c, __half* a,
|
|||
}}
|
||||
}}
|
||||
}}
|
||||
}}""")
|
||||
}}"""
|
||||
)
|
||||
|
||||
prog = HIPProgram(device, "test", lib)
|
||||
|
||||
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
|
@ -103,11 +114,28 @@ def timeit(fxn):
|
|||
# print(f"{ret*1e6:.2f} us")
|
||||
return et
|
||||
|
||||
|
||||
global_size, local_size = [N // (KX * 16 * 2), N // (KY * 16 * 2), 1], [32, 2, 2]
|
||||
print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}")
|
||||
tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(1000)])
|
||||
print(
|
||||
"global/local size",
|
||||
global_size,
|
||||
local_size,
|
||||
f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}",
|
||||
)
|
||||
tm = min(
|
||||
[
|
||||
timeit(
|
||||
lambda: prog(
|
||||
a, b, c, global_size=global_size, local_size=local_size, wait=True
|
||||
)
|
||||
)
|
||||
for _ in range(1000)
|
||||
]
|
||||
)
|
||||
hipallocator.copyout(flat_mv(na.data), a)
|
||||
na = na.reshape(N, N)
|
||||
comp = nb.astype(np.float32) @ nc.astype(np.float32)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||
)
|
||||
np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)
|
||||
|
|
|
@ -14,7 +14,12 @@ A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
|
|||
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
||||
|
||||
OPS = DEVICES * BS * N * N * N * 2
|
||||
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
|
||||
|
||||
|
||||
def matmul(A, B):
|
||||
return jnp.matmul(A, B, preferred_element_type=jnp.float32)
|
||||
|
||||
|
||||
pmatmul = jax.pmap(matmul)
|
||||
|
||||
MAX_TFLOPS = 123 * DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
||||
|
@ -23,5 +28,6 @@ for i in range(10):
|
|||
C = pmatmul(A, B).block_until_ready()
|
||||
et = time.perf_counter() - st
|
||||
tflops = (OPS * 1e-12) / et
|
||||
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
|
||||
|
||||
print(
|
||||
f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}"
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
# os.environ["METAL"] = "1"
|
||||
import numpy as np
|
||||
|
||||
|
@ -18,14 +19,16 @@ nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float
|
|||
|
||||
try:
|
||||
import time, torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
b = torch.from_numpy(nb).to("mps")
|
||||
c = torch.from_numpy(nc).to("mps")
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = torch.nn.functional.conv2d(b, c, padding=PADDING)
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
|
||||
except RuntimeError:
|
||||
|
@ -34,16 +37,23 @@ except RuntimeError:
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad import Device
|
||||
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
|
||||
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return b.conv2d(c, padding=PADDING).realize()
|
||||
|
||||
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
Device[a.device].synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([tiny_prog(b, c) for _ in range(5)])
|
||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
os.environ["METAL"] = "1"
|
||||
import time
|
||||
import numpy as np
|
||||
|
@ -10,15 +11,22 @@ LID = 2
|
|||
|
||||
a = RawMetalBuffer(N * N, dtypes.float32)
|
||||
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nb = np.random.default_rng().standard_normal(
|
||||
size=(N, N), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(
|
||||
size=(N, N), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
b = RawMetalBuffer.fromCPU(nb)
|
||||
c = RawMetalBuffer.fromCPU(nc)
|
||||
|
||||
FLOPS = N * N * N * 2
|
||||
BW = N * N * 3 * 4
|
||||
|
||||
prog = MetalProgram("test", compile_metal(f"""
|
||||
prog = MetalProgram(
|
||||
"test",
|
||||
compile_metal(
|
||||
f"""
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
||||
using namespace metal;
|
||||
|
@ -80,46 +88,83 @@ kernel void test(device float *a, device const float *data1, device const float
|
|||
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
|
||||
}}"""))
|
||||
}}"""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
|
||||
|
||||
|
||||
tm = min(
|
||||
[
|
||||
timeit(
|
||||
lambda: prog(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
global_size=[N // (8 * 4), N // (8 * 4 * LID), 1],
|
||||
local_size=[32, LID, 1],
|
||||
wait=True,
|
||||
)
|
||||
)
|
||||
for _ in range(20)
|
||||
]
|
||||
)
|
||||
na = a.toCPU().reshape(N, N)
|
||||
comp = nb @ nc
|
||||
if N <= 32:
|
||||
print(na)
|
||||
print(comp)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||
)
|
||||
np.testing.assert_allclose(na, comp, atol=1e-3)
|
||||
|
||||
import torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
b = torch.from_numpy(nb).to("mps")
|
||||
c = torch.from_numpy(nc).to("mps")
|
||||
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = b @ c
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch"
|
||||
)
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.runtime.ops_metal import METAL
|
||||
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
|
||||
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return (b @ c).realize()
|
||||
|
||||
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([tiny_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad"
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
# os.environ["METAL"] = "1"
|
||||
import numpy as np
|
||||
import time, torch, torch.mps
|
||||
|
@ -10,6 +11,7 @@ from tinygrad import Device
|
|||
from tinygrad.helpers import colored, getenv, CI
|
||||
|
||||
import os
|
||||
|
||||
os.environ["METAL"] = "1"
|
||||
import time
|
||||
import numpy as np
|
||||
|
@ -20,27 +22,38 @@ N = 16384
|
|||
M = 4096
|
||||
FLOPS = N * M * 2
|
||||
|
||||
nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nb = np.random.default_rng().standard_normal(
|
||||
size=(N), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(
|
||||
size=(N, M), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
|
||||
import torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
b = torch.from_numpy(nb).to("mps")
|
||||
c = torch.from_numpy(nc).to("mps")
|
||||
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = b @ c
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(200)])
|
||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch")
|
||||
print(
|
||||
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch"
|
||||
)
|
||||
torch_a = (b @ c).cpu()
|
||||
|
||||
WORKSIZE_ROW = 16
|
||||
WORKSIZE_COL = 1
|
||||
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
|
||||
GLOBAL_SIZE = [M // (LOCAL_SIZE[0] * LOCAL_SIZE[1] * 4), 1, 1]
|
||||
prog = compile_metal(f"""
|
||||
prog = compile_metal(
|
||||
f"""
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
|
||||
|
@ -86,41 +99,59 @@ kernel void test(device float* data0, const device float* data1, const device fl
|
|||
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
|
||||
}}
|
||||
}}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
prog = MetalProgram("test", prog)
|
||||
# print(prog_string)
|
||||
na = np.zeros(M, dtype=np.float32)
|
||||
b = RawMetalBuffer.fromCPU(nb)
|
||||
c = RawMetalBuffer.fromCPU(nc)
|
||||
|
||||
|
||||
def metalrun():
|
||||
a = RawMetalBuffer.fromCPU(na)
|
||||
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
|
||||
return a
|
||||
|
||||
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([timeit(metalrun) for _ in range(200)])
|
||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
|
||||
print(
|
||||
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal"
|
||||
)
|
||||
metal_a = metalrun().toCPU().reshape(M)
|
||||
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.runtime.ops_metal import METAL
|
||||
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
|
||||
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return (b @ c).realize()
|
||||
|
||||
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([tiny_prog(b, c) for _ in range(200)])
|
||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad")
|
||||
print(
|
||||
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad"
|
||||
)
|
||||
tiny_a = tiny_jit(b, c).numpy()
|
||||
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)
|
|
@ -2,14 +2,28 @@ import numpy as np
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||
N = getenv("N", 4096)
|
||||
CNT = getenv("CNT", 10)
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
a, b = (
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
)
|
||||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
|
||||
a, b = (
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
)
|
||||
c = (
|
||||
(a.reshape(N, 1, N) * b.permute(1, 0).reshape(1, N, N))
|
||||
.float()
|
||||
.sum(axis=2)
|
||||
.realize()
|
||||
if getenv("ACCUM_FP32")
|
||||
else (a @ b).realize()
|
||||
)
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
gpus = tf.config.list_physical_devices('GPU')
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
if gpus:
|
||||
try:
|
||||
# Currently, memory growth needs to be the same across GPUs
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
||||
except RuntimeError as e:
|
||||
# Memory growth must be set before GPUs have been initialized
|
||||
|
@ -26,8 +26,12 @@ for dtype in [tf.float16, tf.float32]:
|
|||
def tf_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tf.matmul(b, c)
|
||||
tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done.
|
||||
tf.debugging.check_numerics(
|
||||
a, "Nan or Inf in result"
|
||||
) # Ensures that the calculation is done.
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([tf_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
|
||||
)
|
||||
|
|
|
@ -13,5 +13,8 @@ for dtype in [torch.float16, torch.float32]:
|
|||
a = b @ c
|
||||
torch.cuda.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
|
||||
)
|
||||
|
|
|
@ -5,6 +5,7 @@ M, N, K = 1024, 1024, 1024
|
|||
try:
|
||||
import tvm
|
||||
from tvm import te
|
||||
|
||||
# print(tvm.target.Target.list_kinds())
|
||||
|
||||
# c, opencl
|
||||
|
@ -39,9 +40,13 @@ C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
|||
sched = C.lazydata.schedule()
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False))
|
||||
|
||||
lin = Linearizer(
|
||||
sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False)
|
||||
)
|
||||
# lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
from tinygrad.runtime.ops_clang import renderer
|
||||
|
||||
src = renderer("mmult", lin.uops)
|
||||
print(src)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
def mask_like(like, mask_inx, mask_value=1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
mask[mask_inx] = mask_value
|
||||
return mask.reshape(like.shape)
|
||||
|
||||
|
||||
def jacobian(func, input):
|
||||
output = func(input)
|
||||
|
||||
|
@ -19,13 +21,14 @@ def jacobian(func, input):
|
|||
|
||||
# tinygrad doesn't support slicing, tiny-hack to select
|
||||
# the needed scalar an backpropagate only through it
|
||||
o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum()
|
||||
o_scalar = Tensor(mask_like(output.numpy(), o, 1.0)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
|
||||
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
|
||||
J[o, i] = grad
|
||||
return J
|
||||
|
||||
|
||||
def numerical_jacobian(func, input, eps=1e-3):
|
||||
output = func(input)
|
||||
|
||||
|
@ -36,14 +39,19 @@ def numerical_jacobian(func, input, eps = 1e-3):
|
|||
for i in range(ji):
|
||||
eps_perturb = mask_like(input.numpy(), i, mask_value=eps)
|
||||
|
||||
output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
|
||||
output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
|
||||
output_perturb_add = (
|
||||
func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
|
||||
)
|
||||
output_perturb_sub = (
|
||||
func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
|
||||
)
|
||||
|
||||
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2 * eps)
|
||||
|
||||
NJ[:, i] = grad_approx
|
||||
return NJ
|
||||
|
||||
|
||||
def gradcheck(func, input, eps=1e-3, atol=1e-3, rtol=1e-3):
|
||||
NJ = numerical_jacobian(func, input, eps)
|
||||
J = jacobian(func, input)
|
||||
|
|
|
@ -2,6 +2,7 @@ import multiprocessing, subprocess
|
|||
import cloudpickle
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _early_exec_process(qin, qout):
|
||||
while True:
|
||||
path, inp = qin.get()
|
||||
|
@ -10,41 +11,62 @@ def _early_exec_process(qin, qout):
|
|||
except Exception as e:
|
||||
qout.put(e)
|
||||
|
||||
|
||||
def enable_early_exec():
|
||||
qin: multiprocessing.Queue = multiprocessing.Queue()
|
||||
qout: multiprocessing.Queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
def early_exec(x):
|
||||
qin.put(x)
|
||||
ret = qout.get()
|
||||
if isinstance(ret, Exception): raise ret
|
||||
else: return ret
|
||||
if isinstance(ret, Exception):
|
||||
raise ret
|
||||
else:
|
||||
return ret
|
||||
|
||||
return early_exec
|
||||
|
||||
|
||||
def proc(itermaker, q) -> None:
|
||||
try:
|
||||
for x in itermaker(): q.put(x)
|
||||
for x in itermaker():
|
||||
q.put(x)
|
||||
except Exception as e:
|
||||
q.put(e)
|
||||
finally:
|
||||
q.put(None)
|
||||
q.close()
|
||||
|
||||
|
||||
class _CloudpickleFunctionWrapper:
|
||||
def __init__(self, fn): self.fn = fn
|
||||
def __getstate__(self): return cloudpickle.dumps(self.fn)
|
||||
def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn)
|
||||
def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs)
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __getstate__(self):
|
||||
return cloudpickle.dumps(self.fn)
|
||||
|
||||
def __setstate__(self, pfn):
|
||||
self.fn = cloudpickle.loads(pfn)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
def cross_process(itermaker, maxsize=16):
|
||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
||||
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
||||
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
|
||||
p = multiprocessing.Process(
|
||||
target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q)
|
||||
)
|
||||
p.start()
|
||||
while True:
|
||||
ret = q.get()
|
||||
if isinstance(ret, Exception): raise ret
|
||||
elif ret is None: break
|
||||
else: yield ret
|
||||
if isinstance(ret, Exception):
|
||||
raise ret
|
||||
elif ret is None:
|
||||
break
|
||||
else:
|
||||
yield ret
|
||||
|
|
|
@ -6,6 +6,7 @@ from tinygrad.lazy import LazyBuffer
|
|||
from tinygrad.runtime.ops_gpu import CLBuffer
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
|
||||
|
||||
def print_objects():
|
||||
# gc.collect()
|
||||
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
|
||||
|
@ -15,8 +16,12 @@ def print_objects():
|
|||
realized_buffers = [x.realized for x in lazybuffers if x.realized]
|
||||
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]
|
||||
|
||||
print(f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB")
|
||||
print(f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers")
|
||||
print(
|
||||
f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB"
|
||||
)
|
||||
print(
|
||||
f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers"
|
||||
)
|
||||
print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned")
|
||||
|
||||
cnt = 0
|
||||
|
@ -33,11 +38,14 @@ def print_objects():
|
|||
cnt += 1
|
||||
|
||||
for x in gpubuffers_orphaned:
|
||||
if getattr(x, '_buf', None): del x._buf
|
||||
if getattr(x, '_image', None): del x._image
|
||||
if getattr(x, "_buf", None):
|
||||
del x._buf
|
||||
if getattr(x, "_image", None):
|
||||
del x._image
|
||||
|
||||
return len(gpubuffers_orphaned)
|
||||
|
||||
|
||||
"""
|
||||
import gc
|
||||
|
||||
|
|
|
@ -7,39 +7,44 @@ from google.protobuf import descriptor as _descriptor
|
|||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
|
||||
)
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'H\003'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC']._serialized_start=45
|
||||
_globals['_TRAINERSPEC']._serialized_end=1581
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
|
||||
_globals['_NORMALIZERSPEC']._serialized_start=1584
|
||||
_globals['_NORMALIZERSPEC']._serialized_end=1793
|
||||
_globals['_SELFTESTDATA']._serialized_start=1795
|
||||
_globals['_SELFTESTDATA']._serialized_end=1916
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
|
||||
_globals['_MODELPROTO']._serialized_start=1919
|
||||
_globals['_MODELPROTO']._serialized_end=2429
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
|
||||
_globals["DESCRIPTOR"]._options = None
|
||||
_globals["DESCRIPTOR"]._serialized_options = b"H\003"
|
||||
_globals["_TRAINERSPEC"].fields_by_name["mining_sentence_size"]._options = None
|
||||
_globals["_TRAINERSPEC"].fields_by_name[
|
||||
"mining_sentence_size"
|
||||
]._serialized_options = b"\030\001"
|
||||
_globals["_TRAINERSPEC"].fields_by_name["training_sentence_size"]._options = None
|
||||
_globals["_TRAINERSPEC"].fields_by_name[
|
||||
"training_sentence_size"
|
||||
]._serialized_options = b"\030\001"
|
||||
_globals["_TRAINERSPEC"]._serialized_start = 45
|
||||
_globals["_TRAINERSPEC"]._serialized_end = 1581
|
||||
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517
|
||||
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570
|
||||
_globals["_NORMALIZERSPEC"]._serialized_start = 1584
|
||||
_globals["_NORMALIZERSPEC"]._serialized_end = 1793
|
||||
_globals["_SELFTESTDATA"]._serialized_start = 1795
|
||||
_globals["_SELFTESTDATA"]._serialized_end = 1916
|
||||
_globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864
|
||||
_globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905
|
||||
_globals["_MODELPROTO"]._serialized_start = 1919
|
||||
_globals["_MODELPROTO"]._serialized_end = 2429
|
||||
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208
|
||||
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418
|
||||
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323
|
||||
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
|
@ -3,17 +3,22 @@ from typing import List
|
|||
from tinygrad.nn.optim import Optimizer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
class LR_Scheduler:
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
|
||||
self.epoch_counter = Tensor(
|
||||
[0], requires_grad=False, device=self.optimizer.device
|
||||
)
|
||||
|
||||
def get_lr(self): pass
|
||||
def get_lr(self):
|
||||
pass
|
||||
|
||||
def step(self) -> None:
|
||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
self.optimizer.lr.assign(self.get_lr()).realize()
|
||||
|
||||
|
||||
class MultiStepLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
|
||||
super().__init__(optimizer)
|
||||
|
@ -25,18 +30,38 @@ class MultiStepLR(LR_Scheduler):
|
|||
return self.optimizer.lr
|
||||
return self.optimizer.lr * self.gamma
|
||||
|
||||
|
||||
class ReduceLROnPlateau(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, mode="min", factor=0.1, patience=10, threshold=1e-4, threshold_mode="rel"):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
mode="min",
|
||||
factor=0.1,
|
||||
patience=10,
|
||||
threshold=1e-4,
|
||||
threshold_mode="rel",
|
||||
):
|
||||
assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"]
|
||||
super().__init__(optimizer)
|
||||
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = mode, factor, patience, threshold, threshold_mode
|
||||
self.best = float('inf') if mode == "min" else float('-inf')
|
||||
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = (
|
||||
mode,
|
||||
factor,
|
||||
patience,
|
||||
threshold,
|
||||
threshold_mode,
|
||||
)
|
||||
self.best = float("inf") if mode == "min" else float("-inf")
|
||||
self.bad_epoch = 0
|
||||
|
||||
if mode == "min": self.threshold *= -1
|
||||
if mode == "min":
|
||||
self.threshold *= -1
|
||||
|
||||
def is_better(self, current: float) -> bool:
|
||||
dynamic_threshold = self.best*(1+self.threshold) if self.threshold_mode == "rel" else self.best+self.threshold
|
||||
dynamic_threshold = (
|
||||
self.best * (1 + self.threshold)
|
||||
if self.threshold_mode == "rel"
|
||||
else self.best + self.threshold
|
||||
)
|
||||
if self.mode == "min":
|
||||
return current < dynamic_threshold
|
||||
return current > dynamic_threshold
|
||||
|
@ -53,6 +78,7 @@ class ReduceLROnPlateau(LR_Scheduler):
|
|||
self.optimizer.lr *= self.factor
|
||||
self.bad_epoch = 0
|
||||
|
||||
|
||||
class CosineAnnealingLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0):
|
||||
super().__init__(optimizer)
|
||||
|
@ -61,26 +87,54 @@ class CosineAnnealingLR(LR_Scheduler):
|
|||
self.eta_max = optimizer.lr.numpy()[0]
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device)
|
||||
return Tensor(
|
||||
[
|
||||
self.eta_min
|
||||
+ 0.5
|
||||
* (self.eta_max - self.eta_min)
|
||||
* (1 + math.cos((self.epoch_counter.numpy()[0] / self.T_max) * math.pi))
|
||||
],
|
||||
device=self.optimizer.device,
|
||||
)
|
||||
|
||||
|
||||
class OneCycleLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
|
||||
anneal_strategy: str = 'linear', cycle_momentum: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
max_lr: float,
|
||||
div_factor: float,
|
||||
final_div_factor: float,
|
||||
total_steps: int,
|
||||
pct_start: float,
|
||||
anneal_strategy: str = "linear",
|
||||
cycle_momentum: bool = False,
|
||||
):
|
||||
self.initial_lr = Tensor([max_lr / div_factor]).contiguous()
|
||||
self.max_lr = Tensor([max_lr]).contiguous()
|
||||
self.min_lr = self.initial_lr / final_div_factor
|
||||
super().__init__(optimizer)
|
||||
self.total_steps = total_steps
|
||||
self.pct_start = pct_start
|
||||
assert anneal_strategy == 'linear', 'only linear annealing supported'
|
||||
assert not cycle_momentum, 'cycle momentum not supported'
|
||||
assert anneal_strategy == "linear", "only linear annealing supported"
|
||||
assert not cycle_momentum, "cycle momentum not supported"
|
||||
self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR
|
||||
|
||||
@staticmethod
|
||||
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor: return ((end - start) * pct + start)
|
||||
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor:
|
||||
return (end - start) * pct + start
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
return (self.epoch_counter < self.total_steps * self.pct_start).where(
|
||||
self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)),
|
||||
self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start)))
|
||||
self._annealing_linear(
|
||||
self.initial_lr,
|
||||
self.max_lr,
|
||||
self.epoch_counter / (self.total_steps * self.pct_start),
|
||||
),
|
||||
self._annealing_linear(
|
||||
self.max_lr,
|
||||
self.min_lr,
|
||||
(self.epoch_counter - (self.total_steps * self.pct_start))
|
||||
/ (self.total_steps * (1 - self.pct_start)),
|
||||
),
|
||||
)
|
||||
|
|
|
@ -5,8 +5,29 @@ from pathlib import Path
|
|||
|
||||
|
||||
class BertForQuestionAnswering:
|
||||
def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1):
|
||||
self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=512,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=24,
|
||||
type_vocab_size=2,
|
||||
vocab_size=30522,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_dropout_prob=0.1,
|
||||
):
|
||||
self.bert = Bert(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
max_position_embeddings,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
self.qa_outputs = Linear(hidden_size, 2)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
|
@ -16,15 +37,20 @@ class BertForQuestionAnswering:
|
|||
fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
|
||||
|
||||
import torch
|
||||
|
||||
with open(fn, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if "dropout" in k: continue # skip dropout
|
||||
if "pooler" in k: continue # skip pooler
|
||||
if "dropout" in k:
|
||||
continue # skip dropout
|
||||
if "pooler" in k:
|
||||
continue # skip pooler
|
||||
get_child(self, k).assign(v.numpy()).realize()
|
||||
|
||||
def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor):
|
||||
def __call__(
|
||||
self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor
|
||||
):
|
||||
sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.chunk(2, dim=-1)
|
||||
|
@ -33,10 +59,35 @@ class BertForQuestionAnswering:
|
|||
|
||||
return Tensor.stack([start_logits, end_logits])
|
||||
|
||||
|
||||
class Bert:
|
||||
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
|
||||
self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
max_position_embeddings,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.embeddings = BertEmbeddings(
|
||||
hidden_size,
|
||||
max_position_embeddings,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
self.encoder = BertEncoder(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids):
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
@ -47,8 +98,16 @@ class Bert:
|
|||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
class BertEmbeddings:
|
||||
def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
max_position_embeddings,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.word_embeddings = Embedding(vocab_size, hidden_size)
|
||||
self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
|
||||
self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
|
||||
|
@ -59,7 +118,11 @@ class BertEmbeddings:
|
|||
input_shape = input_ids.shape
|
||||
seq_length = input_shape[1]
|
||||
|
||||
position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
position_ids = (
|
||||
Tensor.arange(seq_length, requires_grad=False)
|
||||
.unsqueeze(0)
|
||||
.expand(*input_shape)
|
||||
)
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
@ -69,18 +132,49 @@ class BertEmbeddings:
|
|||
embeddings = embeddings.dropout(self.dropout)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertEncoder:
|
||||
def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)]
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.layer = [
|
||||
BertLayer(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
for _ in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states, attention_mask)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer:
|
||||
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.attention = BertAttention(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size)
|
||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
|
||||
|
||||
|
@ -90,6 +184,7 @@ class BertLayer:
|
|||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertOutput:
|
||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
|
||||
self.dense = Linear(intermediate_size, hidden_size)
|
||||
|
@ -102,10 +197,21 @@ class BertOutput:
|
|||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# approximation of the error function
|
||||
def erf(x):
|
||||
t = (1 + 0.3275911 * x.abs()).reciprocal()
|
||||
return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp())
|
||||
return x.sign() * (
|
||||
1
|
||||
- (
|
||||
(((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736)
|
||||
* t
|
||||
+ 0.254829592
|
||||
)
|
||||
* t
|
||||
* (-(x.square())).exp()
|
||||
)
|
||||
|
||||
|
||||
class BertIntermediate:
|
||||
def __init__(self, hidden_size, intermediate_size):
|
||||
|
@ -116,9 +222,18 @@ class BertIntermediate:
|
|||
# tinygrad gelu is openai gelu but we need the original bert gelu
|
||||
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
||||
|
||||
|
||||
class BertAttention:
|
||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.self = BertSelfAttention(
|
||||
hidden_size, num_attention_heads, attention_probs_dropout_prob
|
||||
)
|
||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
|
@ -126,6 +241,7 @@ class BertAttention:
|
|||
attention_output = self.output(self_output, hidden_states)
|
||||
return attention_output
|
||||
|
||||
|
||||
class BertSelfAttention:
|
||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
@ -147,17 +263,24 @@ class BertSelfAttention:
|
|||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
context_layer = Tensor.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, self.dropout)
|
||||
context_layer = Tensor.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, attention_mask, self.dropout
|
||||
)
|
||||
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size)
|
||||
context_layer = context_layer.reshape(
|
||||
context_layer.shape[0], context_layer.shape[1], self.all_head_size
|
||||
)
|
||||
|
||||
return context_layer
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size)
|
||||
x = x.reshape(
|
||||
x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size
|
||||
)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
class BertSelfOutput:
|
||||
def __init__(self, hidden_size, hidden_dropout_prob):
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
|
|
|
@ -2,6 +2,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
||||
from tinygrad.helpers import fetch, get_child
|
||||
|
||||
|
||||
class Block:
|
||||
def __init__(self, dim):
|
||||
self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
||||
|
@ -11,18 +12,43 @@ class Block:
|
|||
self.gamma = Tensor.ones(dim)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
return x + x.sequential([
|
||||
self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm,
|
||||
self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2)
|
||||
])
|
||||
return x + x.sequential(
|
||||
[
|
||||
self.dwconv,
|
||||
lambda x: x.permute(0, 2, 3, 1),
|
||||
self.norm,
|
||||
self.pwconv1,
|
||||
Tensor.gelu,
|
||||
self.pwconv2,
|
||||
lambda x: (self.gamma * x).permute(0, 3, 1, 2),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConvNeXt:
|
||||
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 768],
|
||||
):
|
||||
self.downsample_layers = [
|
||||
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
|
||||
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
||||
[
|
||||
Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
||||
LayerNorm2d(dims[0], eps=1e-6),
|
||||
],
|
||||
*[
|
||||
[
|
||||
LayerNorm2d(dims[i], eps=1e-6),
|
||||
Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
||||
]
|
||||
for i in range(len(dims) - 1)
|
||||
],
|
||||
]
|
||||
self.stages = [
|
||||
[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))
|
||||
]
|
||||
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
|
||||
self.norm = LayerNorm(dims[-1])
|
||||
self.head = Linear(dims[-1], num_classes)
|
||||
|
||||
|
@ -31,6 +57,7 @@ class ConvNeXt:
|
|||
x = x.sequential(downsample).sequential(stage)
|
||||
return x.mean([-2, -1]).sequential([self.norm, self.head])
|
||||
|
||||
|
||||
# *** model definition is done ***
|
||||
|
||||
versions = {
|
||||
|
@ -38,24 +65,32 @@ versions = {
|
|||
"small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
|
||||
"base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
|
||||
"large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
|
||||
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]}
|
||||
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]},
|
||||
}
|
||||
|
||||
|
||||
def get_model(version, load_weights=False):
|
||||
model = ConvNeXt(**versions[version])
|
||||
if load_weights:
|
||||
from tinygrad.nn.state import torch_load
|
||||
weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
|
||||
|
||||
weights = torch_load(
|
||||
fetch(
|
||||
f"https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth"
|
||||
)
|
||||
)["model"]
|
||||
for k, v in weights.items():
|
||||
mv = get_child(model, k)
|
||||
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = get_model("tiny", True)
|
||||
|
||||
# load image
|
||||
from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
|
||||
|
||||
img = Tensor(preprocess(chicken_img))
|
||||
|
||||
Tensor.training = False
|
||||
|
|
|
@ -4,8 +4,19 @@ from tinygrad.nn import BatchNorm2d
|
|||
from tinygrad.helpers import get_child, fetch
|
||||
from tinygrad.nn.state import torch_load
|
||||
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size,
|
||||
strides,
|
||||
expand_ratio,
|
||||
input_filters,
|
||||
output_filters,
|
||||
se_ratio,
|
||||
has_se,
|
||||
track_running_stats=True,
|
||||
):
|
||||
oup = expand_ratio * input_filters
|
||||
if expand_ratio != 1:
|
||||
self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
|
||||
|
@ -37,12 +48,19 @@ class MBConvBlock:
|
|||
x = inputs
|
||||
if self._expand_conv:
|
||||
x = self._bn0(x.conv2d(self._expand_conv)).swish()
|
||||
x = x.conv2d(self._depthwise_conv, padding=self.pad, stride=self.strides, groups=self._depthwise_conv.shape[0])
|
||||
x = x.conv2d(
|
||||
self._depthwise_conv,
|
||||
padding=self.pad,
|
||||
stride=self.strides,
|
||||
groups=self._depthwise_conv.shape[0],
|
||||
)
|
||||
x = self._bn1(x).swish()
|
||||
|
||||
if self.has_se:
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
|
||||
x_squeezed = x_squeezed.conv2d(
|
||||
self._se_reduce, self._se_reduce_bias
|
||||
).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
|
||||
x = x.mul(x_squeezed.sigmoid())
|
||||
|
||||
|
@ -51,8 +69,17 @@ class MBConvBlock:
|
|||
x = x.add(inputs)
|
||||
return x
|
||||
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True):
|
||||
def __init__(
|
||||
self,
|
||||
number=0,
|
||||
classes=1000,
|
||||
has_se=True,
|
||||
track_running_stats=True,
|
||||
input_channels=3,
|
||||
has_fc_output=True,
|
||||
):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
|
@ -106,10 +133,31 @@ class EfficientNet:
|
|||
]
|
||||
|
||||
self._blocks = []
|
||||
for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
|
||||
input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
|
||||
for (
|
||||
num_repeats,
|
||||
kernel_size,
|
||||
strides,
|
||||
expand_ratio,
|
||||
input_filters,
|
||||
output_filters,
|
||||
se_ratio,
|
||||
) in blocks_args:
|
||||
input_filters, output_filters = round_filters(input_filters), round_filters(
|
||||
output_filters
|
||||
)
|
||||
for n in range(round_repeats(num_repeats)):
|
||||
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
kernel_size,
|
||||
strides,
|
||||
expand_ratio,
|
||||
input_filters,
|
||||
output_filters,
|
||||
se_ratio,
|
||||
has_se=has_se,
|
||||
track_running_stats=track_running_stats,
|
||||
)
|
||||
)
|
||||
input_filters = output_filters
|
||||
strides = (1, 1)
|
||||
|
||||
|
@ -140,25 +188,34 @@ class EfficientNet:
|
|||
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
|
||||
}
|
||||
|
||||
b0 = torch_load(fetch(model_urls[self.number]))
|
||||
for k, v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||
if k.endswith("num_batches_tracked"):
|
||||
continue
|
||||
for cat in [
|
||||
"_conv_head",
|
||||
"_conv_stem",
|
||||
"_depthwise_conv",
|
||||
"_expand_conv",
|
||||
"_fc",
|
||||
"_project_conv",
|
||||
"_se_reduce",
|
||||
"_se_expand",
|
||||
]:
|
||||
if cat in k:
|
||||
k = k.replace('.bias', '_bias')
|
||||
k = k.replace('.weight', '')
|
||||
k = k.replace(".bias", "_bias")
|
||||
k = k.replace(".weight", "")
|
||||
|
||||
# print(k, v.shape)
|
||||
mv = get_child(self, k)
|
||||
vnp = v # .astype(np.float32)
|
||||
vnp = vnp if k != '_fc' else vnp.cpu().T
|
||||
vnp = vnp if k != "_fc" else vnp.cpu().T
|
||||
# vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
|
||||
if mv.shape == vnp.shape:
|
||||
mv.assign(vnp.to(mv.device))
|
||||
else:
|
||||
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
||||
|
||||
|
|
|
@ -2,11 +2,15 @@ from typing import Tuple, Union, Optional, Dict
|
|||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(
|
||||
1, end, 1, dim // 2, 2
|
||||
)
|
||||
|
||||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
|
@ -15,20 +19,33 @@ def complex_mult(A, c, d):
|
|||
co = a * d + b * c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
assert (
|
||||
freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1]
|
||||
), f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
|
||||
c, d = (
|
||||
freqs_cis[:, : xq.shape[1], :, :, 0:1],
|
||||
freqs_cis[:, : xq.shape[1], :, :, 1:2],
|
||||
)
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x.reshape(bs, seqlen, n_kv_heads, 1, head_dim)
|
||||
.expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm:
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
|
@ -39,10 +56,13 @@ class RMSNorm:
|
|||
# TODO: convert to float?
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.n_kv_heads = (
|
||||
n_kv_heads if n_kv_heads is not None else n_heads
|
||||
) # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
@ -52,7 +72,13 @@ class Attention:
|
|||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
start_pos: Union[Variable, int],
|
||||
freqs_cis: Tensor,
|
||||
mask: Optional[Tensor],
|
||||
) -> Tensor:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
|
@ -62,21 +88,40 @@ class Attention:
|
|||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
||||
self.cache_k, self.cache_v = Tensor.zeros(
|
||||
bsz, self.max_context, self.n_kv_heads, self.head_dim
|
||||
), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_k.assign(
|
||||
keys.pad(
|
||||
(None, (0, self.max_context - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
self.cache_v.assign(
|
||||
values.pad(
|
||||
(None, (0, self.max_context - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
xq, keys, values = (
|
||||
xq.transpose(1, 2),
|
||||
keys.transpose(1, 2),
|
||||
values.transpose(1, 2),
|
||||
)
|
||||
attn = (
|
||||
xq.scaled_dot_product_attention(keys, values, mask)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, seqlen, -1)
|
||||
)
|
||||
return self.wo(attn)
|
||||
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, linear=nn.Linear):
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
|
@ -84,36 +129,88 @@ class FeedForward:
|
|||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
return self.w2(
|
||||
self.w1(x).silu() * self.w3(x)
|
||||
) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
max_context: int,
|
||||
linear=nn.Linear,
|
||||
):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||
self.feed_forward = FeedForward(dim, hidden_dim, linear)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
start_pos: Union[Variable, int],
|
||||
freqs_cis: Tensor,
|
||||
mask: Optional[Tensor],
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear) for _ in range(n_layers)]
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
norm_eps: float,
|
||||
vocab_size,
|
||||
linear=nn.Linear,
|
||||
n_kv_heads=None,
|
||||
rope_theta=10000,
|
||||
max_context=1024,
|
||||
jit=True,
|
||||
):
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
dim // n_heads, self.max_context * 2, rope_theta
|
||||
)
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
|
||||
def forward(
|
||||
self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float = 0.0
|
||||
):
|
||||
_bsz, seqlen = tokens.shape
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
freqs_cis = self.freqs_cis.shrink(
|
||||
(None, (start_pos, start_pos + seqlen), None, None, None)
|
||||
)
|
||||
mask = (
|
||||
Tensor.full(
|
||||
(1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32
|
||||
)
|
||||
.triu(start_pos + 1)
|
||||
.realize()
|
||||
if seqlen > 1
|
||||
else None
|
||||
)
|
||||
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().flatten().realize()
|
||||
|
||||
|
@ -121,27 +218,54 @@ class Transformer:
|
|||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1, 1) and self.forward_jit and getenv("JIT", 1):
|
||||
assert start_pos > 0
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
|
||||
return self.forward_jit(
|
||||
tokens,
|
||||
Variable("start_pos", 1, self.max_context).bind(start_pos),
|
||||
temperature,
|
||||
)
|
||||
return self.forward(tokens, start_pos, temperature)
|
||||
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
||||
|
||||
def convert_from_huggingface(
|
||||
weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int
|
||||
):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
return (
|
||||
v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1])
|
||||
.transpose(1, 2)
|
||||
.reshape(*v.shape[:2])
|
||||
)
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
**{
|
||||
f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
**{
|
||||
f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
|
||||
for x in ["q", "k", "v", "o"]
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
**{
|
||||
f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
**{
|
||||
f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
|
||||
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k: continue
|
||||
if ".rotary_emb." in k:
|
||||
continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if "q_proj" in k:
|
||||
|
|
|
@ -10,64 +10,95 @@ from tinygrad.nn.state import torch_load
|
|||
from extra.models.resnet import ResNet
|
||||
from extra.models.retinanet import nms as _box_nms
|
||||
|
||||
USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
|
||||
USE_NP_GATHER = os.getenv("FULL_TINYGRAD", "0") == "0"
|
||||
|
||||
|
||||
def rint(tensor):
|
||||
x = (tensor * 2).cast(dtypes.int32).contiguous().cast(dtypes.float32) / 2
|
||||
return (x < 0).where(x.floor(), x.ceil())
|
||||
|
||||
|
||||
def nearest_interpolate(tensor, scale_factor):
|
||||
bs, c, py, px = tensor.shape
|
||||
return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
|
||||
return (
|
||||
tensor.reshape(bs, c, py, 1, px, 1)
|
||||
.expand(bs, c, py, scale_factor, px, scale_factor)
|
||||
.reshape(bs, c, py * scale_factor, px * scale_factor)
|
||||
)
|
||||
|
||||
|
||||
def meshgrid(x, y):
|
||||
grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
|
||||
grid_x = Tensor.cat(
|
||||
*[x[idx : idx + 1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])]
|
||||
)
|
||||
grid_y = Tensor.cat(*[y.unsqueeze(0)] * x.shape[0])
|
||||
return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
|
||||
|
||||
|
||||
def topk(input_, k, dim=-1, largest=True, sorted=False):
|
||||
k = min(k, input_.shape[dim] - 1)
|
||||
input_ = input_.numpy()
|
||||
if largest: input_ *= -1
|
||||
if largest:
|
||||
input_ *= -1
|
||||
ind = np.argpartition(input_, k, axis=dim)
|
||||
if largest: input_ *= -1
|
||||
if largest:
|
||||
input_ *= -1
|
||||
ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
|
||||
input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
|
||||
if not sorted: return Tensor(input_), ind
|
||||
if largest: input_ *= -1
|
||||
if not sorted:
|
||||
return Tensor(input_), ind
|
||||
if largest:
|
||||
input_ *= -1
|
||||
ind_part = np.argsort(input_, axis=dim)
|
||||
ind = np.take_along_axis(ind, ind_part, axis=dim)
|
||||
if largest: input_ *= -1
|
||||
if largest:
|
||||
input_ *= -1
|
||||
val = np.take_along_axis(input_, ind_part, axis=dim)
|
||||
return Tensor(val), ind
|
||||
|
||||
|
||||
# This is very slow for large arrays, or indices
|
||||
def _gather(array, indices):
|
||||
indices = indices.float().to(array.device)
|
||||
reshape_arg = [1] * array.ndim + [array.shape[-1]]
|
||||
return Tensor.where(
|
||||
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
|
||||
array, 0,
|
||||
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1])
|
||||
== Tensor.arange(array.shape[-1])
|
||||
.reshape(*reshape_arg)
|
||||
.expand(*indices.shape, array.shape[-1]),
|
||||
array,
|
||||
0,
|
||||
).sum(indices.ndim)
|
||||
|
||||
|
||||
# TODO: replace npgather with a faster gather using tinygrad only
|
||||
# NOTE: this blocks the gradient
|
||||
def npgather(array, indices):
|
||||
if isinstance(array, Tensor): array = array.numpy()
|
||||
if isinstance(indices, Tensor): indices = indices.numpy()
|
||||
if isinstance(indices, list): indices = np.asarray(indices)
|
||||
if isinstance(array, Tensor):
|
||||
array = array.numpy()
|
||||
if isinstance(indices, Tensor):
|
||||
indices = indices.numpy()
|
||||
if isinstance(indices, list):
|
||||
indices = np.asarray(indices)
|
||||
return Tensor(array[indices.astype(int)])
|
||||
|
||||
|
||||
def get_strides(shape):
|
||||
prod = [1]
|
||||
for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
|
||||
for idx in range(len(shape) - 1, -1, -1):
|
||||
prod.append(prod[-1] * shape[idx])
|
||||
# something about ints is broken with gpu, cuda
|
||||
return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu()
|
||||
|
||||
|
||||
# with keys as integer array for all axes
|
||||
def tensor_getitem(tensor, *keys):
|
||||
# something about ints is broken with gpu, cuda
|
||||
flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cpu().cast(dtypes.int32)
|
||||
flat_keys = (
|
||||
Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1)
|
||||
.cpu()
|
||||
.cast(dtypes.int32)
|
||||
)
|
||||
strides = get_strides(tensor.shape)
|
||||
idxs = (flat_keys * strides).sum(1)
|
||||
gatherer = npgather if USE_NP_GATHER else _gather
|
||||
|
@ -97,7 +128,8 @@ def tensor_gather(tensor, indices):
|
|||
|
||||
|
||||
class LastLevelMaxPool:
|
||||
def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
|
||||
def __call__(self, x):
|
||||
return [Tensor.max_pool2d(x, 1, 2)]
|
||||
|
||||
|
||||
# transpose
|
||||
|
@ -117,9 +149,7 @@ class BoxList:
|
|||
if not isinstance(bbox, Tensor):
|
||||
bbox = Tensor(bbox)
|
||||
if bbox.ndim != 2:
|
||||
raise ValueError(
|
||||
"bbox should have 2 dimensions, got {}".format(bbox.ndim)
|
||||
)
|
||||
raise ValueError("bbox should have 2 dimensions, got {}".format(bbox.ndim))
|
||||
if bbox.shape[-1] != 4:
|
||||
raise ValueError(
|
||||
"last dimenion of bbox should have a "
|
||||
|
@ -145,7 +175,9 @@ class BoxList:
|
|||
box = self.bbox
|
||||
if self.mode == "xyxy":
|
||||
TO_REMOVE = 1
|
||||
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
|
||||
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (
|
||||
box[:, 3] - box[:, 1] + TO_REMOVE
|
||||
)
|
||||
elif self.mode == "xywh":
|
||||
area = box[:, 2] * box[:, 3]
|
||||
return area
|
||||
|
@ -241,7 +273,8 @@ class BoxList:
|
|||
transposed_ymax = image_height - ymin
|
||||
|
||||
transposed_boxes = Tensor.cat(
|
||||
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
|
||||
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax),
|
||||
dim=-1,
|
||||
)
|
||||
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
|
||||
for k, v in self.extra_fields.items():
|
||||
|
@ -289,7 +322,11 @@ def cat_boxlist(bboxes):
|
|||
else:
|
||||
cat_boxes = BoxList(bboxes[0].bbox, size, mode)
|
||||
for field in fields:
|
||||
cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
|
||||
cat_field_list = [
|
||||
bbox.get_field(field)
|
||||
for bbox in bboxes
|
||||
if bbox.get_field(field).shape[0] > 0
|
||||
]
|
||||
|
||||
if len(cat_box_list) > 0:
|
||||
data = Tensor.cat(*cat_field_list, dim=0)
|
||||
|
@ -305,8 +342,12 @@ class FPN:
|
|||
def __init__(self, in_channels_list, out_channels):
|
||||
self.inner_blocks, self.layer_blocks = [], []
|
||||
for in_channels in in_channels_list:
|
||||
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
||||
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
||||
self.inner_blocks.append(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
)
|
||||
self.layer_blocks.append(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
self.top_block = LastLevelMaxPool()
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
|
@ -357,9 +398,7 @@ class AnchorGenerator:
|
|||
):
|
||||
if len(anchor_strides) == 1:
|
||||
anchor_stride = anchor_strides[0]
|
||||
cell_anchors = [
|
||||
generate_anchors(anchor_stride, sizes, aspect_ratios)
|
||||
]
|
||||
cell_anchors = [generate_anchors(anchor_stride, sizes, aspect_ratios)]
|
||||
else:
|
||||
if len(anchor_strides) != len(sizes):
|
||||
raise RuntimeError("FPN should have #anchor_strides == #sizes")
|
||||
|
@ -368,7 +407,7 @@ class AnchorGenerator:
|
|||
generate_anchors(
|
||||
anchor_stride,
|
||||
size if isinstance(size, (tuple, list)) else (size,),
|
||||
aspect_ratios
|
||||
aspect_ratios,
|
||||
)
|
||||
for anchor_stride, size in zip(anchor_strides, sizes)
|
||||
]
|
||||
|
@ -387,10 +426,18 @@ class AnchorGenerator:
|
|||
grid_height, grid_width = size
|
||||
device = base_anchors.device
|
||||
shifts_x = Tensor.arange(
|
||||
start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
|
||||
start=0,
|
||||
stop=grid_width * stride,
|
||||
step=stride,
|
||||
dtype=dtypes.float32,
|
||||
device=device,
|
||||
)
|
||||
shifts_y = Tensor.arange(
|
||||
start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
|
||||
start=0,
|
||||
stop=grid_height * stride,
|
||||
step=stride,
|
||||
dtype=dtypes.float32,
|
||||
device=device,
|
||||
)
|
||||
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
|
||||
shift_x = shift_x.reshape(-1)
|
||||
|
@ -398,7 +445,9 @@ class AnchorGenerator:
|
|||
shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
||||
|
||||
anchors.append(
|
||||
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
|
||||
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(
|
||||
-1, 4
|
||||
)
|
||||
)
|
||||
|
||||
return anchors
|
||||
|
@ -415,14 +464,16 @@ class AnchorGenerator:
|
|||
)
|
||||
else:
|
||||
device = anchors.device
|
||||
inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
|
||||
inds_inside = Tensor.ones(
|
||||
anchors.shape[0], dtype=dtypes.uint8, device=device
|
||||
)
|
||||
boxlist.add_field("visibility", inds_inside)
|
||||
|
||||
def __call__(self, image_list, feature_maps):
|
||||
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
||||
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
|
||||
anchors = []
|
||||
for (image_height, image_width) in image_list.image_sizes:
|
||||
for image_height, image_width in image_list.image_sizes:
|
||||
anchors_in_image = []
|
||||
for anchors_per_feature_map in anchors_over_all_feature_maps:
|
||||
boxlist = BoxList(
|
||||
|
@ -437,14 +488,19 @@ class AnchorGenerator:
|
|||
def generate_anchors(
|
||||
stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
|
||||
):
|
||||
return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
|
||||
return _generate_anchors(
|
||||
stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios))
|
||||
)
|
||||
|
||||
|
||||
def _generate_anchors(base_size, scales, aspect_ratios):
|
||||
anchor = Tensor([1, 1, base_size, base_size]) - 1
|
||||
anchors = _ratio_enum(anchor, aspect_ratios)
|
||||
anchors = Tensor.cat(
|
||||
*[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
|
||||
*[
|
||||
_scale_enum(anchors[i, :], scales).reshape(-1, 4)
|
||||
for i in range(anchors.shape[0])
|
||||
]
|
||||
)
|
||||
return anchors
|
||||
|
||||
|
@ -460,12 +516,15 @@ def _whctrs(anchor):
|
|||
def _mkanchors(ws, hs, x_ctr, y_ctr):
|
||||
ws = ws[:, None]
|
||||
hs = hs[:, None]
|
||||
anchors = Tensor.cat(*(
|
||||
anchors = Tensor.cat(
|
||||
*(
|
||||
x_ctr - 0.5 * (ws - 1),
|
||||
y_ctr - 0.5 * (hs - 1),
|
||||
x_ctr + 0.5 * (ws - 1),
|
||||
y_ctr + 0.5 * (hs - 1),
|
||||
), dim=1)
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
return anchors
|
||||
|
||||
|
||||
|
@ -504,7 +563,7 @@ class RPNHead:
|
|||
|
||||
|
||||
class BoxCoder(object):
|
||||
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
|
||||
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
|
||||
self.weights = weights
|
||||
self.bbox_xform_clip = bbox_xform_clip
|
||||
|
||||
|
@ -557,7 +616,11 @@ class BoxCoder(object):
|
|||
y = pred_ctr_y - 0.5 * pred_h
|
||||
w = pred_ctr_x + 0.5 * pred_w - 1
|
||||
h = pred_ctr_y + 0.5 * pred_h - 1
|
||||
pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
|
||||
pred_boxes = (
|
||||
Tensor.stack([x, y, w, h])
|
||||
.permute(1, 2, 0)
|
||||
.reshape(rel_codes.shape[0], rel_codes.shape[1])
|
||||
)
|
||||
return pred_boxes
|
||||
|
||||
|
||||
|
@ -578,9 +641,7 @@ def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
|
|||
def remove_small_boxes(boxlist, min_size):
|
||||
xywh_boxes = boxlist.convert("xywh").bbox
|
||||
_, _, ws, hs = xywh_boxes.chunk(4, dim=1)
|
||||
keep = ((
|
||||
(ws >= min_size) * (hs >= min_size)
|
||||
) > 0).reshape(-1)
|
||||
keep = (((ws >= min_size) * (hs >= min_size)) > 0).reshape(-1)
|
||||
if keep.sum().numpy() == len(boxlist):
|
||||
return boxlist
|
||||
else:
|
||||
|
@ -630,8 +691,12 @@ class RPNPostProcessor:
|
|||
box_regression_list = []
|
||||
concat_anchors_list = []
|
||||
for batch_idx in range(N):
|
||||
box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
|
||||
concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
|
||||
box_regression_list.append(
|
||||
tensor_gather(box_regression[batch_idx], topk_idx[batch_idx])
|
||||
)
|
||||
concat_anchors_list.append(
|
||||
tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx])
|
||||
)
|
||||
|
||||
box_regression = Tensor.stack(box_regression_list)
|
||||
concat_anchors = Tensor.stack(concat_anchors_list)
|
||||
|
@ -677,9 +742,7 @@ class RPNPostProcessor:
|
|||
for i in range(num_images):
|
||||
objectness = boxlists[i].get_field("objectness")
|
||||
post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
|
||||
_, inds_sorted = topk(objectness,
|
||||
post_nms_top_n, dim=0, sorted=False
|
||||
)
|
||||
_, inds_sorted = topk(objectness, post_nms_top_n, dim=0, sorted=False)
|
||||
boxlists[i] = boxlists[i][inds_sorted]
|
||||
return boxlists
|
||||
|
||||
|
@ -689,9 +752,7 @@ class RPN:
|
|||
self.anchor_generator = AnchorGenerator()
|
||||
|
||||
in_channels = 256
|
||||
head = RPNHead(
|
||||
in_channels, self.anchor_generator.num_anchors_per_location()[0]
|
||||
)
|
||||
head = RPNHead(in_channels, self.anchor_generator.num_anchors_per_location()[0])
|
||||
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
||||
box_selector_test = RPNPostProcessor(
|
||||
pre_nms_top_n=1000,
|
||||
|
@ -699,7 +760,7 @@ class RPN:
|
|||
nms_thresh=0.7,
|
||||
min_size=0,
|
||||
box_coder=rpn_box_coder,
|
||||
fpn_post_nms_top_n=1000
|
||||
fpn_post_nms_top_n=1000,
|
||||
)
|
||||
self.head = head
|
||||
self.box_selector_test = box_selector_test
|
||||
|
@ -725,7 +786,7 @@ def make_conv3x3(
|
|||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
bias=False if use_gn else True
|
||||
bias=False if use_gn else True,
|
||||
)
|
||||
return conv
|
||||
|
||||
|
@ -746,10 +807,18 @@ class MaskRCNNFPNFeatureExtractor:
|
|||
use_gn = False
|
||||
layers = (256, 256, 256, 256)
|
||||
dilation = 1
|
||||
self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
|
||||
self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
|
||||
self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
|
||||
self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
|
||||
self.mask_fcn1 = make_conv3x3(
|
||||
input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn
|
||||
)
|
||||
self.mask_fcn2 = make_conv3x3(
|
||||
layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn
|
||||
)
|
||||
self.mask_fcn3 = make_conv3x3(
|
||||
layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn
|
||||
)
|
||||
self.mask_fcn4 = make_conv3x3(
|
||||
layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn
|
||||
)
|
||||
self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
|
||||
|
||||
def __call__(self, x, proposals):
|
||||
|
@ -833,7 +902,9 @@ def _bilinear_interpolate(
|
|||
y = Tensor.where(ymask[:, None, :], y, 0)
|
||||
x = Tensor.where(xmask[:, None, :], x, 0)
|
||||
key1 = roi_batch_ind[:, None, None, None, None, None]
|
||||
key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
|
||||
key2 = Tensor.arange(channels, device=input.device)[
|
||||
None, :, None, None, None, None
|
||||
]
|
||||
key3 = y[:, None, :, None, :, None]
|
||||
key4 = x[:, None, None, :, None, :]
|
||||
return tensor_getitem(input, key1, key2, key3, key4) # [K, C, PH, PW, IY, IX]
|
||||
|
@ -855,8 +926,11 @@ def _bilinear_interpolate(
|
|||
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
|
||||
return val
|
||||
|
||||
|
||||
# https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
|
||||
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
||||
def _roi_align(
|
||||
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned
|
||||
):
|
||||
orig_dtype = input.dtype
|
||||
_, _, height, width = input.shape
|
||||
ph = Tensor.arange(pooled_height, device=input.device)
|
||||
|
@ -879,8 +953,12 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
|
|||
bin_size_w = roi_width / pooled_width
|
||||
|
||||
exact_sampling = sampling_ratio > 0
|
||||
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
|
||||
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
|
||||
roi_bin_grid_h = (
|
||||
sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
|
||||
)
|
||||
roi_bin_grid_w = (
|
||||
sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
|
||||
)
|
||||
|
||||
if exact_sampling:
|
||||
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
|
||||
|
@ -923,6 +1001,7 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
|
|||
output = output.cast(orig_dtype)
|
||||
return output
|
||||
|
||||
|
||||
class ROIAlign:
|
||||
def __init__(self, output_size, spatial_scale, sampling_ratio):
|
||||
self.output_size = output_size
|
||||
|
@ -931,7 +1010,13 @@ class ROIAlign:
|
|||
|
||||
def __call__(self, input, rois):
|
||||
output = _roi_align(
|
||||
input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
|
||||
input,
|
||||
rois,
|
||||
self.spatial_scale,
|
||||
self.output_size[0],
|
||||
self.output_size[1],
|
||||
self.sampling_ratio,
|
||||
aligned=False,
|
||||
)
|
||||
return output
|
||||
|
||||
|
@ -1002,7 +1087,16 @@ class Pooler:
|
|||
all_idxs.extend(idx_in_level)
|
||||
results.append(pooler_output)
|
||||
|
||||
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
|
||||
return tensor_gather(
|
||||
Tensor.cat(*results),
|
||||
[
|
||||
x[0]
|
||||
for x in sorted(
|
||||
{i: idx for i, idx in enumerate(all_idxs)}.items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class FPNPredictor:
|
||||
|
@ -1027,13 +1121,13 @@ class PostProcessor:
|
|||
nms=0.5,
|
||||
detections_per_img=100,
|
||||
box_coder=None,
|
||||
cls_agnostic_bbox_reg=False
|
||||
cls_agnostic_bbox_reg=False,
|
||||
):
|
||||
self.score_thresh = score_thresh
|
||||
self.nms = nms
|
||||
self.detections_per_img = detections_per_img
|
||||
if box_coder is None:
|
||||
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
|
||||
box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
|
||||
self.box_coder = box_coder
|
||||
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
|
||||
|
||||
|
@ -1090,9 +1184,7 @@ class PostProcessor:
|
|||
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
|
||||
boxlist_for_class.add_field("scores", scores_j)
|
||||
if len(boxlist_for_class):
|
||||
boxlist_for_class = boxlist_nms(
|
||||
boxlist_for_class, self.nms
|
||||
)
|
||||
boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
|
||||
num_labels = len(boxlist_for_class)
|
||||
boxlist_for_class.add_field(
|
||||
"labels", Tensor.full((num_labels,), j, device=device)
|
||||
|
@ -1119,8 +1211,8 @@ class RoIBoxHead:
|
|||
score_thresh=0.05,
|
||||
nms=0.5,
|
||||
detections_per_img=100,
|
||||
box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
|
||||
cls_agnostic_bbox_reg=False
|
||||
box_coder=BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)),
|
||||
cls_agnostic_bbox_reg=False,
|
||||
)
|
||||
|
||||
def __call__(self, features, proposals, targets=None):
|
||||
|
@ -1210,7 +1302,6 @@ def to_image_list(tensors, size_divisible=32):
|
|||
elif isinstance(tensors, (tuple, list)):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
|
||||
if size_divisible > 0:
|
||||
|
||||
stride = size_divisible
|
||||
max_size = list(max_size)
|
||||
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
|
||||
|
@ -1237,10 +1328,13 @@ class MaskRCNN:
|
|||
self.roi_heads = RoIHeads(self.backbone.out_channels)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path('./') / "weights/maskrcnn.pt"
|
||||
fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
|
||||
fn = Path("./") / "weights/maskrcnn.pt"
|
||||
fetch(
|
||||
"https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth",
|
||||
fn,
|
||||
)
|
||||
|
||||
state_dict = torch_load(fn)['model']
|
||||
state_dict = torch_load(fn)["model"]
|
||||
loaded_keys = []
|
||||
for k, v in state_dict.items():
|
||||
if "module." in k:
|
||||
|
@ -1265,7 +1359,7 @@ class MaskRCNN:
|
|||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||
model = MaskRCNN(backbone=resnet)
|
||||
model.load_from_pretrained()
|
||||
|
|
|
@ -3,20 +3,33 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn.state import torch_load
|
||||
from tinygrad.helpers import fetch, get_child
|
||||
|
||||
|
||||
class BasicBlock:
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
||||
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
assert (
|
||||
groups == 1 and base_width == 64
|
||||
), "BasicBlock only supports groups=1 and base_width=64"
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, padding=1, stride=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
|
@ -31,20 +44,44 @@ class Bottleneck:
|
|||
# NOTE: stride_in_1x1=False, this is the v1.5 variant
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
|
||||
def __init__(
|
||||
self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64
|
||||
):
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
|
||||
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=stride if stride_in_1x1 else 1,
|
||||
bias=False,
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1 if stride_in_1x1 else stride,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.conv3 = nn.Conv2d(
|
||||
width, self.expansion * planes, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
|
@ -55,15 +92,18 @@ class Bottleneck:
|
|||
out = out.relu()
|
||||
return out
|
||||
|
||||
|
||||
class ResNet:
|
||||
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
|
||||
def __init__(
|
||||
self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False
|
||||
):
|
||||
self.num = num
|
||||
self.block = {
|
||||
18: BasicBlock,
|
||||
34: BasicBlock,
|
||||
50: Bottleneck,
|
||||
101: Bottleneck,
|
||||
152: Bottleneck
|
||||
152: Bottleneck,
|
||||
}[num]
|
||||
|
||||
self.num_blocks = {
|
||||
|
@ -71,7 +111,7 @@ class ResNet:
|
|||
34: [3, 4, 6, 3],
|
||||
50: [3, 4, 6, 3],
|
||||
101: [3, 4, 23, 3],
|
||||
152: [3,8,36,3]
|
||||
152: [3, 8, 36, 3],
|
||||
}[num]
|
||||
|
||||
self.in_planes = 64
|
||||
|
@ -80,36 +120,64 @@ class ResNet:
|
|||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
|
||||
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
|
||||
self.layer1 = self._make_layer(
|
||||
self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.fc = (
|
||||
nn.Linear(512 * self.block.expansion, num_classes)
|
||||
if num_classes is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
if block == Bottleneck:
|
||||
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
|
||||
layers.append(
|
||||
block(
|
||||
self.in_planes,
|
||||
planes,
|
||||
stride,
|
||||
stride_in_1x1,
|
||||
self.groups,
|
||||
self.base_width,
|
||||
)
|
||||
)
|
||||
else:
|
||||
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
|
||||
layers.append(
|
||||
block(self.in_planes, planes, stride, self.groups, self.base_width)
|
||||
)
|
||||
self.in_planes = planes * block.expansion
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
is_feature_only = self.fc is None
|
||||
if is_feature_only: features = []
|
||||
if is_feature_only:
|
||||
features = []
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2)
|
||||
out = out.sequential(self.layer1)
|
||||
if is_feature_only: features.append(out)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
out = out.sequential(self.layer2)
|
||||
if is_feature_only: features.append(out)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
out = out.sequential(self.layer3)
|
||||
if is_feature_only: features.append(out)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
out = out.sequential(self.layer4)
|
||||
if is_feature_only: features.append(out)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
if not is_feature_only:
|
||||
out = out.mean([2, 3])
|
||||
out = self.fc(out).log_softmax()
|
||||
|
@ -123,12 +191,16 @@ class ResNet:
|
|||
# TODO replace with fake torch load
|
||||
|
||||
model_urls = {
|
||||
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
(50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
(101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
(18, 1, 64): "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
||||
(34, 1, 64): "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
||||
(50, 1, 64): "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
||||
(
|
||||
50,
|
||||
32,
|
||||
4,
|
||||
): "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||
(101, 1, 64): "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
||||
(152, 1, 64): "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
||||
}
|
||||
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
|
@ -136,17 +208,24 @@ class ResNet:
|
|||
obj: Tensor = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
|
||||
if 'fc.' in k and obj.shape != dat.shape:
|
||||
if "fc." in k and obj.shape != dat.shape:
|
||||
print("skipping fully connected layer")
|
||||
continue # Skip FC if transfer learning
|
||||
|
||||
# TODO: remove or when #777 is merged
|
||||
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape)
|
||||
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (
|
||||
k,
|
||||
obj.shape,
|
||||
dat.shape,
|
||||
)
|
||||
obj.assign(dat)
|
||||
|
||||
|
||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
||||
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
||||
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(
|
||||
50, num_classes=num_classes, groups=32, width_per_group=4
|
||||
)
|
||||
|
|
|
@ -4,6 +4,7 @@ import tinygrad.nn as nn
|
|||
from extra.models.resnet import ResNet
|
||||
import numpy as np
|
||||
|
||||
|
||||
def nms(boxes, scores, thresh=0.5):
|
||||
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
@ -15,11 +16,14 @@ def nms(boxes, scores, thresh=0.5):
|
|||
inter_y1 = np.maximum(y1[cur], y1[to_process])
|
||||
inter_x2 = np.minimum(x2[cur], x2[to_process])
|
||||
inter_y2 = np.minimum(y2[cur], y2[to_process])
|
||||
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1)
|
||||
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(
|
||||
0, inter_y2 - inter_y1 + 1
|
||||
)
|
||||
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
|
||||
to_process = to_process[np.where(iou <= thresh)[0]]
|
||||
return keep
|
||||
|
||||
|
||||
def decode_bbox(offsets, anchors):
|
||||
dx, dy, dw, dh = np.rollaxis(offsets, 1)
|
||||
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
|
||||
|
@ -30,6 +34,7 @@ def decode_bbox(offsets, anchors):
|
|||
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
|
||||
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
|
||||
|
||||
|
||||
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
|
||||
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
|
||||
anchors = []
|
||||
|
@ -41,39 +46,87 @@ def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
|
|||
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
|
||||
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
|
||||
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
|
||||
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
|
||||
shifts_x, shifts_y = np.meshgrid(
|
||||
np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h
|
||||
)
|
||||
shifts_x = shifts_x.reshape(-1)
|
||||
shifts_y = shifts_y.reshape(-1)
|
||||
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
|
||||
shifts = np.stack(
|
||||
[shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32
|
||||
)
|
||||
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
|
||||
return anchors
|
||||
|
||||
|
||||
class RetinaNet:
|
||||
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
|
||||
def __init__(
|
||||
self,
|
||||
backbone: ResNet,
|
||||
num_classes=264,
|
||||
num_anchors=9,
|
||||
scales=None,
|
||||
aspect_ratios=None,
|
||||
):
|
||||
assert isinstance(backbone, ResNet)
|
||||
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
|
||||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
||||
scales = (
|
||||
tuple(
|
||||
(i, int(i * 2 ** (1 / 3)), int(i * 2 ** (2 / 3)))
|
||||
for i in 2 ** np.arange(5, 10)
|
||||
)
|
||||
if scales is None
|
||||
else scales
|
||||
)
|
||||
aspect_ratios = (
|
||||
((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
||||
)
|
||||
self.num_anchors, self.num_classes = num_anchors, num_classes
|
||||
assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios))
|
||||
assert len(scales) == len(aspect_ratios) and all(
|
||||
self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)
|
||||
)
|
||||
|
||||
self.backbone = ResNetFPN(backbone)
|
||||
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
|
||||
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
|
||||
self.head = RetinaHead(
|
||||
self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes
|
||||
)
|
||||
self.anchor_gen = lambda input_size: generate_anchors(
|
||||
input_size,
|
||||
self.backbone.compute_grid_sizes(input_size),
|
||||
scales,
|
||||
aspect_ratios,
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.forward(x)
|
||||
|
||||
def forward(self, x):
|
||||
return self.head(self.backbone(x))
|
||||
|
||||
def load_from_pretrained(self):
|
||||
model_urls = {
|
||||
(50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
||||
(50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
|
||||
(
|
||||
50,
|
||||
1,
|
||||
64,
|
||||
): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
||||
(
|
||||
50,
|
||||
32,
|
||||
4,
|
||||
): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
|
||||
}
|
||||
self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)]
|
||||
self.url = model_urls[
|
||||
(
|
||||
self.backbone.body.num,
|
||||
self.backbone.body.groups,
|
||||
self.backbone.body.base_width,
|
||||
)
|
||||
]
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu')
|
||||
state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict
|
||||
|
||||
state_dict = load_state_dict_from_url(
|
||||
self.url, progress=True, map_location="cpu"
|
||||
)
|
||||
state_dict = state_dict["model"] if "model" in state_dict.keys() else state_dict
|
||||
for k, v in state_dict.items():
|
||||
obj = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
|
@ -81,10 +134,21 @@ class RetinaNet:
|
|||
obj.assign(dat)
|
||||
|
||||
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
|
||||
def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
|
||||
def postprocess_detections(
|
||||
self,
|
||||
predictions,
|
||||
input_size=(800, 800),
|
||||
image_sizes=None,
|
||||
orig_image_sizes=None,
|
||||
score_thresh=0.05,
|
||||
topk_candidates=1000,
|
||||
nms_thresh=0.5,
|
||||
):
|
||||
anchors = self.anchor_gen(input_size)
|
||||
grid_sizes = self.backbone.compute_grid_sizes(input_size)
|
||||
split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
|
||||
split_idx = np.cumsum(
|
||||
[int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]]
|
||||
)
|
||||
detections = []
|
||||
for i, predictions_per_image in enumerate(predictions):
|
||||
h, w = input_size if image_sizes is None else image_sizes[i]
|
||||
|
@ -94,7 +158,9 @@ class RetinaNet:
|
|||
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
|
||||
|
||||
image_boxes, image_scores, image_labels = [], [], []
|
||||
for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors):
|
||||
for offsets_per_level, scores_per_level, anchors_per_level in zip(
|
||||
offsets_per_image, scores_per_image, anchors
|
||||
):
|
||||
# remove low scoring boxes
|
||||
scores_per_level = scores_per_level.flatten()
|
||||
keep_idxs = scores_per_level > score_thresh
|
||||
|
@ -104,16 +170,23 @@ class RetinaNet:
|
|||
topk_idxs = np.where(keep_idxs)[0]
|
||||
num_topk = min(len(topk_idxs), topk_candidates)
|
||||
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
|
||||
topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs]
|
||||
topk_idxs, scores_per_level = (
|
||||
topk_idxs[sort_idxs],
|
||||
scores_per_level[sort_idxs],
|
||||
)
|
||||
|
||||
# bbox coords from offsets
|
||||
anchor_idxs = topk_idxs // self.num_classes
|
||||
labels_per_level = topk_idxs % self.num_classes
|
||||
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
|
||||
boxes_per_level = decode_bbox(
|
||||
offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
||||
)
|
||||
# clip to image size
|
||||
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
|
||||
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
|
||||
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4)
|
||||
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(
|
||||
-1, 4
|
||||
)
|
||||
|
||||
image_boxes.append(boxes_per_level)
|
||||
image_scores.append(scores_per_level)
|
||||
|
@ -127,7 +200,9 @@ class RetinaNet:
|
|||
keep_mask = np.zeros_like(image_scores, dtype=bool)
|
||||
for class_id in np.unique(image_labels):
|
||||
curr_indices = np.where(image_labels == class_id)[0]
|
||||
curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh)
|
||||
curr_keep_indices = nms(
|
||||
image_boxes[curr_indices], image_scores[curr_indices], nms_thresh
|
||||
)
|
||||
keep_mask[curr_indices[curr_keep_indices]] = True
|
||||
keep = np.where(keep_mask)[0]
|
||||
keep = keep[image_scores[keep].argsort()[::-1]]
|
||||
|
@ -139,42 +214,91 @@ class RetinaNet:
|
|||
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
|
||||
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
|
||||
# xywh format
|
||||
image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1)
|
||||
image_boxes = np.concatenate(
|
||||
[image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1
|
||||
)
|
||||
|
||||
detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]})
|
||||
detections.append(
|
||||
{
|
||||
"boxes": image_boxes,
|
||||
"scores": image_scores[keep],
|
||||
"labels": image_labels[keep],
|
||||
}
|
||||
)
|
||||
return detections
|
||||
|
||||
|
||||
class ClassificationHead:
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
||||
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
|
||||
self.conv = flatten(
|
||||
[
|
||||
(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||
lambda x: x.relu(),
|
||||
)
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
self.cls_logits = nn.Conv2d(
|
||||
in_channels, num_anchors * num_classes, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
|
||||
out = [
|
||||
self.cls_logits(feat.sequential(self.conv))
|
||||
.permute(0, 2, 3, 1)
|
||||
.reshape(feat.shape[0], -1, self.num_classes)
|
||||
for feat in x
|
||||
]
|
||||
return out[0].cat(*out[1:], dim=1).sigmoid()
|
||||
|
||||
|
||||
class RegressionHead:
|
||||
def __init__(self, in_channels, num_anchors):
|
||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
||||
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
|
||||
self.conv = flatten(
|
||||
[
|
||||
(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||
lambda x: x.relu(),
|
||||
)
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
self.bbox_reg = nn.Conv2d(
|
||||
in_channels, num_anchors * 4, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
|
||||
out = [
|
||||
self.bbox_reg(feat.sequential(self.conv))
|
||||
.permute(0, 2, 3, 1)
|
||||
.reshape(feat.shape[0], -1, 4)
|
||||
for feat in x
|
||||
]
|
||||
return out[0].cat(*out[1:], dim=1)
|
||||
|
||||
|
||||
class RetinaHead:
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
|
||||
self.classification_head = ClassificationHead(
|
||||
in_channels, num_anchors, num_classes
|
||||
)
|
||||
self.regression_head = RegressionHead(in_channels, num_anchors)
|
||||
|
||||
def __call__(self, x):
|
||||
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
|
||||
out = pred_bbox.cat(pred_class, dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetFPN:
|
||||
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
|
||||
self.out_channels = out_channels
|
||||
self.body = resnet
|
||||
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
|
||||
in_channels_list = [
|
||||
(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers
|
||||
]
|
||||
self.fpn = FPN(in_channels_list, out_channels)
|
||||
|
||||
# this is needed to decouple inference from postprocessing (anchors generation)
|
||||
|
@ -190,10 +314,15 @@ class ResNetFPN:
|
|||
p5 = p4.sequential(self.body.layer4)
|
||||
return self.fpn([p3, p4, p5])
|
||||
|
||||
|
||||
class ExtraFPNBlock:
|
||||
def __init__(self, in_channels, out_channels):
|
||||
self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
||||
self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
||||
self.p6 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.p7 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.use_P5 = in_channels == out_channels
|
||||
|
||||
def __call__(self, p, c):
|
||||
|
@ -204,13 +333,20 @@ class ExtraFPNBlock:
|
|||
p.extend([p6, p7])
|
||||
return p
|
||||
|
||||
|
||||
class FPN:
|
||||
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
|
||||
self.inner_blocks, self.layer_blocks = [], []
|
||||
for in_channels in in_channels_list:
|
||||
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
||||
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
||||
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
|
||||
self.inner_blocks.append(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
)
|
||||
self.layer_blocks.append(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
self.extra_blocks = (
|
||||
ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
last_inner = self.inner_blocks[-1](x[-1])
|
||||
|
@ -219,9 +355,17 @@ class FPN:
|
|||
inner_lateral = self.inner_blocks[idx](x[idx])
|
||||
|
||||
# upsample to inner_lateral's shape
|
||||
(ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2]
|
||||
(ih, iw), (oh, ow), prefix = (
|
||||
last_inner.shape[-2:],
|
||||
inner_lateral.shape[-2:],
|
||||
last_inner.shape[:-2],
|
||||
)
|
||||
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
|
||||
inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow]
|
||||
inner_top_down = (
|
||||
last_inner.reshape(*prefix, ih, 1, iw, 1)
|
||||
.expand(*prefix, ih, eh, iw, ew)
|
||||
.reshape(*prefix, ih * eh, iw * ew)[:, :, :oh, :ow]
|
||||
)
|
||||
|
||||
last_inner = inner_lateral + inner_top_down
|
||||
results.insert(0, self.layer_blocks[idx](last_inner))
|
||||
|
@ -229,8 +373,10 @@ class FPN:
|
|||
results = self.extra_blocks(results, x)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from extra.models.resnet import ResNeXt50_32X4D
|
||||
|
||||
backbone = ResNeXt50_32X4D()
|
||||
retina = RetinaNet(backbone)
|
||||
retina.load_from_pretrained()
|
||||
|
|
|
@ -7,10 +7,31 @@ from pathlib import Path
|
|||
|
||||
|
||||
class RNNT:
|
||||
def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
|
||||
self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
|
||||
def __init__(
|
||||
self,
|
||||
input_features=240,
|
||||
vocab_size=29,
|
||||
enc_hidden_size=1024,
|
||||
pred_hidden_size=320,
|
||||
joint_hidden_size=512,
|
||||
pre_enc_layers=2,
|
||||
post_enc_layers=3,
|
||||
pred_layers=2,
|
||||
stack_time_factor=2,
|
||||
dropout=0.32,
|
||||
):
|
||||
self.encoder = Encoder(
|
||||
input_features,
|
||||
enc_hidden_size,
|
||||
pre_enc_layers,
|
||||
post_enc_layers,
|
||||
stack_time_factor,
|
||||
dropout,
|
||||
)
|
||||
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
|
||||
self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
|
||||
self.joint = Joint(
|
||||
vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
|
||||
)
|
||||
|
||||
@TinyJit
|
||||
def __call__(self, x, y, hc=None):
|
||||
|
@ -30,7 +51,12 @@ class RNNT:
|
|||
return outputs
|
||||
|
||||
def _greedy_decode(self, logits, logit_len):
|
||||
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
|
||||
hc = Tensor.zeros(
|
||||
self.prediction.rnn.layers,
|
||||
2,
|
||||
self.prediction.hidden_size,
|
||||
requires_grad=False,
|
||||
)
|
||||
labels = []
|
||||
label = Tensor.zeros(1, 1, requires_grad=False)
|
||||
mask = Tensor.zeros(1, requires_grad=False)
|
||||
|
@ -41,7 +67,14 @@ class RNNT:
|
|||
while not_blank and added < 30:
|
||||
if len(labels) > 0:
|
||||
mask = (mask + 1).clip(0, 1)
|
||||
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
|
||||
label = (
|
||||
Tensor(
|
||||
[[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]],
|
||||
requires_grad=False,
|
||||
)
|
||||
+ 1
|
||||
- 1
|
||||
)
|
||||
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
||||
k = jhc[0, 0, :29].argmax(axis=0).numpy()
|
||||
not_blank = k != 28
|
||||
|
@ -61,31 +94,59 @@ class RNNT:
|
|||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parents[1] / "weights/rnnt.pt"
|
||||
fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
|
||||
fetch(
|
||||
"https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1",
|
||||
fn,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
with open(fn, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")["state_dict"]
|
||||
|
||||
# encoder
|
||||
for i in range(2):
|
||||
self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].weights_ih.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.pre_rnn.cells[i].weights_hh.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.pre_rnn.cells[i].bias_ih.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.pre_rnn.cells[i].bias_hh.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||
)
|
||||
for i in range(3):
|
||||
self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].weights_ih.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.post_rnn.cells[i].weights_hh.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.post_rnn.cells[i].bias_ih.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.post_rnn.cells[i].bias_hh.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||
)
|
||||
|
||||
# prediction
|
||||
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
|
||||
for i in range(2):
|
||||
self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].weights_ih.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||
)
|
||||
self.prediction.rnn.cells[i].weights_hh.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||
)
|
||||
self.prediction.rnn.cells[i].bias_ih.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||
)
|
||||
self.prediction.rnn.cells[i].bias_hh.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||
)
|
||||
|
||||
# joint
|
||||
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
|
||||
|
@ -104,7 +165,9 @@ class LSTMCell:
|
|||
self.bias_hh = Tensor.uniform(hidden_size * 4)
|
||||
|
||||
def __call__(self, x, hc):
|
||||
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
|
||||
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[: x.shape[0]].linear(
|
||||
self.weights_hh.T, self.bias_hh
|
||||
)
|
||||
|
||||
i, f, g, o = gates.chunk(4, 1)
|
||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||
|
@ -121,7 +184,12 @@ class LSTM:
|
|||
self.hidden_size = hidden_size
|
||||
self.layers = layers
|
||||
|
||||
self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
|
||||
self.cells = [
|
||||
LSTMCell(input_size, hidden_size, dropout)
|
||||
if i == 0
|
||||
else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0)
|
||||
for i in range(layers)
|
||||
]
|
||||
|
||||
def __call__(self, x, hc):
|
||||
@TinyJit
|
||||
|
@ -129,7 +197,9 @@ class LSTM:
|
|||
return self.do_step(x_, hc_)
|
||||
|
||||
if hc is None:
|
||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
|
||||
hc = Tensor.zeros(
|
||||
self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False
|
||||
)
|
||||
|
||||
output = None
|
||||
for t in range(x.shape[0]):
|
||||
|
@ -159,10 +229,20 @@ class StackTime:
|
|||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
pre_layers,
|
||||
post_layers,
|
||||
stack_time_factor,
|
||||
dropout,
|
||||
):
|
||||
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
|
||||
self.stack_time = StackTime(stack_time_factor)
|
||||
self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
|
||||
self.post_rnn = LSTM(
|
||||
stack_time_factor * hidden_size, hidden_size, post_layers, dropout
|
||||
)
|
||||
|
||||
def __call__(self, x, x_lens):
|
||||
x, _ = self.pre_rnn(x, None)
|
||||
|
@ -185,7 +265,9 @@ class Prediction:
|
|||
|
||||
|
||||
class Joint:
|
||||
def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
|
||||
def __init__(
|
||||
self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
|
||||
):
|
||||
self.dropout = dropout
|
||||
|
||||
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
|
||||
|
|
|
@ -1,8 +1,17 @@
|
|||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
ff_dim,
|
||||
prenorm=False,
|
||||
act=lambda x: x.relu(),
|
||||
dropout=0.1,
|
||||
):
|
||||
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
@ -10,11 +19,23 @@ class TransformerBlock:
|
|||
self.prenorm, self.act = prenorm, act
|
||||
self.dropout = dropout
|
||||
|
||||
self.query = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.key = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.value = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.query = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
self.key = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
self.value = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
|
||||
self.out = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.out = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
|
||||
self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
|
||||
self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
|
@ -24,25 +45,41 @@ class TransformerBlock:
|
|||
|
||||
def attn(self, x):
|
||||
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
|
||||
query, key, value = [x.linear(*y).reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)).transpose(1,2) for y in [self.query, self.key, self.value]]
|
||||
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(1,2)
|
||||
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
|
||||
query, key, value = [
|
||||
x.linear(*y)
|
||||
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size))
|
||||
.transpose(1, 2)
|
||||
for y in [self.query, self.key, self.value]
|
||||
]
|
||||
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(
|
||||
1, 2
|
||||
)
|
||||
return attention.reshape(
|
||||
shape=(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
).linear(*self.out)
|
||||
|
||||
def __call__(self, x):
|
||||
if self.prenorm:
|
||||
x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout)
|
||||
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
|
||||
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(
|
||||
*self.ff2
|
||||
).dropout(self.dropout)
|
||||
else:
|
||||
x = x + self.attn(x).dropout(self.dropout)
|
||||
x = x.layernorm().linear(*self.ln1)
|
||||
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
|
||||
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(
|
||||
self.dropout
|
||||
)
|
||||
x = x.layernorm().linear(*self.ln2)
|
||||
return x
|
||||
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
|
||||
self.maxlen, self.syms = maxlen, syms
|
||||
self.embed = Tensor.scaled_uniform(maxlen+syms, embed_dim, requires_grad=False)
|
||||
self.embed = Tensor.scaled_uniform(
|
||||
maxlen + syms, embed_dim, requires_grad=False
|
||||
)
|
||||
self.tbs = []
|
||||
for i in range(layers):
|
||||
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
|
||||
|
@ -57,8 +94,11 @@ class Transformer:
|
|||
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
|
||||
onehot = onehot.reshape(bs * x.shape[1], self.maxlen + self.syms)
|
||||
|
||||
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
|
||||
x = (
|
||||
Tensor(onehot, device=x.device)
|
||||
.dot(self.embed)
|
||||
.reshape(shape=(bs, x.shape[1], -1))
|
||||
)
|
||||
x = x.sequential(self.tbs)
|
||||
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax()
|
||||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||
|
||||
|
|
|
@ -4,25 +4,63 @@ from tinygrad import nn
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import fetch, get_child
|
||||
|
||||
|
||||
class DownsampleBlock:
|
||||
def __init__(self, c0, c1, stride=2):
|
||||
self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
self.conv1 = [
|
||||
nn.Conv2d(
|
||||
c0,
|
||||
c1,
|
||||
kernel_size=(3, 3, 3),
|
||||
stride=stride,
|
||||
padding=(1, 1, 1, 1, 1, 1),
|
||||
bias=False,
|
||||
),
|
||||
nn.InstanceNorm(c1),
|
||||
Tensor.relu,
|
||||
]
|
||||
self.conv2 = [
|
||||
nn.Conv2d(
|
||||
c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False
|
||||
),
|
||||
nn.InstanceNorm(c1),
|
||||
Tensor.relu,
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
return x.sequential(self.conv1).sequential(self.conv2)
|
||||
|
||||
|
||||
class UpsampleBlock:
|
||||
def __init__(self, c0, c1):
|
||||
self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)]
|
||||
self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
||||
self.upsample_conv = [
|
||||
nn.ConvTranspose2d(c0, c1, kernel_size=(2, 2, 2), stride=2)
|
||||
]
|
||||
self.conv1 = [
|
||||
nn.Conv2d(
|
||||
2 * c1,
|
||||
c1,
|
||||
kernel_size=(3, 3, 3),
|
||||
padding=(1, 1, 1, 1, 1, 1),
|
||||
bias=False,
|
||||
),
|
||||
nn.InstanceNorm(c1),
|
||||
Tensor.relu,
|
||||
]
|
||||
self.conv2 = [
|
||||
nn.Conv2d(
|
||||
c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False
|
||||
),
|
||||
nn.InstanceNorm(c1),
|
||||
Tensor.relu,
|
||||
]
|
||||
|
||||
def __call__(self, x, skip):
|
||||
x = x.sequential(self.upsample_conv)
|
||||
x = Tensor.cat(x, skip, dim=1)
|
||||
return x.sequential(self.conv1).sequential(self.conv2)
|
||||
|
||||
|
||||
class UNet3D:
|
||||
def __init__(self, in_channels=1, n_class=3):
|
||||
filters = [32, 64, 128, 256, 320]
|
||||
|
@ -30,7 +68,9 @@ class UNet3D:
|
|||
self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
|
||||
self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
|
||||
self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
|
||||
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])]
|
||||
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [
|
||||
UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])
|
||||
]
|
||||
self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
|
||||
|
||||
def __call__(self, x):
|
||||
|
@ -47,13 +87,17 @@ class UNet3D:
|
|||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt"
|
||||
fetch("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
|
||||
fetch(
|
||||
"https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1",
|
||||
fn,
|
||||
)
|
||||
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
|
||||
for k, v in state_dict.items():
|
||||
obj = get_child(self, k)
|
||||
assert obj.shape == v.shape, (k, obj.shape, v.shape)
|
||||
obj.assign(v.numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue