Reformat, uh, everything, with black
parent
01503ca90d
commit
661dcc5ed0
|
@ -4,13 +4,17 @@ import pathlib
|
||||||
from hexdump import hexdump
|
from hexdump import hexdump
|
||||||
|
|
||||||
fxn = None
|
fxn = None
|
||||||
|
|
||||||
|
|
||||||
def disasm(buf):
|
def disasm(buf):
|
||||||
global fxn
|
global fxn
|
||||||
if fxn is None:
|
if fxn is None:
|
||||||
shared = pathlib.Path(__file__).parent / "disasm.so"
|
shared = pathlib.Path(__file__).parent / "disasm.so"
|
||||||
if not shared.is_file():
|
if not shared.is_file():
|
||||||
os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so')
|
os.system(
|
||||||
fxn = ctypes.CDLL(shared.as_posix())['disasm']
|
f"cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so"
|
||||||
|
)
|
||||||
|
fxn = ctypes.CDLL(shared.as_posix())["disasm"]
|
||||||
# hexdump(buf)
|
# hexdump(buf)
|
||||||
END = b"\x00\x00\x00\x00\x00\x00\x00\x03"
|
END = b"\x00\x00\x00\x00\x00\x00\x00\x03"
|
||||||
buf = buf[0x510:] # this right?
|
buf = buf[0x510:] # this right?
|
||||||
|
|
|
@ -23,21 +23,24 @@ from abc import ABC
|
||||||
|
|
||||||
# we will be using the clang backend
|
# we will be using the clang backend
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
|
|
||||||
Device.DEFAULT = "CLANG"
|
Device.DEFAULT = "CLANG"
|
||||||
|
|
||||||
# first, 2+3 as a Tensor, the highest level
|
# first, 2+3 as a Tensor, the highest level
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
a = Tensor([2])
|
a = Tensor([2])
|
||||||
b = Tensor([3])
|
b = Tensor([3])
|
||||||
result = a + b
|
result = a + b
|
||||||
print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
|
print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
|
||||||
assert result.numpy()[0] == 5.
|
assert result.numpy()[0] == 5.0
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
|
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
|
||||||
# it's worth reading tinygrad/tensor.py. it's pretty beautiful
|
# it's worth reading tinygrad/tensor.py. it's pretty beautiful
|
||||||
import tinygrad.mlops as mlops
|
import tinygrad.mlops as mlops
|
||||||
|
|
||||||
|
|
||||||
# this is the good old familiar Tensor class
|
# this is the good old familiar Tensor class
|
||||||
class Tensor:
|
class Tensor:
|
||||||
# these two are pretty straightforward
|
# these two are pretty straightforward
|
||||||
|
@ -51,10 +54,13 @@ class Tensor:
|
||||||
lazydata: LazyBuffer
|
lazydata: LazyBuffer
|
||||||
|
|
||||||
# high level ops (hlops) are defined on this class. example: relu
|
# high level ops (hlops) are defined on this class. example: relu
|
||||||
def relu(self): return self.maximum(0)
|
def relu(self):
|
||||||
|
return self.maximum(0)
|
||||||
|
|
||||||
# log is an mlop, this is the wrapper function in Tensor
|
# log is an mlop, this is the wrapper function in Tensor
|
||||||
def log(self): return mlops.Log.apply(self)
|
def log(self):
|
||||||
|
return mlops.Log.apply(self)
|
||||||
|
|
||||||
|
|
||||||
# all the definitions of the derivatives are subclasses of Function (like mlops.Log)
|
# all the definitions of the derivatives are subclasses of Function (like mlops.Log)
|
||||||
# there's only 18 mlops for derivatives for everything (in tinygrad/mlops.py, code 9/10)
|
# there's only 18 mlops for derivatives for everything (in tinygrad/mlops.py, code 9/10)
|
||||||
|
@ -62,13 +68,18 @@ class Tensor:
|
||||||
# you can differentiate the world using the chain rule
|
# you can differentiate the world using the chain rule
|
||||||
class Function:
|
class Function:
|
||||||
# example types of forward and backward
|
# example types of forward and backward
|
||||||
def forward(self, x:LazyBuffer) -> LazyBuffer: pass
|
def forward(self, x: LazyBuffer) -> LazyBuffer:
|
||||||
def backward(self, x:LazyBuffer) -> LazyBuffer: pass
|
pass
|
||||||
|
|
||||||
|
def backward(self, x: LazyBuffer) -> LazyBuffer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# == LazyBuffer (in tinygrad/lazy.py, code 5/10) ==
|
# == LazyBuffer (in tinygrad/lazy.py, code 5/10) ==
|
||||||
from tinygrad.helpers import DType
|
from tinygrad.helpers import DType
|
||||||
|
|
||||||
|
|
||||||
# this is where the properties live that you thought were a part of Tensor
|
# this is where the properties live that you thought were a part of Tensor
|
||||||
# LazyBuffer is like a Tensor without derivatives, at the mlop layer
|
# LazyBuffer is like a Tensor without derivatives, at the mlop layer
|
||||||
class LazyBuffer:
|
class LazyBuffer:
|
||||||
|
@ -91,6 +102,7 @@ class LazyBuffer:
|
||||||
# this LazyOp describes the computation needed to realize this LazyBuffer
|
# this LazyOp describes the computation needed to realize this LazyBuffer
|
||||||
op: Optional[LazyOp]
|
op: Optional[LazyOp]
|
||||||
|
|
||||||
|
|
||||||
# LazyOp (in tinygrad/ops.py, code 4/10)
|
# LazyOp (in tinygrad/ops.py, code 4/10)
|
||||||
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
|
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
|
||||||
class LazyOp:
|
class LazyOp:
|
||||||
|
@ -98,13 +110,52 @@ class LazyOp:
|
||||||
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
||||||
arg: Optional[Any] = None # and an optional static argument
|
arg: Optional[Any] = None # and an optional static argument
|
||||||
|
|
||||||
|
|
||||||
# there's currently 26 Ops you have to implement for an accelerator.
|
# there's currently 26 Ops you have to implement for an accelerator.
|
||||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto()
|
class UnaryOps(Enum):
|
||||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto()
|
EXP2 = auto()
|
||||||
class ReduceOps(Enum): SUM = auto(); MAX = auto()
|
LOG2 = auto()
|
||||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto()
|
CAST = auto()
|
||||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
|
SIN = auto()
|
||||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
|
SQRT = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryOps(Enum):
|
||||||
|
ADD = auto()
|
||||||
|
SUB = auto()
|
||||||
|
MUL = auto()
|
||||||
|
DIV = auto()
|
||||||
|
CMPLT = auto()
|
||||||
|
MAX = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceOps(Enum):
|
||||||
|
SUM = auto()
|
||||||
|
MAX = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class MovementOps(Enum):
|
||||||
|
RESHAPE = auto()
|
||||||
|
PERMUTE = auto()
|
||||||
|
EXPAND = auto()
|
||||||
|
PAD = auto()
|
||||||
|
SHRINK = auto()
|
||||||
|
STRIDE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class TernaryOps(Enum):
|
||||||
|
MULACC = auto()
|
||||||
|
WHERE = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadOps(Enum):
|
||||||
|
EMPTY = auto()
|
||||||
|
CONST = auto()
|
||||||
|
FROM = auto()
|
||||||
|
CONTIGUOUS = auto()
|
||||||
|
CUSTOM = auto()
|
||||||
|
|
||||||
|
|
||||||
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
|
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
|
||||||
# you do not need to implement the MovementOps
|
# you do not need to implement the MovementOps
|
||||||
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)
|
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)
|
||||||
|
@ -135,7 +186,9 @@ assert len(lazyop.src) == 2
|
||||||
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
||||||
assert lazyop.src[0].op.op == LoadOps.FROM
|
assert lazyop.src[0].op.op == LoadOps.FROM
|
||||||
assert lazyop.src[0].op.src[0].device == "CPU"
|
assert lazyop.src[0].op.src[0].device == "CPU"
|
||||||
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
assert (
|
||||||
|
lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2
|
||||||
|
), "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
||||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||||
|
|
||||||
# now we realize the LazyBuffer
|
# now we realize the LazyBuffer
|
||||||
|
@ -151,12 +204,15 @@ assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU,
|
||||||
|
|
||||||
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
|
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
|
||||||
|
|
||||||
|
|
||||||
# Interpreted backends are very simple (example: CPU and TORCH)
|
# Interpreted backends are very simple (example: CPU and TORCH)
|
||||||
class Interpreted:
|
class Interpreted:
|
||||||
# and they have a lookup table to functions for the Ops
|
# and they have a lookup table to functions for the Ops
|
||||||
fxn_for_op: Dict[Op, Callable] = {
|
fxn_for_op: Dict[Op, Callable] = {
|
||||||
UnaryOps.EXP2: lambda x: np.exp2(x),
|
UnaryOps.EXP2: lambda x: np.exp2(x),
|
||||||
BinaryOps.ADD: lambda x,y: x+y}
|
BinaryOps.ADD: lambda x, y: x + y,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Compiled backends take a little more (example: GPU and LLVM)
|
# Compiled backends take a little more (example: GPU and LLVM)
|
||||||
class Compiled:
|
class Compiled:
|
||||||
|
@ -166,26 +222,40 @@ class Compiled:
|
||||||
# and a runtime, which runs the generated code
|
# and a runtime, which runs the generated code
|
||||||
runtime: Type[Runtime]
|
runtime: Type[Runtime]
|
||||||
|
|
||||||
|
|
||||||
# Runtime is what actually runs the kernels for a compiled backend
|
# Runtime is what actually runs the kernels for a compiled backend
|
||||||
class Runtime(ABC):
|
class Runtime(ABC):
|
||||||
# `name` is the name of the function, and `prg` is the code
|
# `name` is the name of the function, and `prg` is the code
|
||||||
# the constructor compiles the code
|
# the constructor compiles the code
|
||||||
def __init__(self, name:str, prg:str): pass
|
def __init__(self, name: str, prg: str):
|
||||||
|
pass
|
||||||
|
|
||||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||||
def __call__(self, *bufs:List[Buffer], global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
|
def __call__(
|
||||||
|
self,
|
||||||
|
*bufs: List[Buffer],
|
||||||
|
global_size: Optional[List[int]],
|
||||||
|
local_size: Optional[List[int]],
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# == Buffer (in tinygrad/device.py, code 6/10) ==
|
# == Buffer (in tinygrad/device.py, code 6/10) ==
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
# Buffer is where the data is actually held. it's pretty close to just memory
|
# Buffer is where the data is actually held. it's pretty close to just memory
|
||||||
class Buffer(ABC):
|
class Buffer(ABC):
|
||||||
# create an empty rawbuffer that holds `size` elements of type `dtype`
|
# create an empty rawbuffer that holds `size` elements of type `dtype`
|
||||||
# `opaque` is an opaque container class
|
# `opaque` is an opaque container class
|
||||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None): pass
|
def __init__(self, device: str, size: int, dtype: DType, opaque: Any = None):
|
||||||
|
pass
|
||||||
|
|
||||||
# toCPU converts the RawBuffer to a numpy array with shape (size,)
|
# toCPU converts the RawBuffer to a numpy array with shape (size,)
|
||||||
def toCPU(self) -> np.ndarray: pass
|
def toCPU(self) -> np.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# == Example: 2+3 in raw clang ==
|
# == Example: 2+3 in raw clang ==
|
||||||
|
@ -205,6 +275,7 @@ from tinygrad.runtime.ops_clang import ClangProgram, compile_clang
|
||||||
# then we copy the numpy in to RawMallocBuffers
|
# then we copy the numpy in to RawMallocBuffers
|
||||||
# last, we create an empty output buffer
|
# last, we create an empty output buffer
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import dtypes
|
||||||
|
|
||||||
input_a, input_b = MallocAllocator.alloc(4), MallocAllocator.alloc(4)
|
input_a, input_b = MallocAllocator.alloc(4), MallocAllocator.alloc(4)
|
||||||
output = MallocAllocator.alloc(4)
|
output = MallocAllocator.alloc(4)
|
||||||
|
|
||||||
|
@ -214,7 +285,9 @@ MallocAllocator.copyin(input_a, numpy_a.data.cast("B"))
|
||||||
MallocAllocator.copyin(input_b, numpy_b.data.cast("B"))
|
MallocAllocator.copyin(input_b, numpy_b.data.cast("B"))
|
||||||
|
|
||||||
# compile the program, run it, and 2+3 does indeed equal 5
|
# compile the program, run it, and 2+3 does indeed equal 5
|
||||||
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
|
program = ClangProgram(
|
||||||
|
"add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
|
||||||
|
)
|
||||||
program(output, input_a, input_b)
|
program(output, input_a, input_b)
|
||||||
numpy_out = np.empty(1, dtype=np.float32)
|
numpy_out = np.empty(1, dtype=np.float32)
|
||||||
MallocAllocator.copyout(numpy_out.data.cast("B"), output)
|
MallocAllocator.copyout(numpy_out.data.cast("B"), output)
|
||||||
|
@ -229,7 +302,16 @@ np.testing.assert_allclose(numpy_out, numpy_a+numpy_b)
|
||||||
# the first step of transforming an AST into code is to "linearize" it, think like toposort on the AST
|
# the first step of transforming an AST into code is to "linearize" it, think like toposort on the AST
|
||||||
# for that, we use the Linearizer, which turns an AST into a list of (linear) UOps
|
# for that, we use the Linearizer, which turns an AST into a list of (linear) UOps
|
||||||
|
|
||||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto();
|
|
||||||
|
class UOps(Enum):
|
||||||
|
LOOP = auto()
|
||||||
|
DEFINE_LOCAL = auto()
|
||||||
|
LOAD = auto()
|
||||||
|
ALU = auto()
|
||||||
|
CONST = auto()
|
||||||
|
ENDLOOP = auto()
|
||||||
|
STORE = auto()
|
||||||
|
|
||||||
|
|
||||||
class UOp:
|
class UOp:
|
||||||
uop: UOps
|
uop: UOps
|
||||||
|
@ -238,26 +320,34 @@ class UOp:
|
||||||
arg: Any
|
arg: Any
|
||||||
num: int # UOps are unique
|
num: int # UOps are unique
|
||||||
|
|
||||||
|
|
||||||
class Linearizer:
|
class Linearizer:
|
||||||
# create the kernel with the AST
|
# create the kernel with the AST
|
||||||
# NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change
|
# NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change
|
||||||
def __init__(self, ast:LazyOp): pass
|
def __init__(self, ast: LazyOp):
|
||||||
def linearize(self): pass
|
pass
|
||||||
|
|
||||||
|
def linearize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
# when linearize is run, it fills in this list
|
# when linearize is run, it fills in this list
|
||||||
uops: List[UOp]
|
uops: List[UOp]
|
||||||
|
|
||||||
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
result = Tensor(2).realize() + Tensor(3).realize()
|
result = Tensor(2).realize() + Tensor(3).realize()
|
||||||
|
|
||||||
# use the real Linearizer to linearize 2+3
|
# use the real Linearizer to linearize 2+3
|
||||||
from tinygrad.codegen.linearizer import Linearizer
|
from tinygrad.codegen.linearizer import Linearizer
|
||||||
|
|
||||||
sched = result.lazydata.schedule()
|
sched = result.lazydata.schedule()
|
||||||
linearizer = Linearizer(sched[-1].ast)
|
linearizer = Linearizer(sched[-1].ast)
|
||||||
linearizer.linearize()
|
linearizer.linearize()
|
||||||
|
|
||||||
# print the uops
|
# print the uops
|
||||||
for uop in linearizer.uops: print(uop)
|
for uop in linearizer.uops:
|
||||||
|
print(uop)
|
||||||
|
|
||||||
# output:
|
# output:
|
||||||
"""
|
"""
|
||||||
|
@ -275,11 +365,13 @@ for uop in linearizer.uops: print(uop)
|
||||||
# here, we have an example where we fetch the generated code from the JIT
|
# here, we have an example where we fetch the generated code from the JIT
|
||||||
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
result = Tensor(2) + Tensor(3)
|
result = Tensor(2) + Tensor(3)
|
||||||
|
|
||||||
# we have a global cache used by the JIT
|
# we have a global cache used by the JIT
|
||||||
# from there, we can see the generated clang code
|
# from there, we can see the generated clang code
|
||||||
from tinygrad.jit import CacheCollector
|
from tinygrad.jit import CacheCollector
|
||||||
|
|
||||||
CacheCollector.start() # enables the cache
|
CacheCollector.start() # enables the cache
|
||||||
result.realize() # create the program and runs it
|
result.realize() # create the program and runs it
|
||||||
cache_saved = CacheCollector.finish() # disable the cache
|
cache_saved = CacheCollector.finish() # disable the cache
|
||||||
|
@ -319,7 +411,9 @@ print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||||
# we can then reshape it, and the strides change again
|
# we can then reshape it, and the strides change again
|
||||||
# note how the permute stays applied
|
# note how the permute stays applied
|
||||||
a = a.reshape((5, 2, 5, 2))
|
a = a.reshape((5, 2, 5, 2))
|
||||||
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
print(
|
||||||
|
a
|
||||||
|
) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
||||||
|
|
||||||
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
|
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
|
||||||
a = a.reshape((100,))
|
a = a.reshape((100,))
|
||||||
|
|
|
@ -49,14 +49,21 @@ b = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I",
|
||||||
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
||||||
|
|
||||||
# describe the computation
|
# describe the computation
|
||||||
ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
|
ld_1 = LazyOp(
|
||||||
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
|
BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||||
|
)
|
||||||
|
ld_2 = LazyOp(
|
||||||
|
BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||||
|
)
|
||||||
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
|
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
|
||||||
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
|
st_0 = LazyOp(
|
||||||
|
BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||||
|
)
|
||||||
|
|
||||||
# convert the computation to a "linearized" format (print the format)
|
# convert the computation to a "linearized" format (print the format)
|
||||||
lin = Device[DEVICE].get_linearizer(st_0).linearize()
|
lin = Device[DEVICE].get_linearizer(st_0).linearize()
|
||||||
for u in lin.uops: print(u)
|
for u in lin.uops:
|
||||||
|
print(u)
|
||||||
|
|
||||||
# compile a program (and print the source)
|
# compile a program (and print the source)
|
||||||
fxn = Device[DEVICE].to_program(lin)
|
fxn = Device[DEVICE].to_program(lin)
|
||||||
|
@ -79,6 +86,7 @@ from tinygrad.realize import run_schedule
|
||||||
# allocate some values + load in values
|
# allocate some values + load in values
|
||||||
# TODO: remove numpy here
|
# TODO: remove numpy here
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
a = LazyBuffer.fromCPU(np.array([2], np.int32)).copy_to_device(DEVICE)
|
a = LazyBuffer.fromCPU(np.array([2], np.int32)).copy_to_device(DEVICE)
|
||||||
b = LazyBuffer.fromCPU(np.array([3], np.int32)).copy_to_device(DEVICE)
|
b = LazyBuffer.fromCPU(np.array([3], np.int32)).copy_to_device(DEVICE)
|
||||||
|
|
||||||
|
@ -87,10 +95,12 @@ out = a.e(BinaryOps.ADD, b)
|
||||||
|
|
||||||
# schedule the computation as a list of kernels
|
# schedule the computation as a list of kernels
|
||||||
sched = out.schedule()
|
sched = out.schedule()
|
||||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
|
for si in sched:
|
||||||
|
print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||||
|
|
||||||
# DEBUGGING: print the compute ast as a tree
|
# DEBUGGING: print the compute ast as a tree
|
||||||
from tinygrad.graph import print_tree
|
from tinygrad.graph import print_tree
|
||||||
|
|
||||||
print_tree(sched[-1].ast)
|
print_tree(sched[-1].ast)
|
||||||
# NOTE: sched[-1].ast is the same as st_0 above
|
# NOTE: sched[-1].ast is the same as st_0 above
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import time
|
import time
|
||||||
from tinygrad import Tensor, TinyJit, nn, Variable
|
from tinygrad import Tensor, TinyJit, nn, Variable
|
||||||
from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype
|
from tinygrad.helpers import (
|
||||||
|
dtypes,
|
||||||
|
) # TODO: wouldn't need this if argmax returned the right dtype
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
import numpy as np # TODO: remove numpy import
|
import numpy as np # TODO: remove numpy import
|
||||||
|
|
||||||
|
|
||||||
class ActorCritic:
|
class ActorCritic:
|
||||||
def __init__(self, in_features, out_features, hidden_state=32):
|
def __init__(self, in_features, out_features, hidden_state=32):
|
||||||
self.l1 = nn.Linear(in_features, hidden_state)
|
self.l1 = nn.Linear(in_features, hidden_state)
|
||||||
|
@ -20,6 +23,7 @@ class ActorCritic:
|
||||||
x = self.c1(obs).relu()
|
x = self.c1(obs).relu()
|
||||||
return act, self.c2(x)
|
return act, self.c2(x)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model: ActorCritic, test_env: gym.Env) -> float:
|
def evaluate(model: ActorCritic, test_env: gym.Env) -> float:
|
||||||
(obs, _), terminated, truncated = test_env.reset(), False, False
|
(obs, _), terminated, truncated = test_env.reset(), False, False
|
||||||
total_rew = 0.0
|
total_rew = 0.0
|
||||||
|
@ -29,21 +33,26 @@ def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
|
||||||
total_rew += float(rew)
|
total_rew += float(rew)
|
||||||
return total_rew
|
return total_rew
|
||||||
|
|
||||||
|
|
||||||
# TODO: time should be < 5s on M1 Max
|
# TODO: time should be < 5s on M1 Max
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = gym.make('CartPole-v1')
|
env = gym.make("CartPole-v1")
|
||||||
|
|
||||||
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
|
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
|
||||||
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
|
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
def train_step(
|
||||||
|
x: Tensor, selected_action: Tensor, reward: Tensor, old_log_dist: Tensor
|
||||||
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
with Tensor.train():
|
with Tensor.train():
|
||||||
log_dist, value = model(x)
|
log_dist, value = model(x)
|
||||||
|
|
||||||
# get advantage
|
# get advantage
|
||||||
advantage = reward.reshape(-1, 1) - value
|
advantage = reward.reshape(-1, 1) - value
|
||||||
mask = selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)
|
mask = selected_action.reshape(-1, 1) == Tensor.arange(
|
||||||
|
log_dist.shape[1]
|
||||||
|
).reshape(1, -1).expand(selected_action.shape[0], -1)
|
||||||
masked_advantage = mask * advantage.detach()
|
masked_advantage = mask * advantage.detach()
|
||||||
|
|
||||||
# PPO
|
# PPO
|
||||||
|
@ -51,7 +60,9 @@ if __name__ == "__main__":
|
||||||
clipped_ratios = ratios.clip(1 - 0.2, 1 + 0.2) * masked_advantage
|
clipped_ratios = ratios.clip(1 - 0.2, 1 + 0.2) * masked_advantage
|
||||||
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
|
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
|
||||||
|
|
||||||
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
|
entropy_loss = (
|
||||||
|
(log_dist.exp() * log_dist).sum(-1).mean()
|
||||||
|
) # this encourages diversity
|
||||||
critic_loss = advantage.square().mean()
|
critic_loss = advantage.square().mean()
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
(action_loss + entropy_loss * 0.0005 + critic_loss).backward()
|
(action_loss + entropy_loss * 0.0005 + critic_loss).backward()
|
||||||
|
@ -96,7 +107,11 @@ if __name__ == "__main__":
|
||||||
discounts = np.power(0.99, np.arange(len(rews)))
|
discounts = np.power(0.99, np.arange(len(rews)))
|
||||||
Rn += [np.sum(rews[i:] * discounts[: len(rews) - i]) for i in range(len(rews))]
|
Rn += [np.sum(rews[i:] * discounts[: len(rews) - i]) for i in range(len(rews))]
|
||||||
|
|
||||||
Xn, An, Rn = Xn[-MAX_REPLAY_BUFFER:], An[-MAX_REPLAY_BUFFER:], Rn[-MAX_REPLAY_BUFFER:]
|
Xn, An, Rn = (
|
||||||
|
Xn[-MAX_REPLAY_BUFFER:],
|
||||||
|
An[-MAX_REPLAY_BUFFER:],
|
||||||
|
Rn[-MAX_REPLAY_BUFFER:],
|
||||||
|
)
|
||||||
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
|
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
|
||||||
|
|
||||||
# TODO: make this work
|
# TODO: make this work
|
||||||
|
@ -105,10 +120,16 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
|
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
samples = Tensor.randint(BS, high=X.shape[0]).realize() # TODO: remove the need for this
|
samples = Tensor.randint(
|
||||||
|
BS, high=X.shape[0]
|
||||||
|
).realize() # TODO: remove the need for this
|
||||||
# TODO: is this recompiling based on the shape?
|
# TODO: is this recompiling based on the shape?
|
||||||
action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
|
action_loss, entropy_loss, critic_loss = train_step(
|
||||||
t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}")
|
X[samples], A[samples], R[samples], old_log_dist[samples]
|
||||||
|
)
|
||||||
|
t.set_description(
|
||||||
|
f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
test_rew = evaluate(model, gym.make('CartPole-v1', render_mode='human'))
|
test_rew = evaluate(model, gym.make("CartPole-v1", render_mode="human"))
|
||||||
print(f"test reward: {test_rew}")
|
print(f"test reward: {test_rew}")
|
||||||
|
|
|
@ -4,18 +4,29 @@ from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
class Model:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
nn.Conv2d(1, 32, 5),
|
||||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
Tensor.relu,
|
||||||
nn.BatchNorm2d(32), Tensor.max_pool2d,
|
nn.Conv2d(32, 32, 5),
|
||||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
Tensor.relu,
|
||||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
nn.BatchNorm2d(32),
|
||||||
nn.BatchNorm2d(64), Tensor.max_pool2d,
|
Tensor.max_pool2d,
|
||||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
nn.Conv2d(32, 64, 3),
|
||||||
|
Tensor.relu,
|
||||||
|
nn.Conv2d(64, 64, 3),
|
||||||
|
Tensor.relu,
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
Tensor.max_pool2d,
|
||||||
|
lambda x: x.flatten(1),
|
||||||
|
nn.Linear(576, 10),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
|
return x.sequential(self.layers)
|
||||||
|
|
||||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
||||||
|
@ -29,17 +40,25 @@ if __name__ == "__main__":
|
||||||
with Tensor.train():
|
with Tensor.train():
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
loss = (
|
||||||
|
model(X_train[samples])
|
||||||
|
.sparse_categorical_crossentropy(Y_train[samples])
|
||||||
|
.backward()
|
||||||
|
)
|
||||||
opt.step()
|
opt.step()
|
||||||
return loss.realize()
|
return loss.realize()
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
|
def get_test_acc() -> Tensor:
|
||||||
|
return ((model(X_test).argmax(axis=1) == Y_test).mean() * 100).realize()
|
||||||
|
|
||||||
test_acc = float('nan')
|
test_acc = float("nan")
|
||||||
for i in (t := trange(70)):
|
for i in (t := trange(70)):
|
||||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||||
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
|
samples = Tensor.randint(
|
||||||
|
512, high=X_train.shape[0]
|
||||||
|
) # TODO: put this in the JIT when rand is fixed
|
||||||
loss = train_step(samples)
|
loss = train_step(samples)
|
||||||
if i%10 == 9: test_acc = get_test_acc().item()
|
if i % 10 == 9:
|
||||||
|
test_acc = get_test_acc().item()
|
||||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||||
|
|
|
@ -10,9 +10,11 @@ from tinygrad.helpers import GlobalCounters
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.jit import CacheCollector
|
from tinygrad.jit import CacheCollector
|
||||||
|
|
||||||
|
|
||||||
def tensors_allocated():
|
def tensors_allocated():
|
||||||
return sum(isinstance(x, Tensor) for x in gc.get_objects())
|
return sum(isinstance(x, Tensor) for x in gc.get_objects())
|
||||||
|
|
||||||
|
|
||||||
NUM = getenv("NUM", 2)
|
NUM = getenv("NUM", 2)
|
||||||
BS = getenv("BS", 8)
|
BS = getenv("BS", 8)
|
||||||
CNT = getenv("CNT", 10)
|
CNT = getenv("CNT", 10)
|
||||||
|
@ -25,9 +27,12 @@ if __name__ == "__main__":
|
||||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||||
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
||||||
parameters = get_parameters(model)
|
parameters = get_parameters(model)
|
||||||
for p in parameters: p.realize()
|
for p in parameters:
|
||||||
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
|
p.realize()
|
||||||
else: optimizer = optim.SGD(parameters, lr=0.001)
|
if ADAM:
|
||||||
|
optimizer = optim.Adam(parameters, lr=0.001)
|
||||||
|
else:
|
||||||
|
optimizer = optim.SGD(parameters, lr=0.001)
|
||||||
|
|
||||||
Tensor.training = TRAINING
|
Tensor.training = TRAINING
|
||||||
Tensor.no_grad = not BACKWARD
|
Tensor.no_grad = not BACKWARD
|
||||||
|
@ -42,7 +47,8 @@ if __name__ == "__main__":
|
||||||
st = time.monotonic()
|
st = time.monotonic()
|
||||||
out = model.forward(x_train)
|
out = model.forward(x_train)
|
||||||
loss = out.log_softmax().mul(y_train).mean()
|
loss = out.log_softmax().mul(y_train).mean()
|
||||||
if i == 2 and CLCACHE: CacheCollector.start()
|
if i == 2 and CLCACHE:
|
||||||
|
CacheCollector.start()
|
||||||
if BACKWARD:
|
if BACKWARD:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -54,7 +60,8 @@ if __name__ == "__main__":
|
||||||
et = time.monotonic()
|
et = time.monotonic()
|
||||||
else:
|
else:
|
||||||
st = mt = time.monotonic()
|
st = mt = time.monotonic()
|
||||||
for prg, args in cl_cache: prg(*args)
|
for prg, args in cl_cache:
|
||||||
|
prg(*args)
|
||||||
et = time.monotonic()
|
et = time.monotonic()
|
||||||
|
|
||||||
if i == 2 and CLCACHE:
|
if i == 2 and CLCACHE:
|
||||||
|
@ -64,4 +71,6 @@ if __name__ == "__main__":
|
||||||
loss_cpu = loss.detach().numpy()
|
loss_cpu = loss.detach().numpy()
|
||||||
cl = time.monotonic()
|
cl = time.monotonic()
|
||||||
|
|
||||||
print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
print(
|
||||||
|
f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||||
|
)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os, sys, traceback
|
import os, sys, traceback
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
@ -9,27 +10,46 @@ from tinygrad.helpers import Timing, colored, getenv, fetch
|
||||||
from extra.models.llama import Transformer, convert_from_huggingface
|
from extra.models.llama import Transformer, convert_from_huggingface
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
|
||||||
def create_fixed_tokenizer(output_file):
|
def create_fixed_tokenizer(output_file):
|
||||||
print("creating fixed tokenizer")
|
print("creating fixed tokenizer")
|
||||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||||
|
|
||||||
mp = spb2.ModelProto()
|
mp = spb2.ModelProto()
|
||||||
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
|
mp.ParseFromString(
|
||||||
|
fetch(
|
||||||
|
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true"
|
||||||
|
).read_bytes()
|
||||||
|
)
|
||||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||||
with open(output_file, "wb") as f:
|
with open(output_file, "wb") as f:
|
||||||
f.write(mp.SerializeToString())
|
f.write(mp.SerializeToString())
|
||||||
|
|
||||||
|
|
||||||
# TODO: make loading bf16 fast so we can remove this
|
# TODO: make loading bf16 fast so we can remove this
|
||||||
def create_model_cache(output_file, model):
|
def create_model_cache(output_file, model):
|
||||||
print(f"creating model cache at {output_file}")
|
print(f"creating model cache at {output_file}")
|
||||||
# TODO: add read only Tensors
|
# TODO: add read only Tensors
|
||||||
with Timing("download weights: "):
|
with Timing("download weights: "):
|
||||||
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
|
part1 = nn.state.torch_load(
|
||||||
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
|
fetch(
|
||||||
|
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
part2 = nn.state.torch_load(
|
||||||
|
fetch(
|
||||||
|
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with Timing("weights -> model: "):
|
with Timing("weights -> model: "):
|
||||||
nn.state.load_state_dict(model, convert_from_huggingface(part1, model, 32, 8), strict=False)
|
nn.state.load_state_dict(
|
||||||
nn.state.load_state_dict(model, convert_from_huggingface(part2, model, 32, 8), strict=False)
|
model, convert_from_huggingface(part1, model, 32, 8), strict=False
|
||||||
|
)
|
||||||
|
nn.state.load_state_dict(
|
||||||
|
model, convert_from_huggingface(part2, model, 32, 8), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
with Timing("saving float16 cache: "):
|
with Timing("saving float16 cache: "):
|
||||||
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
|
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
|
||||||
|
@ -37,27 +57,44 @@ def create_model_cache(output_file, model):
|
||||||
print("cache created, rerun to use")
|
print("cache created, rerun to use")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
|
|
||||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||||
with Timing("create model: "):
|
with Timing("create model: "):
|
||||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
|
model = Transformer(
|
||||||
|
4096,
|
||||||
|
14336,
|
||||||
|
n_heads=32,
|
||||||
|
n_layers=32,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
vocab_size=32002,
|
||||||
|
n_kv_heads=8,
|
||||||
|
max_context=4096,
|
||||||
|
)
|
||||||
|
|
||||||
cached_model = "/tmp/cached_openhermes.safetensors"
|
cached_model = "/tmp/cached_openhermes.safetensors"
|
||||||
if not os.path.isfile(cached_model): create_model_cache(cached_model, model)
|
if not os.path.isfile(cached_model):
|
||||||
|
create_model_cache(cached_model, model)
|
||||||
with Timing("loading float16 cache: "):
|
with Timing("loading float16 cache: "):
|
||||||
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
|
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
|
||||||
|
|
||||||
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
|
if not os.path.isfile("/tmp/tokenizer.model"):
|
||||||
|
create_fixed_tokenizer("/tmp/tokenizer.model")
|
||||||
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
|
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
|
||||||
|
|
||||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||||
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
IM_END = 32000
|
IM_END = 32000
|
||||||
IM_START = 32001
|
IM_START = 32001
|
||||||
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
|
||||||
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
|
def encode_prompt(k, v):
|
||||||
|
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
|
||||||
|
|
||||||
|
def start_prompt(k):
|
||||||
|
return [IM_START] + spp.encode(f"{k}\n")
|
||||||
|
|
||||||
def output(outputted, toks, color):
|
def output(outputted, toks, color):
|
||||||
cur = spp.decode(toks)[len(outputted) :]
|
cur = spp.decode(toks)[len(outputted) :]
|
||||||
sys.stdout.write(colored(cur, color))
|
sys.stdout.write(colored(cur, color))
|
||||||
|
@ -67,7 +104,10 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# *** app below this line ***
|
# *** app below this line ***
|
||||||
|
|
||||||
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
|
toks = [spp.bos_id()] + encode_prompt(
|
||||||
|
"system",
|
||||||
|
"You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input",
|
||||||
|
)
|
||||||
|
|
||||||
PROMPT = getenv("PROMPT", 1)
|
PROMPT = getenv("PROMPT", 1)
|
||||||
temperature = getenv("TEMP", 0.7)
|
temperature = getenv("TEMP", 0.7)
|
||||||
|
@ -83,24 +123,34 @@ if __name__ == "__main__":
|
||||||
turn = not turn
|
turn = not turn
|
||||||
old_output_len = len(outputted)
|
old_output_len = len(outputted)
|
||||||
while 1:
|
while 1:
|
||||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
|
tok = (
|
||||||
|
model(Tensor([toks[start_pos:]]), start_pos, temperature)
|
||||||
|
.multinomial()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
start_pos = len(toks)
|
start_pos = len(toks)
|
||||||
toks.append(tok)
|
toks.append(tok)
|
||||||
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
||||||
if tok == IM_END: break
|
if tok == IM_END:
|
||||||
if tok == spp.eos_id(): break
|
break
|
||||||
|
if tok == spp.eos_id():
|
||||||
|
break
|
||||||
new_output = outputted[old_output_len:]
|
new_output = outputted[old_output_len:]
|
||||||
|
|
||||||
if new_output.endswith("```") and '```python\n' in new_output:
|
if new_output.endswith("```") and "```python\n" in new_output:
|
||||||
python_code = new_output.split('```python\n')[1].split("```")[0]
|
python_code = new_output.split("```python\n")[1].split("```")[0]
|
||||||
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
|
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
|
||||||
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
|
if (
|
||||||
|
input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower()
|
||||||
|
== "y"
|
||||||
|
):
|
||||||
my_stdout = StringIO()
|
my_stdout = StringIO()
|
||||||
try:
|
try:
|
||||||
with redirect_stdout(my_stdout): exec(python_code)
|
with redirect_stdout(my_stdout):
|
||||||
|
exec(python_code)
|
||||||
result = my_stdout.getvalue()
|
result = my_stdout.getvalue()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result = ''.join(traceback.format_exception_only(e))
|
result = "".join(traceback.format_exception_only(e))
|
||||||
toks += spp.encode(f"\nOutput:\n```\n{result}```")
|
toks += spp.encode(f"\nOutput:\n```\n{result}```")
|
||||||
outputted = output(outputted, toks, "yellow")
|
outputted = output(outputted, toks, "yellow")
|
||||||
old_output_len = len(outputted)
|
old_output_len = len(outputted)
|
||||||
|
|
|
@ -9,8 +9,16 @@ import ast
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = EfficientNet(0)
|
model = EfficientNet(0)
|
||||||
model.load_from_pretrained()
|
model.load_from_pretrained()
|
||||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
mode = (
|
||||||
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
"clang"
|
||||||
|
if getenv("CLANG", "") != ""
|
||||||
|
else "webgpu"
|
||||||
|
if getenv("WEBGPU", "") != ""
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
prg, inp_sizes, out_sizes, state = export_model(
|
||||||
|
model, mode, Tensor.randn(1, 3, 224, 224)
|
||||||
|
)
|
||||||
dirname = Path(__file__).parent
|
dirname = Path(__file__).parent
|
||||||
if getenv("CLANG", "") == "":
|
if getenv("CLANG", "") == "":
|
||||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||||
|
@ -20,19 +28,33 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
cprog = [prg]
|
cprog = [prg]
|
||||||
# image library!
|
# image library!
|
||||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")]
|
cprog += [
|
||||||
|
"#define STB_IMAGE_IMPLEMENTATION",
|
||||||
|
fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h")
|
||||||
|
.read_text()
|
||||||
|
.replace("half", "_half"),
|
||||||
|
]
|
||||||
|
|
||||||
# imagenet labels, move to datasets?
|
# imagenet labels, move to datasets?
|
||||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
lbls = ast.literal_eval(
|
||||||
|
fetch(
|
||||||
|
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||||
|
).read_text()
|
||||||
|
)
|
||||||
lbls = ['"' + lbls[i] + '"' for i in range(1000)]
|
lbls = ['"' + lbls[i] + '"' for i in range(1000)]
|
||||||
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
|
inputs = "\n".join(
|
||||||
outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
|
[f"float {inp}[{inp_size}];" for inp, inp_size in inp_sizes.items()]
|
||||||
|
)
|
||||||
|
outputs = "\n".join(
|
||||||
|
[f"float {out}[{out_size}];" for out, out_size in out_sizes.items()]
|
||||||
|
)
|
||||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||||
cprog.append(inputs)
|
cprog.append(inputs)
|
||||||
cprog.append(outputs)
|
cprog.append(outputs)
|
||||||
|
|
||||||
# buffers (empty + weights)
|
# buffers (empty + weights)
|
||||||
cprog.append("""
|
cprog.append(
|
||||||
|
"""
|
||||||
int main(int argc, char* argv[]) {
|
int main(int argc, char* argv[]) {
|
||||||
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
|
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
|
||||||
int X=0, Y=0, chan=0;
|
int X=0, Y=0, chan=0;
|
||||||
|
@ -62,8 +84,9 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
|
if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
|
||||||
else printf("%s\\n", lbls[best_idx]);
|
else printf("%s\\n", lbls[best_idx]);
|
||||||
}""")
|
}"""
|
||||||
|
)
|
||||||
|
|
||||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||||
# category : 281 (tabby, tabby cat) with 9.452788
|
# category : 281 (tabby, tabby cat) with 9.452788
|
||||||
print('\n'.join(cprog))
|
print("\n".join(cprog))
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
# An example to compile a small Tensorflow model to extremely portable C code
|
# An example to compile a small Tensorflow model to extremely portable C code
|
||||||
|
|
||||||
import os, sys
|
import os, sys
|
||||||
os.environ["CLANG"] = '1'
|
|
||||||
os.environ["GPU"] = '1'
|
os.environ["CLANG"] = "1"
|
||||||
|
os.environ["GPU"] = "1"
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -12,32 +13,42 @@ from examples.compile_efficientnet import compile_net
|
||||||
from extra.onnx import get_run_onnx
|
from extra.onnx import get_run_onnx
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
def get_uncompiled_model2(dataset_size=32, output_size=4):
|
def get_uncompiled_model2(dataset_size=32, output_size=4):
|
||||||
inputs = tf.keras.Input(shape=(dataset_size,), name="inputs")
|
inputs = tf.keras.Input(shape=(dataset_size,), name="inputs")
|
||||||
x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs)
|
x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs)
|
||||||
x = tf.keras.layers.BatchNormalization()(x)
|
x = tf.keras.layers.BatchNormalization()(x)
|
||||||
x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x)
|
x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x)
|
||||||
outputs = tf.keras.layers.Dense(output_size, activation="sigmoid", name="predictions")(x)
|
outputs = tf.keras.layers.Dense(
|
||||||
|
output_size, activation="sigmoid", name="predictions"
|
||||||
|
)(x)
|
||||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def create_onnx_model(keras_model):
|
def create_onnx_model(keras_model):
|
||||||
input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
|
input_signature = [tf.TensorSpec([1, 32], tf.float32, name="x")]
|
||||||
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
|
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
|
||||||
return onnx_model
|
return onnx_model
|
||||||
|
|
||||||
|
|
||||||
def compile_onnx_model(onnx_model):
|
def compile_onnx_model(onnx_model):
|
||||||
run_onnx = get_run_onnx(onnx_model)
|
run_onnx = get_run_onnx(onnx_model)
|
||||||
|
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize()
|
def run(x):
|
||||||
|
return run_onnx({"x": x}, debug=False)["predictions"].realize()
|
||||||
|
|
||||||
the_input = Tensor.randn(1, 32)
|
the_input = Tensor.randn(1, 32)
|
||||||
the_output = run(the_input)
|
the_output = run(the_input)
|
||||||
the_output = run(the_input)
|
the_output = run(the_input)
|
||||||
|
|
||||||
special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"}
|
special_names = {
|
||||||
|
id(the_input.lazydata.realized.cl): "input",
|
||||||
|
id(the_output.lazydata.realized.cl): "outputs",
|
||||||
|
}
|
||||||
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||||
cprog = ["#include <string.h>", "#include <stdio.h>", "#include <stdlib.h>"] + cprog
|
cprog = ["#include <string.h>", "#include <stdio.h>", "#include <stdlib.h>"] + cprog
|
||||||
|
|
||||||
|
@ -60,7 +71,8 @@ def compile_onnx_model(onnx_model):
|
||||||
cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"]
|
cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"]
|
||||||
|
|
||||||
# test program
|
# test program
|
||||||
cprog.append(f"""int main(int argc, char *argv[]) {{
|
cprog.append(
|
||||||
|
f"""int main(int argc, char *argv[]) {{
|
||||||
// read in the weights from disk
|
// read in the weights from disk
|
||||||
FILE *f = fopen("/tmp/tf_weights", "rb");
|
FILE *f = fopen("/tmp/tf_weights", "rb");
|
||||||
float *weights = (float *)malloc({len(weights)});
|
float *weights = (float *)malloc({len(weights)});
|
||||||
|
@ -75,25 +87,38 @@ def compile_onnx_model(onnx_model):
|
||||||
for (int i = 0; i < 32; i++) scanf("%f", &input[i]);
|
for (int i = 0; i < 32; i++) scanf("%f", &input[i]);
|
||||||
float *outputs = infer(input);
|
float *outputs = infer(input);
|
||||||
printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]);
|
printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]);
|
||||||
}}""")
|
}}"""
|
||||||
|
)
|
||||||
|
|
||||||
# ready the program
|
# ready the program
|
||||||
prg = '\n'.join(cprog)
|
prg = "\n".join(cprog)
|
||||||
print(prg)
|
print(prg)
|
||||||
|
|
||||||
# add test weights
|
# add test weights
|
||||||
subprocess.check_output(['clang', '-O2', '-lm', '-fPIC', '-x', 'c', '-', '-o', "/tmp/tf_test"], input=prg.encode('utf-8'))
|
subprocess.check_output(
|
||||||
|
["clang", "-O2", "-lm", "-fPIC", "-x", "c", "-", "-o", "/tmp/tf_test"],
|
||||||
|
input=prg.encode("utf-8"),
|
||||||
|
)
|
||||||
|
|
||||||
tinygrad_output = [x for x in the_output.numpy()[0]]
|
tinygrad_output = [x for x in the_output.numpy()[0]]
|
||||||
print("tinygrad:", tinygrad_output, file=sys.stderr)
|
print("tinygrad:", tinygrad_output, file=sys.stderr)
|
||||||
|
|
||||||
c_input = ' '.join(["%f" % x for x in the_input[0].numpy()])+"\n"
|
c_input = " ".join(["%f" % x for x in the_input[0].numpy()]) + "\n"
|
||||||
c_output = [float(x) for x in subprocess.check_output(["/tmp/tf_test"], input=c_input.encode('utf-8')).decode('utf-8').strip().split(" ")]
|
c_output = [
|
||||||
|
float(x)
|
||||||
|
for x in subprocess.check_output(
|
||||||
|
["/tmp/tf_test"], input=c_input.encode("utf-8")
|
||||||
|
)
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
.split(" ")
|
||||||
|
]
|
||||||
print("compiled:", c_output, file=sys.stderr)
|
print("compiled:", c_output, file=sys.stderr)
|
||||||
|
|
||||||
np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5)
|
np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5)
|
||||||
return the_input.numpy(), c_output
|
return the_input.numpy(), c_output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
keras_model = get_uncompiled_model2()
|
keras_model = get_uncompiled_model2()
|
||||||
onnx_model = create_onnx_model(keras_model)
|
onnx_model = create_onnx_model(keras_model)
|
||||||
|
@ -101,4 +126,3 @@ if __name__ == "__main__":
|
||||||
tf_output = keras_model(test_input).numpy()[0]
|
tf_output = keras_model(test_input).numpy()[0]
|
||||||
print("keras: ", tf_output, file=sys.stderr)
|
print("keras: ", tf_output, file=sys.stderr)
|
||||||
np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)
|
np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,14 @@ import pyaudio
|
||||||
import yaml
|
import yaml
|
||||||
from llama import LLaMa
|
from llama import LLaMa
|
||||||
from vits import MODELS as VITS_MODELS
|
from vits import MODELS as VITS_MODELS
|
||||||
from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
|
from vits import (
|
||||||
|
Y_LENGTH_ESTIMATE_SCALARS,
|
||||||
|
HParams,
|
||||||
|
Synthesizer,
|
||||||
|
TextMapper,
|
||||||
|
get_hparams_from_file,
|
||||||
|
load_model,
|
||||||
|
)
|
||||||
from whisper import init_whisper, transcribe_waveform
|
from whisper import init_whisper, transcribe_waveform
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
|
||||||
|
@ -29,16 +36,26 @@ IM_END = 32002
|
||||||
|
|
||||||
|
|
||||||
# Functions for encoding prompts to chatml md
|
# Functions for encoding prompts to chatml md
|
||||||
def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
def encode_prompt(spp, k, v):
|
||||||
def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
|
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def start_prompt(spp, k):
|
||||||
|
return [IM_START] + spp.encode(f"{k}\n")
|
||||||
|
|
||||||
|
|
||||||
def chunks(lst, n):
|
def chunks(lst, n):
|
||||||
for i in range(0, len(lst), n): yield lst[i:i + n]
|
for i in range(0, len(lst), n):
|
||||||
|
yield lst[i : i + n]
|
||||||
|
|
||||||
|
|
||||||
def create_fixed_tokenizer():
|
def create_fixed_tokenizer():
|
||||||
"""Function needed for extending tokenizer with additional chat tokens"""
|
"""Function needed for extending tokenizer with additional chat tokens"""
|
||||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||||
tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
|
|
||||||
|
tokenizer_path = fetch(
|
||||||
|
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model"
|
||||||
|
)
|
||||||
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
|
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
|
||||||
print("creating fixed tokenizer")
|
print("creating fixed tokenizer")
|
||||||
mp = spb2.ModelProto()
|
mp = spb2.ModelProto()
|
||||||
|
@ -50,16 +67,28 @@ def create_fixed_tokenizer():
|
||||||
tokenizer_path.write_bytes(mp.SerializeToString())
|
tokenizer_path.write_bytes(mp.SerializeToString())
|
||||||
return tokenizer_path
|
return tokenizer_path
|
||||||
|
|
||||||
def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
|
|
||||||
|
def llama_prepare(
|
||||||
|
llama: LLaMa, temperature: float, pre_prompt_path: Path
|
||||||
|
) -> tuple[list[int], str, str, str]:
|
||||||
"""Prepares a llama model from a specified pre-prompt file"""
|
"""Prepares a llama model from a specified pre-prompt file"""
|
||||||
with open(str(pre_prompt_path)) as f:
|
with open(str(pre_prompt_path)) as f:
|
||||||
config = yaml.safe_load(f.read())
|
config = yaml.safe_load(f.read())
|
||||||
toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
|
toks = [llama.tokenizer.bos_id()] + encode_prompt(
|
||||||
|
llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ")
|
||||||
|
)
|
||||||
for i in config["examples"]:
|
for i in config["examples"]:
|
||||||
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
|
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
|
||||||
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
|
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
|
||||||
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
|
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
|
||||||
return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
|
return (
|
||||||
|
toks,
|
||||||
|
config["user_delim"],
|
||||||
|
config["resp_delim"],
|
||||||
|
len(toks),
|
||||||
|
llama.tokenizer.decode(toks),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_generate(
|
def llama_generate(
|
||||||
llama: LLaMa,
|
llama: LLaMa,
|
||||||
|
@ -70,7 +99,7 @@ def llama_generate(
|
||||||
user_delim: str,
|
user_delim: str,
|
||||||
resp_delim: str,
|
resp_delim: str,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=1000
|
max_tokens=1000,
|
||||||
):
|
):
|
||||||
"""Generates an output for the specified prompt"""
|
"""Generates an output for the specified prompt"""
|
||||||
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
|
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
|
||||||
|
@ -79,7 +108,9 @@ def llama_generate(
|
||||||
outputted = llama.tokenizer.decode(toks)
|
outputted = llama.tokenizer.decode(toks)
|
||||||
init_length = len(outputted)
|
init_length = len(outputted)
|
||||||
for _ in range(max_tokens):
|
for _ in range(max_tokens):
|
||||||
probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy()
|
probs_np = llama.model(
|
||||||
|
Tensor([toks[start_pos:]]), start_pos, temperature
|
||||||
|
).numpy()
|
||||||
token = int(np.random.choice(len(probs_np), p=probs_np))
|
token = int(np.random.choice(len(probs_np), p=probs_np))
|
||||||
start_pos = len(toks)
|
start_pos = len(toks)
|
||||||
toks.append(token)
|
toks.append(token)
|
||||||
|
@ -90,12 +121,14 @@ def llama_generate(
|
||||||
sys.stdout.write(cur[len(outputted) :])
|
sys.stdout.write(cur[len(outputted) :])
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
outputted = cur
|
outputted = cur
|
||||||
if toks[-1] == IM_END: break
|
if toks[-1] == IM_END:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
toks.append(IM_END)
|
toks.append(IM_END)
|
||||||
print() # because the output is flushed
|
print() # because the output is flushed
|
||||||
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
|
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
|
||||||
|
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
text_to_synthesize: str,
|
text_to_synthesize: str,
|
||||||
synth: Synthesizer,
|
synth: Synthesizer,
|
||||||
|
@ -110,24 +143,45 @@ def tts(
|
||||||
text_mapper: TextMapper,
|
text_mapper: TextMapper,
|
||||||
model_has_multiple_speakers: bool,
|
model_has_multiple_speakers: bool,
|
||||||
batch_size=600,
|
batch_size=600,
|
||||||
vits_batch_size=1000
|
vits_batch_size=1000,
|
||||||
):
|
):
|
||||||
if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
if model_to_use == "mmts-tts":
|
||||||
|
text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||||
|
|
||||||
# Convert the input text to a tensor.
|
# Convert the input text to a tensor.
|
||||||
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
|
stn_tst = text_mapper.get_text(
|
||||||
|
text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners
|
||||||
|
)
|
||||||
init_shape = stn_tst.shape
|
init_shape = stn_tst.shape
|
||||||
assert init_shape[0] < batch_size, "text is too long"
|
assert init_shape[0] < batch_size, "text is too long"
|
||||||
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
|
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(
|
||||||
sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
|
0
|
||||||
|
), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||||
|
sid = (
|
||||||
|
Tensor([speaker_id], dtype=dtypes.int64)
|
||||||
|
if model_has_multiple_speakers
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Perform inference.
|
# Perform inference.
|
||||||
audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
|
audio_tensor = synth.infer(
|
||||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0]
|
x_tst,
|
||||||
|
x_tst_lengths,
|
||||||
|
sid,
|
||||||
|
noise_scale,
|
||||||
|
length_scale,
|
||||||
|
noise_scale_w,
|
||||||
|
emotion_embedding=emotion_embedding,
|
||||||
|
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use]
|
||||||
|
if estimate_max_y_length
|
||||||
|
else None,
|
||||||
|
batch_size=vits_batch_size,
|
||||||
|
)[0, 0]
|
||||||
# Save the audio output.
|
# Save the audio output.
|
||||||
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
||||||
return audio_data
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
def init_vits(
|
def init_vits(
|
||||||
model_to_use: str,
|
model_to_use: str,
|
||||||
emotion_path: Path,
|
emotion_path: Path,
|
||||||
|
@ -142,21 +196,44 @@ def init_vits(
|
||||||
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
||||||
model_has_multiple_speakers = hps.data.n_speakers > 0
|
model_has_multiple_speakers = hps.data.n_speakers > 0
|
||||||
if model_has_multiple_speakers:
|
if model_has_multiple_speakers:
|
||||||
if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
|
if speaker_id >= hps.data.n_speakers:
|
||||||
|
raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
|
||||||
if hps.__contains__("speakers"): # maps speaker ids to names
|
if hps.__contains__("speakers"): # maps speaker ids to names
|
||||||
speakers = hps.speakers
|
speakers = hps.speakers
|
||||||
if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
if isinstance(speakers, list):
|
||||||
|
speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
||||||
|
|
||||||
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
|
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
|
||||||
emotion_embedding = None
|
emotion_embedding = None
|
||||||
if emotion_path is not None:
|
if emotion_path is not None:
|
||||||
if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
|
if emotion_path.endswith(".npy"):
|
||||||
else: raise ValueError("Emotion path must be a .npy file.")
|
emotion_embedding = Tensor(
|
||||||
|
np.load(emotion_path), dtype=dtypes.int64
|
||||||
|
).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise ValueError("Emotion path must be a .npy file.")
|
||||||
|
|
||||||
# Load symbols, instantiate TextMapper and clean the text.
|
# Load symbols, instantiate TextMapper and clean the text.
|
||||||
if hps.__contains__("symbols"): symbols = hps.symbols
|
if hps.__contains__("symbols"):
|
||||||
elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
|
symbols = hps.symbols
|
||||||
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
|
elif model_to_use == "mmts-tts":
|
||||||
|
symbols = [
|
||||||
|
x.replace("\n", "")
|
||||||
|
for x in fetch(
|
||||||
|
"https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt"
|
||||||
|
)
|
||||||
|
.open(encoding="utf-8")
|
||||||
|
.readlines()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
symbols = (
|
||||||
|
["_"]
|
||||||
|
+ list(';:,.!?¡¿—…"«»“” ')
|
||||||
|
+ list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")
|
||||||
|
+ list(
|
||||||
|
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||||
|
)
|
||||||
|
)
|
||||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||||
|
|
||||||
# Load the model.
|
# Load the model.
|
||||||
|
@ -168,18 +245,23 @@ def init_vits(
|
||||||
|
|
||||||
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
|
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def output_stream(num_channels: int, sample_rate: int):
|
def output_stream(num_channels: int, sample_rate: int):
|
||||||
try:
|
try:
|
||||||
p = pyaudio.PyAudio()
|
p = pyaudio.PyAudio()
|
||||||
stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
|
stream = p.open(
|
||||||
|
format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True
|
||||||
|
)
|
||||||
yield stream
|
yield stream
|
||||||
except KeyboardInterrupt: pass
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
stream.stop_stream()
|
stream.stop_stream()
|
||||||
stream.close()
|
stream.close()
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def log_writer():
|
def log_writer():
|
||||||
try:
|
try:
|
||||||
|
@ -191,10 +273,17 @@ def log_writer():
|
||||||
print(*logs, sep="\n")
|
print(*logs, sep="\n")
|
||||||
print(sep)
|
print(sep)
|
||||||
|
|
||||||
|
|
||||||
def listener(q: mp.Queue, event: mp.Event):
|
def listener(q: mp.Queue, event: mp.Event):
|
||||||
try:
|
try:
|
||||||
p = pyaudio.PyAudio()
|
p = pyaudio.PyAudio()
|
||||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
stream = p.open(
|
||||||
|
format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=RATE,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=CHUNK,
|
||||||
|
)
|
||||||
did_print = False
|
did_print = False
|
||||||
while True:
|
while True:
|
||||||
data = stream.read(CHUNK) # read data to avoid overflow
|
data = stream.read(CHUNK) # read data to avoid overflow
|
||||||
|
@ -210,7 +299,10 @@ def listener(q: mp.Queue, event: mp.Event):
|
||||||
stream.close()
|
stream.close()
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
|
|
||||||
|
def mp_output_stream(
|
||||||
|
q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int
|
||||||
|
):
|
||||||
with output_stream(num_channels, sample_rate) as stream:
|
with output_stream(num_channels, sample_rate) as stream:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -219,8 +311,10 @@ def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_r
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
nltk.download("punkt")
|
nltk.download("punkt")
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
# Parse CLI arguments
|
# Parse CLI arguments
|
||||||
|
@ -230,75 +324,212 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
|
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
|
||||||
|
|
||||||
# LLAMA args
|
# LLAMA args
|
||||||
parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
|
parser.add_argument(
|
||||||
parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
|
"--llama_pre_prompt_path",
|
||||||
parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
|
type=Path,
|
||||||
parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory")
|
default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml",
|
||||||
parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
help="Path to yaml file which contains all pre-prompt data needed. ",
|
||||||
parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
|
)
|
||||||
parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
|
parser.add_argument(
|
||||||
parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
|
"--llama_count", type=int, default=1000, help="Max number of tokens to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama_temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help="Temperature in the softmax",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama_quantize",
|
||||||
|
action="store_true",
|
||||||
|
help="Quantize the weights to int8 in memory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama_model",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama_gen",
|
||||||
|
type=str,
|
||||||
|
default="tiny",
|
||||||
|
required=False,
|
||||||
|
help="Generation of the model to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama_size",
|
||||||
|
type=str,
|
||||||
|
default="1B-Chat",
|
||||||
|
required=False,
|
||||||
|
help="Size of model to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llama_tokenizer",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="Path to llama tokenizer.model",
|
||||||
|
)
|
||||||
|
|
||||||
# vits args
|
# vits args
|
||||||
parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
|
parser.add_argument(
|
||||||
parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
|
"--vits_model_to_use",
|
||||||
parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
|
default="vctk",
|
||||||
parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
|
help="Specify the model to use. Default is 'vctk'.",
|
||||||
parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
|
)
|
||||||
parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
|
parser.add_argument(
|
||||||
parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
|
"--vits_speaker_id",
|
||||||
parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
|
type=int,
|
||||||
parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
|
default=12,
|
||||||
parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
|
help="Specify the speaker ID. Default is 6.",
|
||||||
parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_noise_scale",
|
||||||
|
type=float,
|
||||||
|
default=0.667,
|
||||||
|
help="Specify the noise scale. Default is 0.667.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_noise_scale_w",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="Specify the noise scale w. Default is 0.8.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_length_scale",
|
||||||
|
type=float,
|
||||||
|
default=1,
|
||||||
|
help="Specify the length scale. Default is 1.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Specify the seed (set to None if no seed). Default is 1337.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_num_channels",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Specify the number of audio output channels. Default is 1.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_sample_width",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Specify the number of bytes per sample, adjust if necessary. Default is 2.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_emotion_path",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Specify the path to emotion reference.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_estimate_max_y_length",
|
||||||
|
type=str,
|
||||||
|
default=False,
|
||||||
|
help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary."
|
||||||
|
)
|
||||||
|
|
||||||
# conversation args
|
# conversation args
|
||||||
parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
|
parser.add_argument(
|
||||||
|
"--max_sentence_length",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Max words in one sentence to pass to vits",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Init models
|
# Init models
|
||||||
model, enc = init_whisper(args.whisper_model_name)
|
model, enc = init_whisper(args.whisper_model_name)
|
||||||
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
|
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(
|
||||||
|
args.vits_model_to_use,
|
||||||
|
args.vits_emotion_path,
|
||||||
|
args.vits_speaker_id,
|
||||||
|
args.vits_seed,
|
||||||
|
)
|
||||||
|
|
||||||
# Download tinyllama chat as a default model
|
# Download tinyllama chat as a default model
|
||||||
if args.llama_model is None:
|
if args.llama_model is None:
|
||||||
args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
|
args.llama_model = fetch(
|
||||||
|
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors",
|
||||||
|
"tinyllamachat.safetensors",
|
||||||
|
)
|
||||||
args.llama_gen = "tiny"
|
args.llama_gen = "tiny"
|
||||||
args.llama_size = "1B-Chat"
|
args.llama_size = "1B-Chat"
|
||||||
# Add 3 more tokens to the tokenizer
|
# Add 3 more tokens to the tokenizer
|
||||||
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
|
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"):
|
||||||
|
args.llama_tokenizer = create_fixed_tokenizer()
|
||||||
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
|
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
|
||||||
llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
|
llama = LLaMa.build(
|
||||||
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
|
args.llama_model,
|
||||||
|
tokenizer_path,
|
||||||
|
args.llama_gen,
|
||||||
|
args.llama_size,
|
||||||
|
args.llama_quantize,
|
||||||
|
)
|
||||||
|
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(
|
||||||
|
llama, args.llama_temperature, args.llama_pre_prompt_path
|
||||||
|
)
|
||||||
|
|
||||||
# Start child process for mic input
|
# Start child process for mic input
|
||||||
q = mp.Queue()
|
q = mp.Queue()
|
||||||
is_listening_event = mp.Event()
|
is_listening_event = mp.Event()
|
||||||
p = mp.Process(target=listener, args=(q, is_listening_event,))
|
p = mp.Process(
|
||||||
|
target=listener,
|
||||||
|
args=(
|
||||||
|
q,
|
||||||
|
is_listening_event,
|
||||||
|
),
|
||||||
|
)
|
||||||
p.daemon = True
|
p.daemon = True
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
# Start child process for speaker output
|
# Start child process for speaker output
|
||||||
out_q = mp.Queue()
|
out_q = mp.Queue()
|
||||||
out_counter = mp.Value("i", 0)
|
out_counter = mp.Value("i", 0)
|
||||||
out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
|
out_p = mp.Process(
|
||||||
|
target=mp_output_stream,
|
||||||
|
args=(
|
||||||
|
out_q,
|
||||||
|
out_counter,
|
||||||
|
args.vits_num_channels,
|
||||||
|
hps.data.sampling_rate,
|
||||||
|
),
|
||||||
|
)
|
||||||
out_p.daemon = True
|
out_p.daemon = True
|
||||||
out_p.start()
|
out_p.start()
|
||||||
|
|
||||||
# JIT tts
|
# JIT tts
|
||||||
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
|
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
|
||||||
tts(
|
tts(
|
||||||
i, synth, hps, emotion_embedding,
|
i,
|
||||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
synth,
|
||||||
args.vits_noise_scale_w, args.vits_length_scale,
|
hps,
|
||||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
emotion_embedding,
|
||||||
|
args.vits_speaker_id,
|
||||||
|
args.vits_model_to_use,
|
||||||
|
args.vits_noise_scale,
|
||||||
|
args.vits_noise_scale_w,
|
||||||
|
args.vits_length_scale,
|
||||||
|
args.vits_estimate_max_y_length,
|
||||||
|
text_mapper,
|
||||||
|
model_has_multiple_speakers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start the pipeline
|
# Start the pipeline
|
||||||
with log_writer() as log:
|
with log_writer() as log:
|
||||||
while True:
|
while True:
|
||||||
tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
tokens = [
|
||||||
|
enc._special_tokens["<|startoftranscript|>"],
|
||||||
|
enc._special_tokens["<|notimestamps|>"],
|
||||||
|
]
|
||||||
total = np.array([])
|
total = np.array([])
|
||||||
out_counter.value = 0
|
out_counter.value = 0
|
||||||
|
|
||||||
|
@ -306,10 +537,12 @@ if __name__ == "__main__":
|
||||||
is_listening_event.set()
|
is_listening_event.set()
|
||||||
prev_text = None
|
prev_text = None
|
||||||
while True:
|
while True:
|
||||||
for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
|
for _ in range(RATE // CHUNK):
|
||||||
|
total = np.concatenate([total, q.get()])
|
||||||
txt = transcribe_waveform(model, enc, [total], truncate=True)
|
txt = transcribe_waveform(model, enc, [total], truncate=True)
|
||||||
print(txt, end="\r")
|
print(txt, end="\r")
|
||||||
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
|
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()):
|
||||||
|
continue
|
||||||
if prev_text is not None and prev_text == txt:
|
if prev_text is not None and prev_text == txt:
|
||||||
is_listening_event.clear()
|
is_listening_event.clear()
|
||||||
break
|
break
|
||||||
|
@ -320,9 +553,15 @@ if __name__ == "__main__":
|
||||||
# Generate with llama
|
# Generate with llama
|
||||||
with Timing("llama generation: "):
|
with Timing("llama generation: "):
|
||||||
outputted, start_pos, response = llama_generate(
|
outputted, start_pos, response = llama_generate(
|
||||||
llama, toks, outputted, txt, start_pos,
|
llama,
|
||||||
user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
|
toks,
|
||||||
max_tokens=args.llama_count
|
outputted,
|
||||||
|
txt,
|
||||||
|
start_pos,
|
||||||
|
user_delim=user_delim,
|
||||||
|
resp_delim=resp_delim,
|
||||||
|
temperature=args.llama_temperature,
|
||||||
|
max_tokens=args.llama_count,
|
||||||
)
|
)
|
||||||
log.append(f"{resp_delim.capitalize()}: {response}")
|
log.append(f"{resp_delim.capitalize()}: {response}")
|
||||||
|
|
||||||
|
@ -333,12 +572,21 @@ if __name__ == "__main__":
|
||||||
total = np.array([], dtype=np.int16)
|
total = np.array([], dtype=np.int16)
|
||||||
for j in chunks(i.split(), args.max_sentence_length):
|
for j in chunks(i.split(), args.max_sentence_length):
|
||||||
audio_data = tts(
|
audio_data = tts(
|
||||||
" ".join(j), synth, hps, emotion_embedding,
|
" ".join(j),
|
||||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
synth,
|
||||||
args.vits_noise_scale_w, args.vits_length_scale,
|
hps,
|
||||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
emotion_embedding,
|
||||||
|
args.vits_speaker_id,
|
||||||
|
args.vits_model_to_use,
|
||||||
|
args.vits_noise_scale,
|
||||||
|
args.vits_noise_scale_w,
|
||||||
|
args.vits_length_scale,
|
||||||
|
args.vits_estimate_max_y_length,
|
||||||
|
text_mapper,
|
||||||
|
model_has_multiple_speakers,
|
||||||
)
|
)
|
||||||
total = np.concatenate([total, audio_data])
|
total = np.concatenate([total, audio_data])
|
||||||
out_q.put(total.tobytes())
|
out_q.put(total.tobytes())
|
||||||
while out_counter.value < len(sentences): continue
|
while out_counter.value < len(sentences):
|
||||||
|
continue
|
||||||
log.append(f"Total: {time.perf_counter() - s}")
|
log.append(f"Total: {time.perf_counter() - s}")
|
||||||
|
|
|
@ -11,12 +11,14 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import getenv, fetch, Timing
|
from tinygrad.helpers import getenv, fetch, Timing
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
from extra.models.efficientnet import EfficientNet
|
from extra.models.efficientnet import EfficientNet
|
||||||
|
|
||||||
np.set_printoptions(suppress=True)
|
np.set_printoptions(suppress=True)
|
||||||
|
|
||||||
# TODO: you should be able to put these in the jitted function
|
# TODO: you should be able to put these in the jitted function
|
||||||
bias = Tensor([0.485, 0.456, 0.406])
|
bias = Tensor([0.485, 0.456, 0.406])
|
||||||
scale = Tensor([0.229, 0.224, 0.225])
|
scale = Tensor([0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def _infer(model, img):
|
def _infer(model, img):
|
||||||
img = img.permute((2, 0, 1))
|
img = img.permute((2, 0, 1))
|
||||||
|
@ -25,10 +27,13 @@ def _infer(model, img):
|
||||||
img = img / scale.reshape((1, -1, 1, 1))
|
img = img / scale.reshape((1, -1, 1, 1))
|
||||||
return model.forward(img).realize()
|
return model.forward(img).realize()
|
||||||
|
|
||||||
|
|
||||||
def infer(model, img):
|
def infer(model, img):
|
||||||
# preprocess image
|
# preprocess image
|
||||||
aspect_ratio = img.size[0] / img.size[1]
|
aspect_ratio = img.size[0] / img.size[1]
|
||||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
img = img.resize(
|
||||||
|
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
|
||||||
|
)
|
||||||
|
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
||||||
|
@ -52,18 +57,28 @@ def infer(model, img):
|
||||||
"""
|
"""
|
||||||
return out, retimg
|
return out, retimg
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# instantiate my net
|
# instantiate my net
|
||||||
model = EfficientNet(getenv("NUM", 0))
|
model = EfficientNet(getenv("NUM", 0))
|
||||||
model.load_from_pretrained()
|
model.load_from_pretrained()
|
||||||
|
|
||||||
# category labels
|
# category labels
|
||||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
lbls = ast.literal_eval(
|
||||||
|
fetch(
|
||||||
|
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||||
|
).read_text()
|
||||||
|
)
|
||||||
|
|
||||||
# load image and preprocess
|
# load image and preprocess
|
||||||
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
url = (
|
||||||
if url == 'webcam':
|
sys.argv[1]
|
||||||
|
if len(sys.argv) >= 2
|
||||||
|
else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
||||||
|
)
|
||||||
|
if url == "webcam":
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||||
while 1:
|
while 1:
|
||||||
|
@ -72,12 +87,17 @@ if __name__ == "__main__":
|
||||||
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
||||||
lt = time.monotonic_ns()
|
lt = time.monotonic_ns()
|
||||||
out, retimg = infer(model, img)
|
out, retimg = infer(model, img)
|
||||||
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
print(
|
||||||
|
f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms",
|
||||||
|
np.argmax(out),
|
||||||
|
np.max(out),
|
||||||
|
lbls[np.argmax(out)],
|
||||||
|
)
|
||||||
SCALE = 3
|
SCALE = 3
|
||||||
simg = cv2.resize(retimg, (224 * SCALE, 224 * SCALE))
|
simg = cv2.resize(retimg, (224 * SCALE, 224 * SCALE))
|
||||||
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
||||||
cv2.imshow('capture', retimg)
|
cv2.imshow("capture", retimg)
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
break
|
break
|
||||||
cap.release()
|
cap.release()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
|
|
@ -3,6 +3,7 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import dtypes
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
|
|
||||||
|
|
||||||
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
|
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
|
||||||
def bit_extract(x, s, e) -> Tensor:
|
def bit_extract(x, s, e) -> Tensor:
|
||||||
# extract the top bits we don't want
|
# extract the top bits we don't want
|
||||||
|
@ -10,11 +11,16 @@ def bit_extract(x, s, e) -> Tensor:
|
||||||
x = (x - top_bits) / (1 << e)
|
x = (x - top_bits) / (1 << e)
|
||||||
return x.contiguous()
|
return x.contiguous()
|
||||||
|
|
||||||
|
|
||||||
def u16_to_f16(x):
|
def u16_to_f16(x):
|
||||||
sign = bit_extract(x, 15, 15).float()
|
sign = bit_extract(x, 15, 15).float()
|
||||||
exponent = bit_extract(x, 14, 10).float()
|
exponent = bit_extract(x, 14, 10).float()
|
||||||
fraction = bit_extract(x, 9, 0).float()
|
fraction = bit_extract(x, 9, 0).float()
|
||||||
return sign.where(-1, 1) * exponent.where((exponent - 15).exp2() * (1 + fraction / 0x400), 6.103515625e-5 * (fraction / 0x400))
|
return sign.where(-1, 1) * exponent.where(
|
||||||
|
(exponent - 15).exp2() * (1 + fraction / 0x400),
|
||||||
|
6.103515625e-5 * (fraction / 0x400),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def u32_to_f16(oo):
|
def u32_to_f16(oo):
|
||||||
oo1 = (oo / 0x10000).floor().contiguous()
|
oo1 = (oo / 0x10000).floor().contiguous()
|
||||||
|
@ -24,6 +30,7 @@ def u32_to_f16(oo):
|
||||||
f2 = u16_to_f16(oo2)
|
f2 = u16_to_f16(oo2)
|
||||||
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
|
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# random float16
|
# random float16
|
||||||
Tensor.manual_seed(2)
|
Tensor.manual_seed(2)
|
||||||
|
|
222
examples/gpt2.py
222
examples/gpt2.py
|
@ -10,11 +10,20 @@ from tinygrad.shape.symbolic import Variable
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
|
from tinygrad.helpers import (
|
||||||
|
GlobalCounters,
|
||||||
|
Timing,
|
||||||
|
DEBUG,
|
||||||
|
getenv,
|
||||||
|
fetch,
|
||||||
|
colored,
|
||||||
|
dtypes,
|
||||||
|
)
|
||||||
|
|
||||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
||||||
HALF = getenv("HALF")
|
HALF = getenv("HALF")
|
||||||
|
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
def __init__(self, dim, n_heads):
|
def __init__(self, dim, n_heads):
|
||||||
self.c_attn = Linear(dim, 3 * dim, bias=True)
|
self.c_attn = Linear(dim, 3 * dim, bias=True)
|
||||||
|
@ -23,18 +32,27 @@ class Attention:
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
|
def __call__(
|
||||||
|
self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]
|
||||||
|
) -> Tensor:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# no symbolic shape qkv when consuming prompts
|
# no symbolic shape qkv when consuming prompts
|
||||||
start_pos = start_pos.val
|
start_pos = start_pos.val
|
||||||
|
|
||||||
xqkv = self.c_attn(x)
|
xqkv = self.c_attn(x)
|
||||||
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
|
xq, xk, xv = [
|
||||||
|
xqkv.shrink((None, None, (i * self.dim, (i + 1) * self.dim))).reshape(
|
||||||
|
xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||||
|
|
||||||
# create kv cache
|
# create kv cache
|
||||||
if not hasattr(self, "cache_k"):
|
if not hasattr(self, "cache_k"):
|
||||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
self.cache_k, self.cache_v = Tensor.zeros(
|
||||||
|
bsz, MAX_CONTEXT, self.n_heads, self.head_dim
|
||||||
|
), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||||
if HALF:
|
if HALF:
|
||||||
self.cache_k = self.cache_k.half()
|
self.cache_k = self.cache_k.half()
|
||||||
self.cache_v = self.cache_v.half()
|
self.cache_v = self.cache_v.half()
|
||||||
|
@ -43,11 +61,28 @@ class Attention:
|
||||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||||
|
|
||||||
# update the cache
|
# update the cache
|
||||||
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
self.cache_k.assign(
|
||||||
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
keys.pad(
|
||||||
|
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
|
||||||
|
).contiguous()
|
||||||
|
).realize()
|
||||||
|
self.cache_v.assign(
|
||||||
|
values.pad(
|
||||||
|
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
|
||||||
|
).contiguous()
|
||||||
|
).realize()
|
||||||
|
|
||||||
|
xq, keys, values = (
|
||||||
|
xq.transpose(1, 2),
|
||||||
|
keys.transpose(1, 2),
|
||||||
|
values.transpose(1, 2),
|
||||||
|
)
|
||||||
|
return self.c_proj(
|
||||||
|
xq.scaled_dot_product_attention(keys, values, mask)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(bsz, seqlen, -1)
|
||||||
|
)
|
||||||
|
|
||||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
|
||||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
|
|
||||||
|
|
||||||
class FeedForward:
|
class FeedForward:
|
||||||
def __init__(self, dim, hidden_dim):
|
def __init__(self, dim, hidden_dim):
|
||||||
|
@ -57,6 +92,7 @@ class FeedForward:
|
||||||
def __call__(self, x: Tensor) -> Tensor:
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
return self.c_proj(self.c_fc(x).gelu())
|
return self.c_proj(self.c_fc(x).gelu())
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock:
|
class TransformerBlock:
|
||||||
def __init__(self, dim, n_heads, norm_eps):
|
def __init__(self, dim, n_heads, norm_eps):
|
||||||
self.attn = Attention(dim, n_heads)
|
self.attn = Attention(dim, n_heads)
|
||||||
|
@ -66,7 +102,8 @@ class TransformerBlock:
|
||||||
|
|
||||||
def __call__(self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]):
|
def __call__(self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]):
|
||||||
h = x + self.attn(self.ln_1(x), start_pos, mask)
|
h = x + self.attn(self.ln_1(x), start_pos, mask)
|
||||||
return (h + self.mlp(self.ln_2(h)))
|
return h + self.mlp(self.ln_2(h))
|
||||||
|
|
||||||
|
|
||||||
class Transformer:
|
class Transformer:
|
||||||
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
||||||
|
@ -78,7 +115,8 @@ class Transformer:
|
||||||
self.forward_jit = TinyJit(self.forward)
|
self.forward_jit = TinyJit(self.forward)
|
||||||
|
|
||||||
def forward(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0):
|
def forward(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0):
|
||||||
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
if not hasattr(self, "allpos"):
|
||||||
|
self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||||
_bsz, seqlen = tokens.shape
|
_bsz, seqlen = tokens.shape
|
||||||
|
|
||||||
# NOTE: cannot convert token indices into half due to precision
|
# NOTE: cannot convert token indices into half due to precision
|
||||||
|
@ -86,44 +124,73 @@ class Transformer:
|
||||||
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos + seqlen))))
|
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos + seqlen))))
|
||||||
h = tok_emb + pos_emb
|
h = tok_emb + pos_emb
|
||||||
|
|
||||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
mask = (
|
||||||
|
Tensor.full((1, 1, seqlen, start_pos.val + seqlen), float("-inf"))
|
||||||
|
.triu(start_pos.val + 1)
|
||||||
|
.realize()
|
||||||
|
if seqlen > 1
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if HALF:
|
if HALF:
|
||||||
h = h.half()
|
h = h.half()
|
||||||
if mask is not None: mask = mask.half()
|
if mask is not None:
|
||||||
|
mask = mask.half()
|
||||||
|
|
||||||
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
|
for hi in self.h:
|
||||||
|
h = hi(h, start_pos=start_pos, mask=mask)
|
||||||
|
|
||||||
logits = self.lm_head(self.ln_f(h))
|
logits = self.lm_head(self.ln_f(h))
|
||||||
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
||||||
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().realize()
|
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().realize()
|
||||||
|
|
||||||
# TODO: fix empty token
|
# TODO: fix empty token
|
||||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
def __call__(
|
||||||
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
|
self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0
|
||||||
|
) -> Tensor:
|
||||||
|
return (
|
||||||
|
self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward
|
||||||
|
)(tokens, start_pos, temperature)
|
||||||
|
|
||||||
|
|
||||||
VOCAB_SIZE = 50257
|
VOCAB_SIZE = 50257
|
||||||
MODEL_PARAMS = {
|
MODEL_PARAMS = {
|
||||||
'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params
|
"gpt2": dict(
|
||||||
'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params
|
n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||||
'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params
|
), # 124M params
|
||||||
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
|
"gpt2-medium": dict(
|
||||||
|
n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||||
|
), # 350M params
|
||||||
|
"gpt2-large": dict(
|
||||||
|
n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||||
|
), # 774M params
|
||||||
|
"gpt2-xl": dict(
|
||||||
|
n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||||
|
), # 1558M params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class GPT2:
|
class GPT2:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(model_size="gpt2"):
|
def build(model_size="gpt2"):
|
||||||
tokenizer = tiktoken.get_encoding("gpt2")
|
tokenizer = tiktoken.get_encoding("gpt2")
|
||||||
|
|
||||||
model = Transformer(**MODEL_PARAMS[model_size])
|
model = Transformer(**MODEL_PARAMS[model_size])
|
||||||
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
|
weights = torch_load(
|
||||||
|
fetch(f"https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin")
|
||||||
|
)
|
||||||
# special treatment for the Conv1D weights we need to transpose
|
# special treatment for the Conv1D weights we need to transpose
|
||||||
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
transposed = [
|
||||||
|
"attn.c_attn.weight",
|
||||||
|
"attn.c_proj.weight",
|
||||||
|
"mlp.c_fc.weight",
|
||||||
|
"mlp.c_proj.weight",
|
||||||
|
]
|
||||||
for k in weights.keys():
|
for k in weights.keys():
|
||||||
if any(k.endswith(w) for w in transposed):
|
if any(k.endswith(w) for w in transposed):
|
||||||
weights[k] = Tensor(weights[k].numpy().T)
|
weights[k] = Tensor(weights[k].numpy().T)
|
||||||
# lm head and wte are tied
|
# lm head and wte are tied
|
||||||
weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy())
|
weights["lm_head.weight"] = Tensor(weights["wte.weight"].numpy())
|
||||||
|
|
||||||
load_state_dict(model, weights)
|
load_state_dict(model, weights)
|
||||||
return GPT2(model, tokenizer)
|
return GPT2(model, tokenizer)
|
||||||
|
@ -132,42 +199,98 @@ class GPT2:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
def greedy_until(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_length: int,
|
||||||
|
temperature: float,
|
||||||
|
timing: bool = False,
|
||||||
|
batch_size: int = 1,
|
||||||
|
):
|
||||||
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
||||||
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
for _ in trange(max_length, disable=(timing == True)):
|
for _ in trange(max_length, disable=(timing == True)):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
if timing: print("")
|
if timing:
|
||||||
|
print("")
|
||||||
st = GlobalCounters.time_sum_s
|
st = GlobalCounters.time_sum_s
|
||||||
with Timing("total ", enabled=timing):
|
with Timing("total ", enabled=timing):
|
||||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
with Timing(
|
||||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
"ran model in ",
|
||||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
on_exit=(
|
||||||
probs = self.model(Tensor([x[start_pos:] for x in toks]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
|
lambda et: (
|
||||||
|
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
|
||||||
|
if DEBUG >= 2
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
|
||||||
|
+ (
|
||||||
|
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s"
|
||||||
|
if DEBUG >= 2
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if DEBUG
|
||||||
|
else None,
|
||||||
|
enabled=timing,
|
||||||
|
):
|
||||||
|
probs = self.model(
|
||||||
|
Tensor([x[start_pos:] for x in toks]),
|
||||||
|
Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(
|
||||||
|
start_pos
|
||||||
|
),
|
||||||
|
temperature,
|
||||||
|
)
|
||||||
# TODO: fix JIT rand so we can put this in the JIT
|
# TODO: fix JIT rand so we can put this in the JIT
|
||||||
tok = probs.multinomial().flatten().numpy().tolist()
|
tok = probs.multinomial().flatten().numpy().tolist()
|
||||||
start_pos = len(toks[0])
|
start_pos = len(toks[0])
|
||||||
for i,t in enumerate(tok): toks[i].append(t)
|
for i, t in enumerate(tok):
|
||||||
|
toks[i].append(t)
|
||||||
output = [self.tokenizer.decode(x) for x in toks]
|
output = [self.tokenizer.decode(x) for x in toks]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# **** main code ****
|
# **** main code ****
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
print(f"using {Device.DEFAULT} backend")
|
print(f"using {Device.DEFAULT} backend")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--prompt', type=str, default="What is the answer to life, the universe, and everything?", help="Phrase to start with")
|
description="Run GPT2 in tinygrad",
|
||||||
parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
|
)
|
||||||
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
|
parser.add_argument(
|
||||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
"--prompt",
|
||||||
parser.add_argument('--seed', type=int, help="Set the random seed")
|
type=str,
|
||||||
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
|
default="What is the answer to life, the universe, and everything?",
|
||||||
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
|
help="Phrase to start with",
|
||||||
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--count", type=int, default=100, help="Max number of tokens to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature", type=float, default=0.8, help="Temperature in the softmax"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_size",
|
||||||
|
type=str,
|
||||||
|
default="gpt2-medium",
|
||||||
|
help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]",
|
||||||
|
)
|
||||||
|
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||||
|
parser.add_argument("--seed", type=int, help="Set the random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", type=int, default=1, help="Set the input batch size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--benchmark",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Benchmark GPT with the given number of tokens",
|
||||||
|
)
|
||||||
|
parser.add_argument("--noshow", action="store_true", help="Don't show the output")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
|
@ -182,11 +305,22 @@ if __name__ == "__main__":
|
||||||
l.assign(l.cast(dtypes.float16).realize())
|
l.assign(l.cast(dtypes.float16).realize())
|
||||||
|
|
||||||
if args.benchmark != -1:
|
if args.benchmark != -1:
|
||||||
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
gpt2.model(
|
||||||
|
Tensor.rand(args.batch_size, args.benchmark),
|
||||||
|
Variable("a", 0, MAX_CONTEXT).bind(0),
|
||||||
|
).realize()
|
||||||
else:
|
else:
|
||||||
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
|
texts = gpt2.greedy_until(
|
||||||
|
args.prompt,
|
||||||
|
args.count,
|
||||||
|
args.temperature,
|
||||||
|
timing=args.timing,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
if not args.noshow:
|
if not args.noshow:
|
||||||
print('Generating text...')
|
print("Generating text...")
|
||||||
if len(texts) == 1: print(texts[0])
|
if len(texts) == 1:
|
||||||
|
print(texts[0])
|
||||||
else:
|
else:
|
||||||
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
|
for i, text in enumerate(texts):
|
||||||
|
print(colored(f"Response {i}:", "green"), text)
|
||||||
|
|
|
@ -28,7 +28,8 @@ if __name__ == "__main__":
|
||||||
sched = [x for x in sched if x.ast.op not in LoadOps]
|
sched = [x for x in sched if x.ast.op not in LoadOps]
|
||||||
|
|
||||||
# focus on one kernel
|
# focus on one kernel
|
||||||
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
|
if getenv("KERNEL", -1) >= 0:
|
||||||
|
sched = sched[getenv("KERNEL", -1) : getenv("KERNEL", -1) + 1]
|
||||||
|
|
||||||
# work with the schedule
|
# work with the schedule
|
||||||
total_tm = 0
|
total_tm = 0
|
||||||
|
@ -52,20 +53,33 @@ if __name__ == "__main__":
|
||||||
# try a beam search
|
# try a beam search
|
||||||
if getenv("BEAM"):
|
if getenv("BEAM"):
|
||||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||||
lin = beam_search(lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1)))
|
lin = beam_search(
|
||||||
|
lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1))
|
||||||
|
)
|
||||||
lins.append(lin)
|
lins.append(lin)
|
||||||
|
|
||||||
# benchmark the programs
|
# benchmark the programs
|
||||||
choices = []
|
choices = []
|
||||||
for lin in lins:
|
for lin in lins:
|
||||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||||
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
|
gflops = (
|
||||||
|
sym_infer(lin.info.flops, {k: k.min for k in vars_from_ast(lin.ast)})
|
||||||
|
* 1e-9
|
||||||
|
/ tm
|
||||||
|
)
|
||||||
choices.append((tm, gflops, lin.linearize()))
|
choices.append((tm, gflops, lin.linearize()))
|
||||||
|
|
||||||
# print all kernels
|
# print all kernels
|
||||||
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
if DEBUG >= 1:
|
||||||
|
print(
|
||||||
|
f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
|
||||||
|
)
|
||||||
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
|
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
|
||||||
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
print(
|
||||||
|
f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
|
||||||
|
)
|
||||||
total_tm += tm
|
total_tm += tm
|
||||||
running_gflops += gflops * tm
|
running_gflops += gflops * tm
|
||||||
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
|
print(
|
||||||
|
f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS"
|
||||||
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# setup for distributed
|
# setup for distributed
|
||||||
from extra import dist
|
from extra import dist
|
||||||
from tinygrad.helpers import getenv, dtypes
|
from tinygrad.helpers import getenv, dtypes
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if getenv("DIST"):
|
if getenv("DIST"):
|
||||||
dist.preinit()
|
dist.preinit()
|
||||||
|
@ -24,7 +25,7 @@ from tinygrad.shape.symbolic import Node
|
||||||
from extra.lr_scheduler import OneCycleLR
|
from extra.lr_scheduler import OneCycleLR
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
|
|
||||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv("EVAL_BS", 500), getenv("STEPS", 1000)
|
||||||
|
|
||||||
if getenv("HALF", 0):
|
if getenv("HALF", 0):
|
||||||
Tensor.default_type = dtypes.float16
|
Tensor.default_type = dtypes.float16
|
||||||
|
@ -33,16 +34,28 @@ else:
|
||||||
Tensor.default_type = dtypes.float32
|
Tensor.default_type = dtypes.float32
|
||||||
np_dtype = np.float32
|
np_dtype = np.float32
|
||||||
|
|
||||||
|
|
||||||
class BatchNorm(nn.BatchNorm2d):
|
class BatchNorm(nn.BatchNorm2d):
|
||||||
def __init__(self, num_features):
|
def __init__(self, num_features):
|
||||||
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
|
super().__init__(
|
||||||
|
num_features,
|
||||||
|
track_running_stats=False,
|
||||||
|
eps=1e-12,
|
||||||
|
momentum=0.85,
|
||||||
|
affine=True,
|
||||||
|
)
|
||||||
self.weight.requires_grad = False
|
self.weight.requires_grad = False
|
||||||
self.bias.requires_grad = True
|
self.bias.requires_grad = True
|
||||||
|
|
||||||
|
|
||||||
class ConvGroup:
|
class ConvGroup:
|
||||||
def __init__(self, channels_in, channels_out):
|
def __init__(self, channels_in, channels_out):
|
||||||
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
|
self.conv1 = nn.Conv2d(
|
||||||
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
|
channels_in, channels_out, kernel_size=3, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.conv2 = nn.Conv2d(
|
||||||
|
channels_out, channels_out, kernel_size=3, padding=1, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
self.norm1 = BatchNorm(channels_out)
|
self.norm1 = BatchNorm(channels_out)
|
||||||
self.norm2 = BatchNorm(channels_out)
|
self.norm2 = BatchNorm(channels_out)
|
||||||
|
@ -63,6 +76,7 @@ class ConvGroup:
|
||||||
|
|
||||||
return x + residual
|
return x + residual
|
||||||
|
|
||||||
|
|
||||||
class SpeedyResNet:
|
class SpeedyResNet:
|
||||||
def __init__(self, W):
|
def __init__(self, W):
|
||||||
self.whitening = W
|
self.whitening = W
|
||||||
|
@ -74,54 +88,58 @@ class SpeedyResNet:
|
||||||
ConvGroup(256, 512),
|
ConvGroup(256, 512),
|
||||||
lambda x: x.max((2, 3)),
|
lambda x: x.max((2, 3)),
|
||||||
nn.Linear(512, 10, bias=False),
|
nn.Linear(512, 10, bias=False),
|
||||||
lambda x: x.mul(1./9)
|
lambda x: x.mul(1.0 / 9),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, x, training=True):
|
def __call__(self, x, training=True):
|
||||||
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
|
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
|
||||||
# TODO: remove the pad but instead let the kernel optimizer itself
|
# TODO: remove the pad but instead let the kernel optimizer itself
|
||||||
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
|
forward = (
|
||||||
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
|
lambda x: x.conv2d(self.whitening).pad2d((1, 0, 0, 1)).sequential(self.net)
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
forward(x) if training else forward(x) * 0.5 + forward(x[..., ::-1]) * 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_cifar():
|
def train_cifar():
|
||||||
|
|
||||||
# hyper-parameters were exactly the same as the original repo
|
# hyper-parameters were exactly the same as the original repo
|
||||||
bias_scaler = 58
|
bias_scaler = 58
|
||||||
hyp: Dict[str, Any] = {
|
hyp: Dict[str, Any] = {
|
||||||
'seed' : 209,
|
"seed": 209,
|
||||||
'opt': {
|
"opt": {
|
||||||
'bias_lr': 1.76 * bias_scaler/512,
|
"bias_lr": 1.76 * bias_scaler / 512,
|
||||||
'non_bias_lr': 1.76 / 512,
|
"non_bias_lr": 1.76 / 512,
|
||||||
'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
|
"bias_decay": 1.08 * 6.45e-4 * BS / bias_scaler,
|
||||||
'non_bias_decay': 1.08 * 6.45e-4 * BS,
|
"non_bias_decay": 1.08 * 6.45e-4 * BS,
|
||||||
'final_lr_ratio': 0.025,
|
"final_lr_ratio": 0.025,
|
||||||
'initial_div_factor': 1e16,
|
"initial_div_factor": 1e16,
|
||||||
'label_smoothing': 0.20,
|
"label_smoothing": 0.20,
|
||||||
'momentum': 0.85,
|
"momentum": 0.85,
|
||||||
'percent_start': 0.23,
|
"percent_start": 0.23,
|
||||||
'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
"loss_scale_scaler": 1.0 / 128, # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
||||||
},
|
},
|
||||||
'net': {
|
"net": {
|
||||||
'kernel_size': 2, # kernel size for the whitening layer
|
"kernel_size": 2, # kernel size for the whitening layer
|
||||||
'cutmix_size': 3,
|
"cutmix_size": 3,
|
||||||
'cutmix_steps': 499,
|
"cutmix_steps": 499,
|
||||||
'pad_amount': 2
|
"pad_amount": 2,
|
||||||
|
},
|
||||||
|
"ema": {
|
||||||
|
"steps": 399,
|
||||||
|
"decay_base": 0.95,
|
||||||
|
"decay_pow": 1.6,
|
||||||
|
"every_n_steps": 5,
|
||||||
},
|
},
|
||||||
'ema': {
|
|
||||||
'steps': 399,
|
|
||||||
'decay_base': .95,
|
|
||||||
'decay_pow': 1.6,
|
|
||||||
'every_n_steps': 5,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_seed(seed):
|
def set_seed(seed):
|
||||||
Tensor.manual_seed(getenv('SEED', seed))
|
Tensor.manual_seed(getenv("SEED", seed))
|
||||||
random.seed(getenv('SEED', seed))
|
random.seed(getenv("SEED", seed))
|
||||||
|
|
||||||
# ========== Model ==========
|
# ========== Model ==========
|
||||||
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
|
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
|
||||||
def whitening(X, kernel_size=hyp['net']['kernel_size']):
|
def whitening(X, kernel_size=hyp["net"]["kernel_size"]):
|
||||||
def _cov(X):
|
def _cov(X):
|
||||||
X = X / np.sqrt(X.shape[0] - 1)
|
X = X / np.sqrt(X.shape[0] - 1)
|
||||||
return X.T @ X
|
return X.T @ X
|
||||||
|
@ -130,12 +148,18 @@ def train_cifar():
|
||||||
h, w = patch_size
|
h, w = patch_size
|
||||||
c = data.shape[1]
|
c = data.shape[1]
|
||||||
axis: SupportsIndex = (2, 3) # type: ignore
|
axis: SupportsIndex = (2, 3) # type: ignore
|
||||||
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
|
return (
|
||||||
|
np.lib.stride_tricks.sliding_window_view(
|
||||||
|
data, window_shape=(h, w), axis=axis
|
||||||
|
)
|
||||||
|
.transpose((0, 3, 2, 1, 4, 5))
|
||||||
|
.reshape((-1, c, h, w))
|
||||||
|
)
|
||||||
|
|
||||||
def _eigens(patches):
|
def _eigens(patches):
|
||||||
n, c, h, w = patches.shape
|
n, c, h, w = patches.shape
|
||||||
Σ = _cov(patches.reshape(n, c * h * w))
|
Σ = _cov(patches.reshape(n, c * h * w))
|
||||||
Λ, V = np.linalg.eigh(Σ, UPLO='U')
|
Λ, V = np.linalg.eigh(Σ, UPLO="U")
|
||||||
return np.flip(Λ, 0), np.flip(V.T.reshape(c * h * w, c, h, w), 0)
|
return np.flip(Λ, 0), np.flip(V.T.reshape(c * h * w, c, h, w), 0)
|
||||||
|
|
||||||
Λ, V = _eigens(_patches(X.numpy()))
|
Λ, V = _eigens(_patches(X.numpy()))
|
||||||
|
@ -144,12 +168,16 @@ def train_cifar():
|
||||||
return Tensor(W.astype(np_dtype), requires_grad=False)
|
return Tensor(W.astype(np_dtype), requires_grad=False)
|
||||||
|
|
||||||
# ========== Loss ==========
|
# ========== Loss ==========
|
||||||
def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
|
def cross_entropy(
|
||||||
|
x: Tensor, y: Tensor, reduction: str = "mean", label_smoothing: float = 0.0
|
||||||
|
) -> Tensor:
|
||||||
divisor = y.shape[1]
|
divisor = y.shape[1]
|
||||||
assert not isinstance(divisor, Node), "sint not supported as divisor"
|
assert not isinstance(divisor, Node), "sint not supported as divisor"
|
||||||
y = (1 - label_smoothing) * y + label_smoothing / divisor
|
y = (1 - label_smoothing) * y + label_smoothing / divisor
|
||||||
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
if reduction == "none":
|
||||||
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||||
|
if reduction == "sum":
|
||||||
|
return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
||||||
|
|
||||||
# ========== Preprocessing ==========
|
# ========== Preprocessing ==========
|
||||||
|
@ -161,12 +189,20 @@ def train_cifar():
|
||||||
p = padding[3]
|
p = padding[3]
|
||||||
s = X.shape[3]
|
s = X.shape[3]
|
||||||
|
|
||||||
X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0)))
|
X_lr = X[..., :, 1 : 1 + p[0]].flip(3).pad(
|
||||||
|
((0, 0), (0, 0), (0, 0), (0, s + p[0]))
|
||||||
|
) + X[..., :, -1 - p[1] : -1].flip(3).pad(
|
||||||
|
((0, 0), (0, 0), (0, 0), (s + p[1], 0))
|
||||||
|
)
|
||||||
X = X.pad(((0, 0), (0, 0), (0, 0), p)) + X_lr
|
X = X.pad(((0, 0), (0, 0), (0, 0), p)) + X_lr
|
||||||
|
|
||||||
p = padding[2]
|
p = padding[2]
|
||||||
s = X.shape[2]
|
s = X.shape[2]
|
||||||
X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0)))
|
X_lr = X[..., 1 : 1 + p[0], :].flip(2).pad(
|
||||||
|
((0, 0), (0, 0), (0, s + p[0]), (0, 0))
|
||||||
|
) + X[..., -1 - p[1] : -1, :].flip(2).pad(
|
||||||
|
((0, 0), (0, 0), (s + p[1], 0), (0, 0))
|
||||||
|
)
|
||||||
X = X.pad(((0, 0), (0, 0), p, (0, 0))) + X_lr
|
X = X.pad(((0, 0), (0, 0), p, (0, 0))) + X_lr
|
||||||
|
|
||||||
return X
|
return X
|
||||||
|
@ -176,10 +212,18 @@ def train_cifar():
|
||||||
is_even = int(mask_size % 2 == 0)
|
is_even = int(mask_size % 2 == 0)
|
||||||
center_max = shape[-2] - mask_size // 2 - is_even
|
center_max = shape[-2] - mask_size // 2 - is_even
|
||||||
center_min = mask_size // 2 - is_even
|
center_min = mask_size // 2 - is_even
|
||||||
center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
center_x = (
|
||||||
center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
|
||||||
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
|
).floor()
|
||||||
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
|
center_y = (
|
||||||
|
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
|
||||||
|
).floor()
|
||||||
|
d_x = Tensor.arange(0, shape[-1]).reshape(
|
||||||
|
(1, 1, 1, shape[-1])
|
||||||
|
) - center_x.reshape((-1, 1, 1, 1))
|
||||||
|
d_y = Tensor.arange(0, shape[-2]).reshape(
|
||||||
|
(1, 1, shape[-2], 1)
|
||||||
|
) - center_y.reshape((-1, 1, 1, 1))
|
||||||
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
||||||
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
||||||
mask = d_y * d_x
|
mask = d_y * d_x
|
||||||
|
@ -200,7 +244,7 @@ def train_cifar():
|
||||||
Y_patch = Tensor(Y.numpy()[order])
|
Y_patch = Tensor(Y.numpy()[order])
|
||||||
X_cutmix = Tensor.where(mask, X_patch, X)
|
X_cutmix = Tensor.where(mask, X_patch, X)
|
||||||
mix_portion = float(mask_size**2) / (X.shape[-2] * X.shape[-1])
|
mix_portion = float(mask_size**2) / (X.shape[-2] * X.shape[-1])
|
||||||
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
|
Y_cutmix = mix_portion * Y_patch + (1.0 - mix_portion) * Y
|
||||||
return X_cutmix, Y_cutmix
|
return X_cutmix, Y_cutmix
|
||||||
|
|
||||||
# the operations that remain inside batch fetcher is the ones that involves random operations
|
# the operations that remain inside batch fetcher is the ones that involves random operations
|
||||||
|
@ -213,11 +257,16 @@ def train_cifar():
|
||||||
random.shuffle(order)
|
random.shuffle(order)
|
||||||
if is_train:
|
if is_train:
|
||||||
X = random_crop(X, crop_size=32)
|
X = random_crop(X, crop_size=32)
|
||||||
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
|
X = Tensor.where(
|
||||||
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
|
Tensor.rand(X.shape[0], 1, 1, 1) < 0.5, X[..., ::-1], X
|
||||||
|
) # flip LR
|
||||||
|
if step >= hyp["net"]["cutmix_steps"]:
|
||||||
|
X, Y = cutmix(X, Y, mask_size=hyp["net"]["cutmix_size"])
|
||||||
X, Y = X.numpy(), Y.numpy()
|
X, Y = X.numpy(), Y.numpy()
|
||||||
et = time.monotonic()
|
et = time.monotonic()
|
||||||
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
|
print(
|
||||||
|
f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})"
|
||||||
|
)
|
||||||
for i in range(0, X.shape[0], BS):
|
for i in range(0, X.shape[0], BS):
|
||||||
# pad the last batch
|
# pad the last batch
|
||||||
batch_end = min(i + BS, Y.shape[0])
|
batch_end = min(i + BS, Y.shape[0])
|
||||||
|
@ -226,18 +275,24 @@ def train_cifar():
|
||||||
step += 1
|
step += 1
|
||||||
yield x, y
|
yield x, y
|
||||||
cnt += 1
|
cnt += 1
|
||||||
if not is_train: break
|
if not is_train:
|
||||||
|
break
|
||||||
|
|
||||||
transform = [
|
transform = [
|
||||||
lambda x: x / 255.0,
|
lambda x: x / 255.0,
|
||||||
lambda x: (x.reshape((-1,3,32,32)) - Tensor(cifar_mean).reshape((1,3,1,1)))/Tensor(cifar_std).reshape((1,3,1,1))
|
lambda x: (
|
||||||
|
x.reshape((-1, 3, 32, 32)) - Tensor(cifar_mean).reshape((1, 3, 1, 1))
|
||||||
|
)
|
||||||
|
/ Tensor(cifar_std).reshape((1, 3, 1, 1)),
|
||||||
]
|
]
|
||||||
|
|
||||||
class modelEMA():
|
class modelEMA:
|
||||||
def __init__(self, w, net):
|
def __init__(self, w, net):
|
||||||
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
|
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
|
||||||
self.net_ema = SpeedyResNet(w)
|
self.net_ema = SpeedyResNet(w)
|
||||||
for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
|
for net_ema_param, net_param in zip(
|
||||||
|
get_state_dict(self.net_ema).values(), get_state_dict(net).values()
|
||||||
|
):
|
||||||
net_ema_param.requires_grad = False
|
net_ema_param.requires_grad = False
|
||||||
net_ema_param.assign(net_param.numpy())
|
net_ema_param.assign(net_param.numpy())
|
||||||
|
|
||||||
|
@ -245,23 +300,37 @@ def train_cifar():
|
||||||
def update(self, net, decay):
|
def update(self, net, decay):
|
||||||
# TODO with Tensor.no_grad()
|
# TODO with Tensor.no_grad()
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
|
for net_ema_param, (param_name, net_param) in zip(
|
||||||
|
get_state_dict(self.net_ema).values(), get_state_dict(net).items()
|
||||||
|
):
|
||||||
# batchnorm currently is not being tracked
|
# batchnorm currently is not being tracked
|
||||||
if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
|
if not ("num_batches_tracked" in param_name) and not (
|
||||||
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
|
"running" in param_name
|
||||||
|
):
|
||||||
|
net_ema_param.assign(
|
||||||
|
net_ema_param.detach() * decay
|
||||||
|
+ net_param.detach() * (1.0 - decay)
|
||||||
|
).realize()
|
||||||
Tensor.no_grad = False
|
Tensor.no_grad = False
|
||||||
|
|
||||||
set_seed(hyp['seed'])
|
set_seed(hyp["seed"])
|
||||||
|
|
||||||
# this import needs to be done here because this is running in a subprocess
|
# this import needs to be done here because this is running in a subprocess
|
||||||
from extra.dist import OOB
|
from extra.dist import OOB
|
||||||
|
|
||||||
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
|
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
|
||||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||||
|
|
||||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||||
# load data and label into GPU and convert to dtype accordingly
|
# load data and label into GPU and convert to dtype accordingly
|
||||||
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
|
X_train, X_test = (
|
||||||
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
|
X_train.to(device=Device.DEFAULT).float(),
|
||||||
|
X_test.to(device=Device.DEFAULT).float(),
|
||||||
|
)
|
||||||
|
Y_train, Y_test = (
|
||||||
|
Y_train.to(device=Device.DEFAULT).float(),
|
||||||
|
Y_test.to(device=Device.DEFAULT).float(),
|
||||||
|
)
|
||||||
# one-hot encode labels
|
# one-hot encode labels
|
||||||
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
|
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
|
||||||
# preprocess data
|
# preprocess data
|
||||||
|
@ -274,10 +343,15 @@ def train_cifar():
|
||||||
model = SpeedyResNet(W)
|
model = SpeedyResNet(W)
|
||||||
|
|
||||||
# padding is not timed in the original repo since it can be done all at once
|
# padding is not timed in the original repo since it can be done all at once
|
||||||
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
|
X_train = pad_reflect(X_train, size=hyp["net"]["pad_amount"])
|
||||||
|
|
||||||
# Convert data and labels to the default dtype
|
# Convert data and labels to the default dtype
|
||||||
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
|
X_train, Y_train, X_test, Y_test = (
|
||||||
|
X_train.cast(Tensor.default_type),
|
||||||
|
Y_train.cast(Tensor.default_type),
|
||||||
|
X_test.cast(Tensor.default_type),
|
||||||
|
Y_test.cast(Tensor.default_type),
|
||||||
|
)
|
||||||
|
|
||||||
# parse the training params into bias and non-bias
|
# parse the training params into bias and non-bias
|
||||||
params_dict = get_state_dict(model)
|
params_dict = get_state_dict(model)
|
||||||
|
@ -285,26 +359,60 @@ def train_cifar():
|
||||||
params_non_bias = []
|
params_non_bias = []
|
||||||
for params in params_dict:
|
for params in params_dict:
|
||||||
if params_dict[params].requires_grad is not False:
|
if params_dict[params].requires_grad is not False:
|
||||||
if 'bias' in params:
|
if "bias" in params:
|
||||||
params_bias.append(params_dict[params])
|
params_bias.append(params_dict[params])
|
||||||
else:
|
else:
|
||||||
params_non_bias.append(params_dict[params])
|
params_non_bias.append(params_dict[params])
|
||||||
|
|
||||||
opt_bias = optim.SGD(params_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
opt_bias = optim.SGD(
|
||||||
opt_non_bias = optim.SGD(params_non_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
|
params_bias,
|
||||||
|
lr=0.01,
|
||||||
|
momentum=hyp["opt"]["momentum"],
|
||||||
|
nesterov=True,
|
||||||
|
weight_decay=hyp["opt"]["bias_decay"],
|
||||||
|
)
|
||||||
|
opt_non_bias = optim.SGD(
|
||||||
|
params_non_bias,
|
||||||
|
lr=0.01,
|
||||||
|
momentum=hyp["opt"]["momentum"],
|
||||||
|
nesterov=True,
|
||||||
|
weight_decay=hyp["opt"]["non_bias_decay"],
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
|
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
|
||||||
initial_div_factor = hyp['opt']['initial_div_factor']
|
initial_div_factor = hyp["opt"]["initial_div_factor"]
|
||||||
final_lr_ratio = hyp['opt']['final_lr_ratio']
|
final_lr_ratio = hyp["opt"]["final_lr_ratio"]
|
||||||
pct_start = hyp['opt']['percent_start']
|
pct_start = hyp["opt"]["percent_start"]
|
||||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
lr_sched_bias = OneCycleLR(
|
||||||
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
opt_bias,
|
||||||
|
max_lr=hyp["opt"]["bias_lr"],
|
||||||
|
pct_start=pct_start,
|
||||||
|
div_factor=initial_div_factor,
|
||||||
|
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
|
||||||
|
total_steps=STEPS,
|
||||||
|
)
|
||||||
|
lr_sched_non_bias = OneCycleLR(
|
||||||
|
opt_non_bias,
|
||||||
|
max_lr=hyp["opt"]["non_bias_lr"],
|
||||||
|
pct_start=pct_start,
|
||||||
|
div_factor=initial_div_factor,
|
||||||
|
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
|
||||||
|
total_steps=STEPS,
|
||||||
|
)
|
||||||
|
|
||||||
loss_batchsize_scaler = 512 / BS
|
loss_batchsize_scaler = 512 / BS
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def train_step_jitted(model, optimizer, lr_scheduler, X, Y):
|
def train_step_jitted(model, optimizer, lr_scheduler, X, Y):
|
||||||
out = model(X)
|
out = model(X)
|
||||||
loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
loss = (
|
||||||
|
cross_entropy(
|
||||||
|
out, Y, reduction="none", label_smoothing=hyp["opt"]["label_smoothing"]
|
||||||
|
)
|
||||||
|
.mul(hyp["opt"]["loss_scale_scaler"] * loss_batchsize_scaler)
|
||||||
|
.sum()
|
||||||
|
.div(hyp["opt"]["loss_scale_scaler"])
|
||||||
|
)
|
||||||
|
|
||||||
if not getenv("DISABLE_BACKWARD"):
|
if not getenv("DISABLE_BACKWARD"):
|
||||||
# index 0 for bias and 1 for non-bias
|
# index 0 for bias and 1 for non-bias
|
||||||
|
@ -316,11 +424,16 @@ def train_cifar():
|
||||||
# sync gradients across ranks
|
# sync gradients across ranks
|
||||||
bucket, offset = [], 0
|
bucket, offset = [], 0
|
||||||
for _, v in params_dict.items():
|
for _, v in params_dict.items():
|
||||||
if v.grad is not None: bucket.append(v.grad.flatten())
|
if v.grad is not None:
|
||||||
|
bucket.append(v.grad.flatten())
|
||||||
grads = collectives.allreduce(Tensor.cat(*bucket))
|
grads = collectives.allreduce(Tensor.cat(*bucket))
|
||||||
for _, v in params_dict.items():
|
for _, v in params_dict.items():
|
||||||
if v.grad is not None:
|
if v.grad is not None:
|
||||||
v.grad.assign(grads[offset:offset+v.grad.numel()].reshape(*v.grad.shape))
|
v.grad.assign(
|
||||||
|
grads[offset : offset + v.grad.numel()].reshape(
|
||||||
|
*v.grad.shape
|
||||||
|
)
|
||||||
|
)
|
||||||
offset += v.grad.numel()
|
offset += v.grad.numel()
|
||||||
|
|
||||||
optimizer[0].step()
|
optimizer[0].step()
|
||||||
|
@ -331,9 +444,10 @@ def train_cifar():
|
||||||
|
|
||||||
def eval_step(model, X, Y):
|
def eval_step(model, X, Y):
|
||||||
out = model(X, training=False)
|
out = model(X, training=False)
|
||||||
loss = cross_entropy(out, Y, reduction='mean')
|
loss = cross_entropy(out, Y, reduction="mean")
|
||||||
correct = out.argmax(axis=1) == Y.argmax(axis=1)
|
correct = out.argmax(axis=1) == Y.argmax(axis=1)
|
||||||
return correct.realize(), loss.realize()
|
return correct.realize(), loss.realize()
|
||||||
|
|
||||||
eval_step_jitted = TinyJit(eval_step)
|
eval_step_jitted = TinyJit(eval_step)
|
||||||
eval_step_ema_jitted = TinyJit(eval_step)
|
eval_step_ema_jitted = TinyJit(eval_step)
|
||||||
|
|
||||||
|
@ -347,7 +461,7 @@ def train_cifar():
|
||||||
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
||||||
|
|
||||||
model_ema: Optional[modelEMA] = None
|
model_ema: Optional[modelEMA] = None
|
||||||
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
|
projected_ema_decay_val = hyp["ema"]["decay_base"] ** hyp["ema"]["every_n_steps"]
|
||||||
i = 0
|
i = 0
|
||||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||||
with Tensor.train():
|
with Tensor.train():
|
||||||
|
@ -363,24 +477,37 @@ def train_cifar():
|
||||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||||
# further split batch if distributed
|
# further split batch if distributed
|
||||||
if getenv("DIST"):
|
if getenv("DIST"):
|
||||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
Xt, Yt = (
|
||||||
|
Xt.chunk(min(world_size, 5), 0)[min(rank, 4)],
|
||||||
|
Yt.chunk(min(world_size, 5), 0)[min(rank, 4)],
|
||||||
|
)
|
||||||
|
|
||||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||||
losses.append(loss.numpy().tolist())
|
losses.append(loss.numpy().tolist())
|
||||||
corrects.extend(correct.numpy().tolist())
|
corrects.extend(correct.numpy().tolist())
|
||||||
if model_ema:
|
if model_ema:
|
||||||
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
|
correct_ema, loss_ema = eval_step_ema_jitted(
|
||||||
|
model_ema.net_ema, Xt, Yt
|
||||||
|
)
|
||||||
losses_ema.append(loss_ema.numpy().tolist())
|
losses_ema.append(loss_ema.numpy().tolist())
|
||||||
corrects_ema.extend(correct_ema.numpy().tolist())
|
corrects_ema.extend(correct_ema.numpy().tolist())
|
||||||
|
|
||||||
# collect accuracy across ranks
|
# collect accuracy across ranks
|
||||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
if model_ema:
|
||||||
|
correct_sum_ema, correct_len_ema = sum(corrects_ema), len(
|
||||||
|
corrects_ema
|
||||||
|
)
|
||||||
if getenv("DIST"):
|
if getenv("DIST"):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
for j in range(1, min(world_size, 5)):
|
for j in range(1, min(world_size, 5)):
|
||||||
if model_ema:
|
if model_ema:
|
||||||
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
(
|
||||||
|
recv_sum,
|
||||||
|
recv_len,
|
||||||
|
recv_sum_ema,
|
||||||
|
recv_len_ema,
|
||||||
|
) = OOB.recv(j)
|
||||||
else:
|
else:
|
||||||
recv_sum, recv_len = OOB.recv(j)
|
recv_sum, recv_len = OOB.recv(j)
|
||||||
correct_sum += recv_sum
|
correct_sum += recv_sum
|
||||||
|
@ -390,55 +517,95 @@ def train_cifar():
|
||||||
correct_len_ema += recv_len_ema
|
correct_len_ema += recv_len_ema
|
||||||
elif rank < min(world_size, 5):
|
elif rank < min(world_size, 5):
|
||||||
if model_ema:
|
if model_ema:
|
||||||
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
OOB.send(
|
||||||
|
(
|
||||||
|
correct_sum,
|
||||||
|
correct_len,
|
||||||
|
correct_sum_ema,
|
||||||
|
correct_len_ema,
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
OOB.send((correct_sum, correct_len), 0)
|
OOB.send((correct_sum, correct_len), 0)
|
||||||
|
|
||||||
# only rank 0 prints
|
# only rank 0 prints
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
acc = correct_sum / correct_len * 100.0
|
acc = correct_sum / correct_len * 100.0
|
||||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
if model_ema:
|
||||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
acc_ema = correct_sum_ema / correct_len_ema * 100.0
|
||||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
print(
|
||||||
|
f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)"
|
||||||
|
)
|
||||||
|
if model_ema:
|
||||||
|
print(
|
||||||
|
f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}"
|
||||||
|
)
|
||||||
|
|
||||||
if STEPS == 0 or i==STEPS: break
|
if STEPS == 0 or i == STEPS:
|
||||||
|
break
|
||||||
X, Y = next(batcher)
|
X, Y = next(batcher)
|
||||||
if getenv("DIST"):
|
if getenv("DIST"):
|
||||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
loss = train_step_jitted(
|
||||||
|
model,
|
||||||
|
[opt_bias, opt_non_bias],
|
||||||
|
[lr_sched_bias, lr_sched_non_bias],
|
||||||
|
X,
|
||||||
|
Y,
|
||||||
|
)
|
||||||
et = time.monotonic()
|
et = time.monotonic()
|
||||||
loss_cpu = loss.numpy()
|
loss_cpu = loss.numpy()
|
||||||
# EMA for network weights
|
# EMA for network weights
|
||||||
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
if i > hyp["ema"]["steps"] and (i + 1) % hyp["ema"]["every_n_steps"] == 0:
|
||||||
if model_ema is None:
|
if model_ema is None:
|
||||||
model_ema = modelEMA(W, model)
|
model_ema = modelEMA(W, model)
|
||||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
model_ema.update(
|
||||||
|
model,
|
||||||
|
Tensor(
|
||||||
|
[
|
||||||
|
projected_ema_decay_val
|
||||||
|
* (i / STEPS) ** hyp["ema"]["decay_pow"]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
cl = time.monotonic()
|
cl = time.monotonic()
|
||||||
if not getenv("DIST"):
|
if not getenv("DIST"):
|
||||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
print(
|
||||||
|
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
print(
|
||||||
|
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||||
|
)
|
||||||
st = cl
|
st = cl
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if not getenv("DIST"):
|
if not getenv("DIST"):
|
||||||
train_cifar()
|
train_cifar()
|
||||||
else: # distributed
|
else: # distributed
|
||||||
if getenv("HIP"):
|
if getenv("HIP"):
|
||||||
from tinygrad.runtime.ops_hip import HIP
|
from tinygrad.runtime.ops_hip import HIP
|
||||||
|
|
||||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||||
else:
|
else:
|
||||||
from tinygrad.runtime.ops_gpu import CLDevice
|
from tinygrad.runtime.ops_gpu import CLDevice
|
||||||
|
|
||||||
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
||||||
world_size = len(devices)
|
world_size = len(devices)
|
||||||
|
|
||||||
# ensure that the batch size is divisible by the number of devices
|
# ensure that the batch size is divisible by the number of devices
|
||||||
assert BS % world_size == 0, f"batch size {BS} is not divisible by world size {world_size}"
|
assert (
|
||||||
|
BS % world_size == 0
|
||||||
|
), f"batch size {BS} is not divisible by world size {world_size}"
|
||||||
|
|
||||||
# ensure that the evaluation batch size is divisible by the number of devices
|
# ensure that the evaluation batch size is divisible by the number of devices
|
||||||
assert EVAL_BS % min(world_size, 5) == 0, f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
|
assert (
|
||||||
|
EVAL_BS % min(world_size, 5) == 0
|
||||||
|
), f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
|
||||||
|
|
||||||
# init out-of-band communication
|
# init out-of-band communication
|
||||||
dist.init_oob(world_size)
|
dist.init_oob(world_size)
|
||||||
|
@ -447,4 +614,5 @@ if __name__ == "__main__":
|
||||||
processes = []
|
processes = []
|
||||||
for rank, device in enumerate(devices):
|
for rank, device in enumerate(devices):
|
||||||
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
|
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
|
||||||
for p in processes: p.join()
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys, argparse, json
|
import sys, argparse, json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
np.set_printoptions(linewidth=200)
|
np.set_printoptions(linewidth=200)
|
||||||
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes
|
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
|
@ -24,84 +25,225 @@ MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
|
||||||
MODEL_PARAMS = {
|
MODEL_PARAMS = {
|
||||||
"1": {
|
"1": {
|
||||||
"7B": {
|
"7B": {
|
||||||
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 11008},
|
"args": {
|
||||||
|
"dim": 4096,
|
||||||
|
"n_heads": 32,
|
||||||
|
"n_layers": 32,
|
||||||
|
"norm_eps": 1e-06,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 11008,
|
||||||
|
},
|
||||||
"files": 1,
|
"files": 1,
|
||||||
},
|
},
|
||||||
"13B": {
|
"13B": {
|
||||||
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 13824},
|
"args": {
|
||||||
|
"dim": 5120,
|
||||||
|
"n_heads": 40,
|
||||||
|
"n_layers": 40,
|
||||||
|
"norm_eps": 1e-06,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 13824,
|
||||||
|
},
|
||||||
"files": 2,
|
"files": 2,
|
||||||
},
|
},
|
||||||
"30B": {
|
"30B": {
|
||||||
"args": {"dim": 6656, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 17920},
|
"args": {
|
||||||
|
"dim": 6656,
|
||||||
|
"n_heads": 52,
|
||||||
|
"n_layers": 60,
|
||||||
|
"norm_eps": 1e-06,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 17920,
|
||||||
|
},
|
||||||
"files": 4,
|
"files": 4,
|
||||||
},
|
},
|
||||||
"65B": {
|
"65B": {
|
||||||
"args": {"dim": 8192, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 22016},
|
"args": {
|
||||||
|
"dim": 8192,
|
||||||
|
"n_heads": 64,
|
||||||
|
"n_layers": 80,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 22016,
|
||||||
|
},
|
||||||
"files": 8,
|
"files": 8,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"2": {
|
"2": {
|
||||||
"7B": {
|
"7B": {
|
||||||
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 11008},
|
"args": {
|
||||||
|
"dim": 4096,
|
||||||
|
"n_heads": 32,
|
||||||
|
"n_layers": 32,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 11008,
|
||||||
|
},
|
||||||
"files": 1,
|
"files": 1,
|
||||||
},
|
},
|
||||||
"13B": {
|
"13B": {
|
||||||
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 13824},
|
"args": {
|
||||||
|
"dim": 5120,
|
||||||
|
"n_heads": 40,
|
||||||
|
"n_layers": 40,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 13824,
|
||||||
|
},
|
||||||
"files": 2,
|
"files": 2,
|
||||||
},
|
},
|
||||||
"70B": {
|
"70B": {
|
||||||
"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 28672},
|
"args": {
|
||||||
|
"dim": 8192,
|
||||||
|
"n_heads": 64,
|
||||||
|
"n_kv_heads": 8,
|
||||||
|
"n_layers": 80,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 28672,
|
||||||
|
},
|
||||||
"files": 8,
|
"files": 8,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"code": {
|
"code": {
|
||||||
"7B": {
|
"7B": {
|
||||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
|
"args": {
|
||||||
|
"dim": 4096,
|
||||||
|
"n_layers": 32,
|
||||||
|
"n_heads": 32,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32016,
|
||||||
|
"hidden_dim": 11008,
|
||||||
|
},
|
||||||
"files": 1,
|
"files": 1,
|
||||||
},
|
},
|
||||||
"7B-Python": {
|
"7B-Python": {
|
||||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 11008},
|
"args": {
|
||||||
|
"dim": 4096,
|
||||||
|
"n_layers": 32,
|
||||||
|
"n_heads": 32,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 11008,
|
||||||
|
},
|
||||||
"files": 1,
|
"files": 1,
|
||||||
},
|
},
|
||||||
"7B-Instruct": {
|
"7B-Instruct": {
|
||||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
|
"args": {
|
||||||
|
"dim": 4096,
|
||||||
|
"n_layers": 32,
|
||||||
|
"n_heads": 32,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32016,
|
||||||
|
"hidden_dim": 11008,
|
||||||
|
},
|
||||||
"files": 1,
|
"files": 1,
|
||||||
},
|
},
|
||||||
"13B": {
|
"13B": {
|
||||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
|
"args": {
|
||||||
|
"dim": 5120,
|
||||||
|
"n_layers": 40,
|
||||||
|
"n_heads": 40,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32016,
|
||||||
|
"hidden_dim": 13824,
|
||||||
|
},
|
||||||
"files": 2,
|
"files": 2,
|
||||||
},
|
},
|
||||||
"13B-Python": {
|
"13B-Python": {
|
||||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 13824},
|
"args": {
|
||||||
|
"dim": 5120,
|
||||||
|
"n_layers": 40,
|
||||||
|
"n_heads": 40,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 13824,
|
||||||
|
},
|
||||||
"files": 2,
|
"files": 2,
|
||||||
},
|
},
|
||||||
"13B-Instruct": {
|
"13B-Instruct": {
|
||||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
|
"args": {
|
||||||
|
"dim": 5120,
|
||||||
|
"n_layers": 40,
|
||||||
|
"n_heads": 40,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32016,
|
||||||
|
"hidden_dim": 13824,
|
||||||
|
},
|
||||||
"files": 2,
|
"files": 2,
|
||||||
},
|
},
|
||||||
"34B": {
|
"34B": {
|
||||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
"args": {
|
||||||
|
"dim": 8192,
|
||||||
|
"n_layers": 48,
|
||||||
|
"n_heads": 64,
|
||||||
|
"n_kv_heads": 8,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 22016,
|
||||||
|
},
|
||||||
"files": 4,
|
"files": 4,
|
||||||
},
|
},
|
||||||
"34B-Python": {
|
"34B-Python": {
|
||||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
"args": {
|
||||||
|
"dim": 8192,
|
||||||
|
"n_layers": 48,
|
||||||
|
"n_heads": 64,
|
||||||
|
"n_kv_heads": 8,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 22016,
|
||||||
|
},
|
||||||
"files": 4,
|
"files": 4,
|
||||||
},
|
},
|
||||||
"34B-Instruct": {
|
"34B-Instruct": {
|
||||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
"args": {
|
||||||
|
"dim": 8192,
|
||||||
|
"n_layers": 48,
|
||||||
|
"n_heads": 64,
|
||||||
|
"n_kv_heads": 8,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 22016,
|
||||||
|
},
|
||||||
"files": 4,
|
"files": 4,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"tiny": {
|
"tiny": {
|
||||||
"1B": {
|
"1B": {
|
||||||
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632},
|
"args": {
|
||||||
|
"dim": 2048,
|
||||||
|
"n_layers": 22,
|
||||||
|
"n_heads": 32,
|
||||||
|
"n_kv_heads": 4,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"vocab_size": 32000,
|
||||||
|
"hidden_dim": 5632,
|
||||||
|
},
|
||||||
"files": 1,
|
"files": 1,
|
||||||
},
|
},
|
||||||
"1B-Chat": {
|
"1B-Chat": {
|
||||||
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632},
|
"args": {
|
||||||
|
"dim": 2048,
|
||||||
|
"n_layers": 22,
|
||||||
|
"n_heads": 32,
|
||||||
|
"n_kv_heads": 4,
|
||||||
|
"norm_eps": 1e-05,
|
||||||
|
"vocab_size": 32003,
|
||||||
|
"hidden_dim": 5632,
|
||||||
|
},
|
||||||
"files": 1,
|
"files": 1,
|
||||||
}
|
},
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,21 +253,37 @@ def concat_weights(models):
|
||||||
disk_tensors = [model[name] for model in models]
|
disk_tensors = [model[name] for model in models]
|
||||||
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
||||||
return disk_tensors[0].to(device=Device.DEFAULT)
|
return disk_tensors[0].to(device=Device.DEFAULT)
|
||||||
axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
|
axis = (
|
||||||
|
1
|
||||||
|
if name.startswith("tok_embeddings.")
|
||||||
|
or name.endswith(".attention.wo.weight")
|
||||||
|
or name.endswith(".feed_forward.w2.weight")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
|
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
|
||||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||||
return {name: convert(name) for name in {name: None for model in models for name in model}}
|
|
||||||
|
return {
|
||||||
|
name: convert(name)
|
||||||
|
for name in {name: None for model in models for name in model}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load(fn: str):
|
def load(fn: str):
|
||||||
if fn.endswith('.index.json'):
|
if fn.endswith(".index.json"):
|
||||||
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
with open(fn) as fp:
|
||||||
parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
|
weight_map = json.load(fp)["weight_map"]
|
||||||
|
parts = {
|
||||||
|
n: load(str(Path(fn).parent / Path(n).name))
|
||||||
|
for n in set(weight_map.values())
|
||||||
|
}
|
||||||
return {k: parts[n][k] for k, n in weight_map.items()}
|
return {k: parts[n][k] for k, n in weight_map.items()}
|
||||||
elif fn.endswith(".safetensors"):
|
elif fn.endswith(".safetensors"):
|
||||||
return safe_load(fn)
|
return safe_load(fn)
|
||||||
else:
|
else:
|
||||||
return torch_load(fn)
|
return torch_load(fn)
|
||||||
|
|
||||||
|
|
||||||
class AbsmaxQuantizedLinear:
|
class AbsmaxQuantizedLinear:
|
||||||
def __init__(self, in_features, out_features, bias=False):
|
def __init__(self, in_features, out_features, bias=False):
|
||||||
assert bias == False
|
assert bias == False
|
||||||
|
@ -139,34 +297,63 @@ class AbsmaxQuantizedLinear:
|
||||||
def quantize(tensors):
|
def quantize(tensors):
|
||||||
new_tensors = {}
|
new_tensors = {}
|
||||||
for name, v in tensors.items():
|
for name, v in tensors.items():
|
||||||
if "feed_forward" in name or ("attention.w") in name or name == "output.weight":
|
if (
|
||||||
|
"feed_forward" in name
|
||||||
|
or ("attention.w") in name
|
||||||
|
or name == "output.weight"
|
||||||
|
):
|
||||||
scale = v.abs().max(axis=1) / 127.0
|
scale = v.abs().max(axis=1) / 127.0
|
||||||
int8_weight = (v.T / scale).T.cast(dtype=dtypes.int8)
|
int8_weight = (v.T / scale).T.cast(dtype=dtypes.int8)
|
||||||
new_tensors[name] = int8_weight
|
new_tensors[name] = int8_weight
|
||||||
new_tensors[name.replace('weight', 'scale')] = scale
|
new_tensors[name.replace("weight", "scale")] = scale
|
||||||
else:
|
else:
|
||||||
new_tensors[name] = v
|
new_tensors[name] = v
|
||||||
return new_tensors
|
return new_tensors
|
||||||
|
|
||||||
|
|
||||||
class LLaMa:
|
class LLaMa:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
|
def build(
|
||||||
|
model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False
|
||||||
|
):
|
||||||
params = MODEL_PARAMS[model_gen][model_size]
|
params = MODEL_PARAMS[model_gen][model_size]
|
||||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||||
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
assert (
|
||||||
|
sp_model.vocab_size() == params["args"]["vocab_size"]
|
||||||
|
), f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
||||||
|
|
||||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
model = (
|
||||||
|
Transformer(
|
||||||
|
**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT
|
||||||
|
)
|
||||||
|
if quantize
|
||||||
|
else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
||||||
|
)
|
||||||
|
|
||||||
if model_path.is_dir():
|
if model_path.is_dir():
|
||||||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
weights = concat_weights(
|
||||||
|
[
|
||||||
|
load(filename)
|
||||||
|
for filename in [
|
||||||
|
f"{model_path}/consolidated.{i:02d}.pth"
|
||||||
|
for i in range(params["files"])
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
weights = load(str(model_path))
|
weights = load(str(model_path))
|
||||||
if "model.embed_tokens.weight" in weights:
|
if "model.embed_tokens.weight" in weights:
|
||||||
weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
|
weights = convert_from_huggingface(
|
||||||
|
weights,
|
||||||
|
model,
|
||||||
|
params["args"]["n_heads"],
|
||||||
|
params["args"].get("n_kv_heads", params["args"]["n_heads"]),
|
||||||
|
)
|
||||||
|
|
||||||
if quantize:
|
if quantize:
|
||||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||||
for _,v in weights.items(): v.realize()
|
for _, v in weights.items():
|
||||||
|
v.realize()
|
||||||
load_state_dict(model, weights, strict=False)
|
load_state_dict(model, weights, strict=False)
|
||||||
|
|
||||||
return LLaMa(model, sp_model)
|
return LLaMa(model, sp_model)
|
||||||
|
@ -179,18 +366,23 @@ class LLaMa:
|
||||||
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
|
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
for i in range(max_length):
|
for i in range(max_length):
|
||||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).realize()
|
probs = llama.model(
|
||||||
|
Tensor([toks[start_pos:]]), start_pos, temperature
|
||||||
|
).realize()
|
||||||
probs_np = probs.numpy()
|
probs_np = probs.numpy()
|
||||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||||
start_pos = len(toks)
|
start_pos = len(toks)
|
||||||
toks.append(tok)
|
toks.append(tok)
|
||||||
|
|
||||||
if tok == self.tokenizer.eos_id(): break
|
if tok == self.tokenizer.eos_id():
|
||||||
|
break
|
||||||
output = self.tokenizer.decode(toks)
|
output = self.tokenizer.decode(toks)
|
||||||
for s in until:
|
for s in until:
|
||||||
if output.endswith(s): return output[0:-len(s)]
|
if output.endswith(s):
|
||||||
|
return output[0 : -len(s)]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# **** main code ****
|
# **** main code ****
|
||||||
"""
|
"""
|
||||||
test:
|
test:
|
||||||
|
@ -256,21 +448,58 @@ if __name__ == "__main__":
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
print(f"using {Device.DEFAULT} backend")
|
print(f"using {Device.DEFAULT} backend")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
|
description="Run LLaMA in tinygrad",
|
||||||
parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate")
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
|
)
|
||||||
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Phrase to start with. Without this, it goes into chatbot mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--count", type=int, default=1000, help="Max number of tokens to generate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--personality",
|
||||||
|
type=str,
|
||||||
|
default="Stacy",
|
||||||
|
help="Personality, can be Stacy, George, Gary, or Lexie",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature", type=float, default=0.7, help="Temperature in the softmax"
|
||||||
|
)
|
||||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||||
parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof")
|
parser.add_argument(
|
||||||
parser.add_argument("--gen", default="1", help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""")
|
"--profile", action="store_true", help="Output profile data to out.prof"
|
||||||
parser.add_argument("--size", type=str, default=None, help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""")
|
)
|
||||||
parser.add_argument("--quantize", action="store_true", help="Quantize the weights to int8 in memory")
|
parser.add_argument(
|
||||||
parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
"--gen",
|
||||||
|
default="1",
|
||||||
|
help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--size",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantize", action="store_true", help="Quantize the weights to int8 in memory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.gen not in MODEL_PARAMS: raise ValueError("Invalid model generation")
|
if args.gen not in MODEL_PARAMS:
|
||||||
if args.size is None: args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
|
raise ValueError("Invalid model generation")
|
||||||
|
if args.size is None:
|
||||||
|
args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
|
||||||
chatbot = args.prompt == None
|
chatbot = args.prompt == None
|
||||||
|
|
||||||
# *** prompt engineers work here ****
|
# *** prompt engineers work here ****
|
||||||
|
@ -294,9 +523,13 @@ After you are done speaking, output [EOS]. You are not the User.
|
||||||
user_delim = "\nUser: "
|
user_delim = "\nUser: "
|
||||||
resp_delim = "Stacy: "
|
resp_delim = "Stacy: "
|
||||||
end_delim = " [EOS]\n"
|
end_delim = " [EOS]\n"
|
||||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
pre_prompt += "".join(
|
||||||
|
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||||
|
)
|
||||||
elif args.personality.lower() == "george":
|
elif args.personality.lower() == "george":
|
||||||
print("WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter.")
|
print(
|
||||||
|
"WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter."
|
||||||
|
)
|
||||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User
|
pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User
|
||||||
You are an AI version of George Hotz. You act as much as you can like George.
|
You are an AI version of George Hotz. You act as much as you can like George.
|
||||||
You are one of the greatest computer experts in the world.
|
You are one of the greatest computer experts in the world.
|
||||||
|
@ -312,13 +545,15 @@ After you are done speaking, output [EOS]. You are not the User.
|
||||||
"What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm",
|
"What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm",
|
||||||
"What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer",
|
"What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer",
|
||||||
"How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version",
|
"How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version",
|
||||||
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk"
|
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk",
|
||||||
}
|
}
|
||||||
|
|
||||||
user_delim = "\nUser: "
|
user_delim = "\nUser: "
|
||||||
resp_delim = "George: "
|
resp_delim = "George: "
|
||||||
end_delim = " [EOS]\n"
|
end_delim = " [EOS]\n"
|
||||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
pre_prompt += "".join(
|
||||||
|
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||||
|
)
|
||||||
elif args.personality.lower() == "gary":
|
elif args.personality.lower() == "gary":
|
||||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User
|
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User
|
||||||
You are Gary!
|
You are Gary!
|
||||||
|
@ -331,13 +566,15 @@ After you are done speaking, output [EOS]. You are not the User.
|
||||||
"""
|
"""
|
||||||
examples = {
|
examples = {
|
||||||
"What is your name?": "I am Gary. I used to sell cars.",
|
"What is your name?": "I am Gary. I used to sell cars.",
|
||||||
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla"
|
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla",
|
||||||
}
|
}
|
||||||
|
|
||||||
user_delim = "\nUser: "
|
user_delim = "\nUser: "
|
||||||
resp_delim = "Gary: "
|
resp_delim = "Gary: "
|
||||||
end_delim = " [EOS]\n"
|
end_delim = " [EOS]\n"
|
||||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
pre_prompt += "".join(
|
||||||
|
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||||
|
)
|
||||||
elif args.personality.lower() == "lexie":
|
elif args.personality.lower() == "lexie":
|
||||||
pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad
|
pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad
|
||||||
You are Lexie!
|
You are Lexie!
|
||||||
|
@ -352,21 +589,34 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||||
examples = {
|
examples = {
|
||||||
"hi lexie": "hi chad, glad we finally met up!",
|
"hi lexie": "hi chad, glad we finally met up!",
|
||||||
"you look better than your pictures": "thanks! are you subscribed to my onlyfans?",
|
"you look better than your pictures": "thanks! are you subscribed to my onlyfans?",
|
||||||
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress"
|
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress",
|
||||||
}
|
}
|
||||||
|
|
||||||
user_delim = "\nChad: "
|
user_delim = "\nChad: "
|
||||||
resp_delim = "Lexie: "
|
resp_delim = "Lexie: "
|
||||||
end_delim = " [EOS]\n"
|
end_delim = " [EOS]\n"
|
||||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
pre_prompt += "".join(
|
||||||
|
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||||
|
)
|
||||||
|
|
||||||
# *** prompt engineers stop here ****
|
# *** prompt engineers stop here ****
|
||||||
|
|
||||||
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen]
|
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen]
|
||||||
MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
MODEL_PATH = (
|
||||||
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
|
args.model
|
||||||
|
or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
||||||
|
)
|
||||||
|
TOKENIZER_PATH = (
|
||||||
|
MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent
|
||||||
|
) / "tokenizer.model"
|
||||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||||
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
|
llama = LLaMa.build(
|
||||||
|
MODEL_PATH,
|
||||||
|
TOKENIZER_PATH,
|
||||||
|
model_gen=args.gen,
|
||||||
|
model_size=args.size,
|
||||||
|
quantize=args.quantize,
|
||||||
|
)
|
||||||
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
|
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
|
||||||
|
|
||||||
if chatbot:
|
if chatbot:
|
||||||
|
@ -375,7 +625,9 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||||
|
|
||||||
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
||||||
with Timing():
|
with Timing():
|
||||||
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: outputs are not used
|
llama.model(
|
||||||
|
Tensor([toks]), 0, args.temperature
|
||||||
|
).realize() # NOTE: outputs are not used
|
||||||
start_pos = len(toks)
|
start_pos = len(toks)
|
||||||
else:
|
else:
|
||||||
# non chat bot mode
|
# non chat bot mode
|
||||||
|
@ -403,14 +655,37 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||||
for i in range(args.count):
|
for i in range(args.count):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
|
|
||||||
if args.timing or args.profile: print("")
|
if args.timing or args.profile:
|
||||||
|
print("")
|
||||||
st = GlobalCounters.time_sum_s
|
st = GlobalCounters.time_sum_s
|
||||||
with Profiling(enabled=args.profile):
|
with Profiling(enabled=args.profile):
|
||||||
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
|
with Timing(
|
||||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
"total ",
|
||||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
enabled=args.timing,
|
||||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
on_exit=lambda x: f", {1e9/x:.2f} tok/sec",
|
||||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
):
|
||||||
|
with Timing(
|
||||||
|
"ran model in ",
|
||||||
|
on_exit=(
|
||||||
|
lambda et: (
|
||||||
|
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
|
||||||
|
if DEBUG >= 2
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
|
||||||
|
+ (
|
||||||
|
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s"
|
||||||
|
if DEBUG >= 2
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if DEBUG
|
||||||
|
else None,
|
||||||
|
enabled=args.timing,
|
||||||
|
):
|
||||||
|
probs = llama.model(
|
||||||
|
Tensor([toks[start_pos:]]), start_pos, args.temperature
|
||||||
|
).realize()
|
||||||
# TODO: fix JIT rand so we can put this in the JIT
|
# TODO: fix JIT rand so we can put this in the JIT
|
||||||
tok = probs.multinomial().item()
|
tok = probs.multinomial().item()
|
||||||
|
|
||||||
|
@ -427,5 +702,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||||
outputted = cur
|
outputted = cur
|
||||||
|
|
||||||
# stop after you have your answer
|
# stop after you have your answer
|
||||||
if chatbot and outputted.endswith(end_delim): break
|
if chatbot and outputted.endswith(end_delim):
|
||||||
if not chatbot: break
|
break
|
||||||
|
if not chatbot:
|
||||||
|
break
|
||||||
|
|
|
@ -63,21 +63,23 @@ class Normalize:
|
||||||
image = Ft.normalize(image, mean=self.mean, std=self.std)
|
image = Ft.normalize(image, mean=self.mean, std=self.std)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
transforms = lambda size_scale: T.Compose(
|
transforms = lambda size_scale: T.Compose(
|
||||||
[
|
[
|
||||||
Resize(int(800 * size_scale), int(1333 * size_scale)),
|
Resize(int(800 * size_scale), int(1333 * size_scale)),
|
||||||
T.ToTensor(),
|
T.ToTensor(),
|
||||||
Normalize(
|
Normalize(
|
||||||
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
|
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_bgr255=True
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def expand_boxes(boxes, scale):
|
def expand_boxes(boxes, scale):
|
||||||
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
|
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
|
||||||
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
|
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
|
||||||
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
|
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
|
||||||
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
|
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
|
||||||
|
|
||||||
w_half *= scale
|
w_half *= scale
|
||||||
h_half *= scale
|
h_half *= scale
|
||||||
|
@ -118,7 +120,7 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
|
||||||
mask = mask.expand((1, 1, -1, -1))
|
mask = mask.expand((1, 1, -1, -1))
|
||||||
|
|
||||||
mask = mask.to(torch.float32)
|
mask = mask.to(torch.float32)
|
||||||
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
|
mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
|
||||||
mask = mask[0][0]
|
mask = mask[0][0]
|
||||||
|
|
||||||
if thresh >= 0:
|
if thresh >= 0:
|
||||||
|
@ -169,11 +171,13 @@ class Masker:
|
||||||
|
|
||||||
masker = Masker(threshold=0.5, padding=1)
|
masker = Masker(threshold=0.5, padding=1)
|
||||||
|
|
||||||
|
|
||||||
def select_top_predictions(predictions, confidence_threshold=0.9):
|
def select_top_predictions(predictions, confidence_threshold=0.9):
|
||||||
scores = predictions.get_field("scores").numpy()
|
scores = predictions.get_field("scores").numpy()
|
||||||
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
|
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
|
||||||
return predictions[keep]
|
return predictions[keep]
|
||||||
|
|
||||||
|
|
||||||
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
|
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
|
||||||
image = transforms(size_scale)(original_image).numpy()
|
image = transforms(size_scale)(original_image).numpy()
|
||||||
image = Tensor(image, requires_grad=False)
|
image = Tensor(image, requires_grad=False)
|
||||||
|
@ -189,6 +193,7 @@ def compute_prediction(original_image, model, confidence_threshold, size_scale=1
|
||||||
prediction.add_field("mask", masks)
|
prediction.add_field("mask", masks)
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
def compute_prediction_batched(batch, model, size_scale=1.0):
|
def compute_prediction_batched(batch, model, size_scale=1.0):
|
||||||
imgs = []
|
imgs = []
|
||||||
for img in batch:
|
for img in batch:
|
||||||
|
@ -198,21 +203,25 @@ def compute_prediction_batched(batch, model, size_scale=1.0):
|
||||||
del image
|
del image
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
|
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
|
||||||
|
|
||||||
|
|
||||||
def findContours(*args, **kwargs):
|
def findContours(*args, **kwargs):
|
||||||
if cv2.__version__.startswith('4'):
|
if cv2.__version__.startswith("4"):
|
||||||
contours, hierarchy = cv2.findContours(*args, **kwargs)
|
contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||||
elif cv2.__version__.startswith('3'):
|
elif cv2.__version__.startswith("3"):
|
||||||
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
|
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||||
return contours, hierarchy
|
return contours, hierarchy
|
||||||
|
|
||||||
|
|
||||||
def compute_colors_for_labels(labels):
|
def compute_colors_for_labels(labels):
|
||||||
l = labels[:, None]
|
l = labels[:, None]
|
||||||
colors = l * palette
|
colors = l * palette
|
||||||
colors = (colors % 255).astype("uint8")
|
colors = (colors % 255).astype("uint8")
|
||||||
return colors
|
return colors
|
||||||
|
|
||||||
|
|
||||||
def overlay_mask(image, predictions):
|
def overlay_mask(image, predictions):
|
||||||
image = np.asarray(image)
|
image = np.asarray(image)
|
||||||
masks = predictions.get_field("mask").numpy()
|
masks = predictions.get_field("mask").numpy()
|
||||||
|
@ -231,17 +240,92 @@ def overlay_mask(image, predictions):
|
||||||
|
|
||||||
return composite
|
return composite
|
||||||
|
|
||||||
|
|
||||||
CATEGORIES = [
|
CATEGORIES = [
|
||||||
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
"__background",
|
||||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
|
"person",
|
||||||
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
|
"bicycle",
|
||||||
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
|
"car",
|
||||||
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
|
"motorcycle",
|
||||||
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
|
"airplane",
|
||||||
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
|
"bus",
|
||||||
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
"train",
|
||||||
|
"truck",
|
||||||
|
"boat",
|
||||||
|
"traffic light",
|
||||||
|
"fire hydrant",
|
||||||
|
"stop sign",
|
||||||
|
"parking meter",
|
||||||
|
"bench",
|
||||||
|
"bird",
|
||||||
|
"cat",
|
||||||
|
"dog",
|
||||||
|
"horse",
|
||||||
|
"sheep",
|
||||||
|
"cow",
|
||||||
|
"elephant",
|
||||||
|
"bear",
|
||||||
|
"zebra",
|
||||||
|
"giraffe",
|
||||||
|
"backpack",
|
||||||
|
"umbrella",
|
||||||
|
"handbag",
|
||||||
|
"tie",
|
||||||
|
"suitcase",
|
||||||
|
"frisbee",
|
||||||
|
"skis",
|
||||||
|
"snowboard",
|
||||||
|
"sports ball",
|
||||||
|
"kite",
|
||||||
|
"baseball bat",
|
||||||
|
"baseball glove",
|
||||||
|
"skateboard",
|
||||||
|
"surfboard",
|
||||||
|
"tennis racket",
|
||||||
|
"bottle",
|
||||||
|
"wine glass",
|
||||||
|
"cup",
|
||||||
|
"fork",
|
||||||
|
"knife",
|
||||||
|
"spoon",
|
||||||
|
"bowl",
|
||||||
|
"banana",
|
||||||
|
"apple",
|
||||||
|
"sandwich",
|
||||||
|
"orange",
|
||||||
|
"broccoli",
|
||||||
|
"carrot",
|
||||||
|
"hot dog",
|
||||||
|
"pizza",
|
||||||
|
"donut",
|
||||||
|
"cake",
|
||||||
|
"chair",
|
||||||
|
"couch",
|
||||||
|
"potted plant",
|
||||||
|
"bed",
|
||||||
|
"dining table",
|
||||||
|
"toilet",
|
||||||
|
"tv",
|
||||||
|
"laptop",
|
||||||
|
"mouse",
|
||||||
|
"remote",
|
||||||
|
"keyboard",
|
||||||
|
"cell phone",
|
||||||
|
"microwave",
|
||||||
|
"oven",
|
||||||
|
"toaster",
|
||||||
|
"sink",
|
||||||
|
"refrigerator",
|
||||||
|
"book",
|
||||||
|
"clock",
|
||||||
|
"vase",
|
||||||
|
"scissors",
|
||||||
|
"teddy bear",
|
||||||
|
"hair drier",
|
||||||
|
"toothbrush",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def overlay_boxes(image, predictions):
|
def overlay_boxes(image, predictions):
|
||||||
labels = predictions.get_field("labels").numpy()
|
labels = predictions.get_field("labels").numpy()
|
||||||
boxes = predictions.bbox
|
boxes = predictions.bbox
|
||||||
|
@ -258,6 +342,7 @@ def overlay_boxes(image, predictions):
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def overlay_class_names(image, predictions):
|
def overlay_class_names(image, predictions):
|
||||||
scores = predictions.get_field("scores").numpy().tolist()
|
scores = predictions.get_field("scores").numpy().tolist()
|
||||||
labels = predictions.get_field("labels").numpy().tolist()
|
labels = predictions.get_field("labels").numpy().tolist()
|
||||||
|
@ -269,26 +354,35 @@ def overlay_class_names(image, predictions):
|
||||||
x, y = box[:2]
|
x, y = box[:2]
|
||||||
s = template.format(label, score)
|
s = template.format(label, score)
|
||||||
x, y = int(x), int(y)
|
x, y = int(x), int(y)
|
||||||
cv2.putText(
|
cv2.putText(image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||||
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
|
|
||||||
)
|
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--image', type=str, help="Path of the image to run")
|
description="Run MaskRCNN",
|
||||||
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
|
)
|
||||||
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
parser.add_argument("--image", type=str, help="Path of the image to run")
|
||||||
|
parser.add_argument(
|
||||||
|
"--threshold", type=float, default=0.7, help="Detector threshold"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--size_scale", type=float, default=1.0, help="Image resize multiplier"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out", type=str, default="/tmp/rendered.png", help="Output filename"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||||
model_tiny = MaskRCNN(resnet)
|
model_tiny = MaskRCNN(resnet)
|
||||||
model_tiny.load_from_pretrained()
|
model_tiny.load_from_pretrained()
|
||||||
img = Image.open(args.image)
|
img = Image.open(args.image)
|
||||||
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
|
top_result_tiny = compute_prediction(
|
||||||
|
img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale
|
||||||
|
)
|
||||||
bbox_image = overlay_boxes(img, top_result_tiny)
|
bbox_image = overlay_boxes(img, top_result_tiny)
|
||||||
mask_image = overlay_mask(bbox_image, top_result_tiny)
|
mask_image = overlay_mask(bbox_image, top_result_tiny)
|
||||||
final_image = overlay_class_names(mask_image, top_result_tiny)
|
final_image = overlay_class_names(mask_image, top_result_tiny)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import unicodedata
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
|
|
||||||
def gaussian_kernel(n, std):
|
def gaussian_kernel(n, std):
|
||||||
gaussian_1d = signal.gaussian(n, std)
|
gaussian_1d = signal.gaussian(n, std)
|
||||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||||
|
@ -12,14 +13,18 @@ def gaussian_kernel(n, std):
|
||||||
gaussian_3d /= gaussian_3d.max()
|
gaussian_3d /= gaussian_3d.max()
|
||||||
return gaussian_3d
|
return gaussian_3d
|
||||||
|
|
||||||
|
|
||||||
def prepare_arrays(image, roi_shape=(128, 128, 128)):
|
def prepare_arrays(image, roi_shape=(128, 128, 128)):
|
||||||
assert len(roi_shape) == 3 and any(roi_shape)
|
assert len(roi_shape) == 3 and any(roi_shape)
|
||||||
image_shape = list(image.shape[2:])
|
image_shape = list(image.shape[2:])
|
||||||
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
|
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
|
||||||
norm_map = np.zeros_like(result)
|
norm_map = np.zeros_like(result)
|
||||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
|
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(
|
||||||
|
norm_map.dtype
|
||||||
|
)
|
||||||
return result, norm_map, norm_patch
|
return result, norm_map, norm_patch
|
||||||
|
|
||||||
|
|
||||||
def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
|
def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
|
||||||
assert len(roi_shape) == 3 and any(roi_shape)
|
assert len(roi_shape) == 3 and any(roi_shape)
|
||||||
assert 0 < overlap_factor < 1
|
assert 0 < overlap_factor < 1
|
||||||
|
@ -31,25 +36,35 @@ def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
|
||||||
for k in range(0, strides[2] * size[2], strides[2]):
|
for k in range(0, strides[2] * size[2], strides[2]):
|
||||||
yield i, j, k
|
yield i, j, k
|
||||||
|
|
||||||
|
|
||||||
def _get_best_indices(logits, n_best_size):
|
def _get_best_indices(logits, n_best_size):
|
||||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||||
return list(map(lambda x: x[0], index_and_score))[:n_best_size]
|
return list(map(lambda x: x[0], index_and_score))[:n_best_size]
|
||||||
|
|
||||||
|
|
||||||
def _is_punctuation(char):
|
def _is_punctuation(char):
|
||||||
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
|
if (
|
||||||
|
(cp := ord(char)) in range(33, 48)
|
||||||
|
or cp in range(58, 65)
|
||||||
|
or cp in range(91, 97)
|
||||||
|
or cp in range(123, 127)
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
return unicodedata.category(char).startswith("P")
|
return unicodedata.category(char).startswith("P")
|
||||||
|
|
||||||
|
|
||||||
def _is_whitespace(char):
|
def _is_whitespace(char):
|
||||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||||
return True
|
return True
|
||||||
return unicodedata.category(char) == "Zs"
|
return unicodedata.category(char) == "Zs"
|
||||||
|
|
||||||
|
|
||||||
def _is_control(char):
|
def _is_control(char):
|
||||||
if char == "\t" or char == "\n" or char == "\r":
|
if char == "\t" or char == "\n" or char == "\r":
|
||||||
return False
|
return False
|
||||||
return unicodedata.category(char).startswith("C")
|
return unicodedata.category(char).startswith("C")
|
||||||
|
|
||||||
|
|
||||||
def _run_split_on_punc(text):
|
def _run_split_on_punc(text):
|
||||||
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
||||||
return [text]
|
return [text]
|
||||||
|
@ -66,6 +81,7 @@ def _run_split_on_punc(text):
|
||||||
output[-1].append(char)
|
output[-1].append(char)
|
||||||
return ["".join(x) for x in output]
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
|
|
||||||
def _run_strip_accents(text):
|
def _run_strip_accents(text):
|
||||||
output = []
|
output = []
|
||||||
for char in unicodedata.normalize("NFD", text):
|
for char in unicodedata.normalize("NFD", text):
|
||||||
|
@ -73,13 +89,15 @@ def _run_strip_accents(text):
|
||||||
output.append(char)
|
output.append(char)
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
def _clean_text(text):
|
def _clean_text(text):
|
||||||
output = []
|
output = []
|
||||||
for char in text:
|
for char in text:
|
||||||
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
|
if not ((cp := ord(char)) == 0 or cp == 0xFFFD or _is_control(char)):
|
||||||
output.append(" " if _is_whitespace(char) else char)
|
output.append(" " if _is_whitespace(char) else char)
|
||||||
return "".join(output)
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
def _get_final_text(pred_text, orig_text):
|
def _get_final_text(pred_text, orig_text):
|
||||||
def _strip_spaces(text):
|
def _strip_spaces(text):
|
||||||
ns_text = ""
|
ns_text = ""
|
||||||
|
@ -128,32 +146,46 @@ def _get_final_text(pred_text, orig_text):
|
||||||
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
||||||
return output_text
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
def get_bert_qa_prediction(features, example, start_end_logits):
|
def get_bert_qa_prediction(features, example, start_end_logits):
|
||||||
prelim_predictions = []
|
prelim_predictions = []
|
||||||
for i, feature in enumerate(features):
|
for i, feature in enumerate(features):
|
||||||
for start_index in _get_best_indices(start_end_logits[i][0], 20):
|
for start_index in _get_best_indices(start_end_logits[i][0], 20):
|
||||||
for end_index in _get_best_indices(start_end_logits[i][1], 20):
|
for end_index in _get_best_indices(start_end_logits[i][1], 20):
|
||||||
if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]):
|
if start_index >= len(feature["tokens"]) or end_index >= len(
|
||||||
|
feature["tokens"]
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]:
|
if (
|
||||||
|
start_index not in feature["token_to_orig_map"]
|
||||||
|
or end_index not in feature["token_to_orig_map"]
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not feature["token_is_max_context"].get(start_index, False):
|
if not feature["token_is_max_context"].get(start_index, False):
|
||||||
continue
|
continue
|
||||||
if end_index < start_index or end_index - start_index + 1 > 30:
|
if end_index < start_index or end_index - start_index + 1 > 30:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prelim_predictions.append({
|
prelim_predictions.append(
|
||||||
|
{
|
||||||
"feature_index": i,
|
"feature_index": i,
|
||||||
"start_index": start_index,
|
"start_index": start_index,
|
||||||
"end_index": end_index,
|
"end_index": end_index,
|
||||||
"start_logit": start_end_logits[i][0, start_index],
|
"start_logit": start_end_logits[i][0, start_index],
|
||||||
"end_logit": start_end_logits[i][1, end_index]
|
"end_logit": start_end_logits[i][1, end_index],
|
||||||
})
|
}
|
||||||
predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True)
|
)
|
||||||
|
predictions = sorted(
|
||||||
|
prelim_predictions,
|
||||||
|
key=lambda x: (x["start_logit"] + x["end_logit"]),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
if len(predictions) > 0:
|
if len(predictions) > 0:
|
||||||
feature = features[predictions[0]["feature_index"]]
|
feature = features[predictions[0]["feature_index"]]
|
||||||
tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)]
|
tok_tokens = feature["tokens"][
|
||||||
|
predictions[0]["start_index"] : (predictions[0]["end_index"] + 1)
|
||||||
|
]
|
||||||
orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
|
orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
|
||||||
orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
|
orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
|
||||||
orig_tokens = example["context"][orig_doc_start : (orig_doc_end + 1)]
|
orig_tokens = example["context"][orig_doc_start : (orig_doc_end + 1)]
|
||||||
|
|
|
@ -3,6 +3,7 @@ import string
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def levenshtein(a, b):
|
def levenshtein(a, b):
|
||||||
n, m = len(a), len(b)
|
n, m = len(a), len(b)
|
||||||
if n > m:
|
if n > m:
|
||||||
|
@ -20,6 +21,7 @@ def levenshtein(a, b):
|
||||||
|
|
||||||
return current[n]
|
return current[n]
|
||||||
|
|
||||||
|
|
||||||
def word_error_rate(x, y):
|
def word_error_rate(x, y):
|
||||||
scores = words = 0
|
scores = words = 0
|
||||||
for h, r in zip(x, y):
|
for h, r in zip(x, y):
|
||||||
|
@ -29,12 +31,14 @@ def word_error_rate(x, y):
|
||||||
scores += levenshtein(h_list, r_list)
|
scores += levenshtein(h_list, r_list)
|
||||||
return float(scores) / words, float(scores), words
|
return float(scores) / words, float(scores), words
|
||||||
|
|
||||||
|
|
||||||
def one_hot(arr, num_classes=3):
|
def one_hot(arr, num_classes=3):
|
||||||
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
|
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
|
||||||
arr = res.reshape(list(arr.shape) + [num_classes])
|
arr = res.reshape(list(arr.shape) + [num_classes])
|
||||||
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
|
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
|
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
|
||||||
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
||||||
prediction = prediction.argmax(axis=channel_axis)
|
prediction = prediction.argmax(axis=channel_axis)
|
||||||
|
@ -42,14 +46,18 @@ def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr
|
||||||
intersection = np.sum(prediction * target, axis=reduce_axis)
|
intersection = np.sum(prediction * target, axis=reduce_axis)
|
||||||
target_sum = np.sum(target, axis=reduce_axis)
|
target_sum = np.sum(target, axis=reduce_axis)
|
||||||
prediction_sum = np.sum(prediction, axis=reduce_axis)
|
prediction_sum = np.sum(prediction, axis=reduce_axis)
|
||||||
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
|
result = (2.0 * intersection + smooth_nr) / (
|
||||||
|
target_sum + prediction_sum + smooth_dr
|
||||||
|
)
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
def normalize_string(s):
|
def normalize_string(s):
|
||||||
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
||||||
s = re.sub(r'\b(a|an|the)\b', ' ', s)
|
s = re.sub(r"\b(a|an|the)\b", " ", s)
|
||||||
return " ".join(s.split())
|
return " ".join(s.split())
|
||||||
|
|
||||||
|
|
||||||
def f1_score(x, y):
|
def f1_score(x, y):
|
||||||
xt = normalize_string(x).split()
|
xt = normalize_string(x).split()
|
||||||
yt = normalize_string(y).split()
|
yt = normalize_string(y).split()
|
||||||
|
|
|
@ -6,15 +6,18 @@ from tinygrad.jit import TinyJit
|
||||||
from tinygrad.helpers import getenv, dtypes, GlobalCounters
|
from tinygrad.helpers import getenv, dtypes, GlobalCounters
|
||||||
from examples.mlperf import helpers
|
from examples.mlperf import helpers
|
||||||
|
|
||||||
|
|
||||||
def eval_resnet():
|
def eval_resnet():
|
||||||
# Resnet50-v1.5
|
# Resnet50-v1.5
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
from extra.models.resnet import ResNet50
|
from extra.models.resnet import ResNet50
|
||||||
|
|
||||||
mdl = ResNet50()
|
mdl = ResNet50()
|
||||||
mdl.load_from_pretrained()
|
mdl.load_from_pretrained()
|
||||||
|
|
||||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||||
|
|
||||||
def input_fixup(x):
|
def input_fixup(x):
|
||||||
x = x.permute([0, 3, 1, 2]).cast(dtypes.float32) / 255.0
|
x = x.permute([0, 3, 1, 2]).cast(dtypes.float32) / 255.0
|
||||||
x -= input_mean
|
x -= input_mean
|
||||||
|
@ -48,14 +51,18 @@ def eval_resnet():
|
||||||
et = time.perf_counter()
|
et = time.perf_counter()
|
||||||
n += (t == y).sum()
|
n += (t == y).sum()
|
||||||
d += len(t)
|
d += len(t)
|
||||||
print(f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS")
|
print(
|
||||||
|
f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS"
|
||||||
|
)
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
def eval_unet3d():
|
def eval_unet3d():
|
||||||
# UNet3D
|
# UNet3D
|
||||||
from extra.models.unet3d import UNet3D
|
from extra.models.unet3d import UNet3D
|
||||||
from extra.datasets.kits19 import iterate, sliding_window_inference
|
from extra.datasets.kits19 import iterate, sliding_window_inference
|
||||||
from examples.mlperf.metrics import get_dice_score
|
from examples.mlperf.metrics import get_dice_score
|
||||||
|
|
||||||
mdl = UNet3D()
|
mdl = UNet3D()
|
||||||
mdl.load_from_pretrained()
|
mdl.load_from_pretrained()
|
||||||
s = 0
|
s = 0
|
||||||
|
@ -69,15 +76,18 @@ def eval_unet3d():
|
||||||
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
def eval_retinanet():
|
def eval_retinanet():
|
||||||
# RetinaNet with ResNeXt50_32X4D
|
# RetinaNet with ResNeXt50_32X4D
|
||||||
from extra.models.resnet import ResNeXt50_32X4D
|
from extra.models.resnet import ResNeXt50_32X4D
|
||||||
from extra.models.retinanet import RetinaNet
|
from extra.models.retinanet import RetinaNet
|
||||||
|
|
||||||
mdl = RetinaNet(ResNeXt50_32X4D())
|
mdl = RetinaNet(ResNeXt50_32X4D())
|
||||||
mdl.load_from_pretrained()
|
mdl.load_from_pretrained()
|
||||||
|
|
||||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||||
|
|
||||||
def input_fixup(x):
|
def input_fixup(x):
|
||||||
x = x.permute([0, 3, 1, 2]) / 255.0
|
x = x.permute([0, 3, 1, 2]) / 255.0
|
||||||
x -= input_mean
|
x -= input_mean
|
||||||
|
@ -88,11 +98,18 @@ def eval_retinanet():
|
||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
from pycocotools.cocoeval import COCOeval
|
from pycocotools.cocoeval import COCOeval
|
||||||
from contextlib import redirect_stdout
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
coco = COCO(openimages())
|
coco = COCO(openimages())
|
||||||
coco_eval = COCOeval(coco, iouType="bbox")
|
coco_eval = COCOeval(coco, iouType="bbox")
|
||||||
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
|
coco_evalimgs, evaluated_imgs, ncats, narea = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
len(coco_eval.params.catIds),
|
||||||
|
len(coco_eval.params.areaRng),
|
||||||
|
)
|
||||||
|
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
|
|
||||||
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
|
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
|
||||||
|
|
||||||
n, bs = 0, 8
|
n, bs = 0, 8
|
||||||
|
@ -106,19 +123,35 @@ def eval_retinanet():
|
||||||
mdlrun.jit_cache = None
|
mdlrun.jit_cache = None
|
||||||
outs = mdl(input_fixup(dat)).numpy()
|
outs = mdl(input_fixup(dat)).numpy()
|
||||||
et = time.perf_counter()
|
et = time.perf_counter()
|
||||||
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
|
predictions = mdl.postprocess_detections(
|
||||||
|
outs,
|
||||||
|
input_size=dat.shape[1:3],
|
||||||
|
orig_image_sizes=[t["image_size"] for t in targets],
|
||||||
|
)
|
||||||
ext = time.perf_counter()
|
ext = time.perf_counter()
|
||||||
n += len(targets)
|
n += len(targets)
|
||||||
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
|
print(
|
||||||
|
f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing"
|
||||||
|
)
|
||||||
img_ids = [t["image_id"] for t in targets]
|
img_ids = [t["image_id"] for t in targets]
|
||||||
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score}
|
coco_results = [
|
||||||
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
|
{
|
||||||
|
"image_id": targets[i]["image_id"],
|
||||||
|
"category_id": label,
|
||||||
|
"bbox": box,
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
for i, prediction in enumerate(predictions)
|
||||||
|
for box, score, label in zip(*prediction.values())
|
||||||
|
]
|
||||||
with redirect_stdout(None):
|
with redirect_stdout(None):
|
||||||
coco_eval.cocoDt = coco.loadRes(coco_results)
|
coco_eval.cocoDt = coco.loadRes(coco_results)
|
||||||
coco_eval.params.imgIds = img_ids
|
coco_eval.params.imgIds = img_ids
|
||||||
coco_eval.evaluate()
|
coco_eval.evaluate()
|
||||||
evaluated_imgs.extend(img_ids)
|
evaluated_imgs.extend(img_ids)
|
||||||
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
|
coco_evalimgs.append(
|
||||||
|
np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids))
|
||||||
|
)
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
||||||
coco_eval.params.imgIds = evaluated_imgs
|
coco_eval.params.imgIds = evaluated_imgs
|
||||||
|
@ -127,16 +160,47 @@ def eval_retinanet():
|
||||||
coco_eval.accumulate()
|
coco_eval.accumulate()
|
||||||
coco_eval.summarize()
|
coco_eval.summarize()
|
||||||
|
|
||||||
|
|
||||||
def eval_rnnt():
|
def eval_rnnt():
|
||||||
# RNN-T
|
# RNN-T
|
||||||
from extra.models.rnnt import RNNT
|
from extra.models.rnnt import RNNT
|
||||||
|
|
||||||
mdl = RNNT()
|
mdl = RNNT()
|
||||||
mdl.load_from_pretrained()
|
mdl.load_from_pretrained()
|
||||||
|
|
||||||
from extra.datasets.librispeech import iterate
|
from extra.datasets.librispeech import iterate
|
||||||
from examples.mlperf.metrics import word_error_rate
|
from examples.mlperf.metrics import word_error_rate
|
||||||
|
|
||||||
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
|
LABELS = [
|
||||||
|
" ",
|
||||||
|
"a",
|
||||||
|
"b",
|
||||||
|
"c",
|
||||||
|
"d",
|
||||||
|
"e",
|
||||||
|
"f",
|
||||||
|
"g",
|
||||||
|
"h",
|
||||||
|
"i",
|
||||||
|
"j",
|
||||||
|
"k",
|
||||||
|
"l",
|
||||||
|
"m",
|
||||||
|
"n",
|
||||||
|
"o",
|
||||||
|
"p",
|
||||||
|
"q",
|
||||||
|
"r",
|
||||||
|
"s",
|
||||||
|
"t",
|
||||||
|
"u",
|
||||||
|
"v",
|
||||||
|
"w",
|
||||||
|
"x",
|
||||||
|
"y",
|
||||||
|
"z",
|
||||||
|
"'",
|
||||||
|
]
|
||||||
|
|
||||||
c = 0
|
c = 0
|
||||||
scores = 0
|
scores = 0
|
||||||
|
@ -149,16 +213,20 @@ def eval_rnnt():
|
||||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||||
for n, t in enumerate(tt):
|
for n, t in enumerate(tt):
|
||||||
tnp = np.array(t)
|
tnp = np.array(t)
|
||||||
_, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
|
_, scores_, words_ = word_error_rate(
|
||||||
|
["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]]
|
||||||
|
)
|
||||||
scores += scores_
|
scores += scores_
|
||||||
words += words_
|
words += words_
|
||||||
c += len(tt)
|
c += len(tt)
|
||||||
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
|
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
def eval_bert():
|
def eval_bert():
|
||||||
# Bert-QA
|
# Bert-QA
|
||||||
from extra.models.bert import BertForQuestionAnswering
|
from extra.models.bert import BertForQuestionAnswering
|
||||||
|
|
||||||
mdl = BertForQuestionAnswering()
|
mdl = BertForQuestionAnswering()
|
||||||
mdl.load_from_pretrained()
|
mdl.load_from_pretrained()
|
||||||
|
|
||||||
|
@ -180,9 +248,17 @@ def eval_bert():
|
||||||
mt = time.perf_counter()
|
mt = time.perf_counter()
|
||||||
outs = []
|
outs = []
|
||||||
for x in X:
|
for x in X:
|
||||||
outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy())
|
outs.append(
|
||||||
|
run(
|
||||||
|
Tensor(x["input_ids"]),
|
||||||
|
Tensor(x["input_mask"]),
|
||||||
|
Tensor(x["segment_ids"]),
|
||||||
|
).numpy()
|
||||||
|
)
|
||||||
et = time.perf_counter()
|
et = time.perf_counter()
|
||||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features")
|
print(
|
||||||
|
f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features"
|
||||||
|
)
|
||||||
|
|
||||||
pred = get_bert_qa_prediction(X, Y, outs)
|
pred = get_bert_qa_prediction(X, Y, outs)
|
||||||
print(f"pred: {pred}\nans: {Y['answers']}")
|
print(f"pred: {pred}\nans: {Y['answers']}")
|
||||||
|
@ -192,17 +268,27 @@ def eval_bert():
|
||||||
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
def eval_mrcnn():
|
def eval_mrcnn():
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from extra.models.mask_rcnn import MaskRCNN
|
from extra.models.mask_rcnn import MaskRCNN
|
||||||
from extra.models.resnet import ResNet
|
from extra.models.resnet import ResNet
|
||||||
from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
from extra.datasets.coco import (
|
||||||
|
BASEDIR,
|
||||||
|
images,
|
||||||
|
convert_prediction_to_coco_bbox,
|
||||||
|
convert_prediction_to_coco_mask,
|
||||||
|
accumulate_predictions_for_coco,
|
||||||
|
evaluate_predictions_on_coco,
|
||||||
|
iterate,
|
||||||
|
)
|
||||||
from examples.mask_rcnn import compute_prediction_batched, Image
|
from examples.mask_rcnn import compute_prediction_batched, Image
|
||||||
|
|
||||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||||
mdl.load_from_pretrained()
|
mdl.load_from_pretrained()
|
||||||
|
|
||||||
bbox_output = '/tmp/results_bbox.json'
|
bbox_output = "/tmp/results_bbox.json"
|
||||||
mask_output = '/tmp/results_mask.json'
|
mask_output = "/tmp/results_mask.json"
|
||||||
|
|
||||||
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
||||||
accumulate_predictions_for_coco([], mask_output, rm=True)
|
accumulate_predictions_for_coco([], mask_output, rm=True)
|
||||||
|
@ -213,12 +299,12 @@ def eval_mrcnn():
|
||||||
for batch in tqdm(iterate(images, bs=bs), total=len(images) // bs):
|
for batch in tqdm(iterate(images, bs=bs), total=len(images) // bs):
|
||||||
batch_imgs = []
|
batch_imgs = []
|
||||||
for image_row in batch:
|
for image_row in batch:
|
||||||
image_name = image_row['file_name']
|
image_name = image_row["file_name"]
|
||||||
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
|
img = Image.open(BASEDIR / f"val2017/{image_name}").convert("RGB")
|
||||||
batch_imgs.append(img)
|
batch_imgs.append(img)
|
||||||
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
||||||
for image_row, result in zip(batch, batch_result):
|
for image_row, result in zip(batch, batch_result):
|
||||||
image_name = image_row['file_name']
|
image_name = image_row["file_name"]
|
||||||
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
||||||
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
||||||
accumulate_predictions_for_coco(box_pred, bbox_output)
|
accumulate_predictions_for_coco(box_pred, bbox_output)
|
||||||
|
@ -226,8 +312,9 @@ def eval_mrcnn():
|
||||||
del batch_imgs
|
del batch_imgs
|
||||||
del batch_result
|
del batch_result
|
||||||
|
|
||||||
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
|
evaluate_predictions_on_coco(bbox_output, iou_type="bbox")
|
||||||
evaluate_predictions_on_coco(mask_output, iou_type='segm')
|
evaluate_predictions_on_coco(mask_output, iou_type="segm")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# inference only
|
# inference only
|
||||||
|
|
|
@ -3,46 +3,60 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import GlobalCounters, getenv
|
from tinygrad.helpers import GlobalCounters, getenv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def test_model(model, *inputs):
|
def test_model(model, *inputs):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
out = model(*inputs)
|
out = model(*inputs)
|
||||||
if isinstance(out, Tensor): out = out.numpy()
|
if isinstance(out, Tensor):
|
||||||
|
out = out.numpy()
|
||||||
# TODO: return event future to still get the time_sum_s without DEBUG=2
|
# TODO: return event future to still get the time_sum_s without DEBUG=2
|
||||||
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
|
print(
|
||||||
|
f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def spec_resnet():
|
def spec_resnet():
|
||||||
# Resnet50-v1.5
|
# Resnet50-v1.5
|
||||||
from extra.models.resnet import ResNet50
|
from extra.models.resnet import ResNet50
|
||||||
|
|
||||||
mdl = ResNet50()
|
mdl = ResNet50()
|
||||||
img = Tensor.randn(1, 3, 224, 224)
|
img = Tensor.randn(1, 3, 224, 224)
|
||||||
test_model(mdl, img)
|
test_model(mdl, img)
|
||||||
|
|
||||||
|
|
||||||
def spec_retinanet():
|
def spec_retinanet():
|
||||||
# Retinanet with ResNet backbone
|
# Retinanet with ResNet backbone
|
||||||
from extra.models.resnet import ResNet50
|
from extra.models.resnet import ResNet50
|
||||||
from extra.models.retinanet import RetinaNet
|
from extra.models.retinanet import RetinaNet
|
||||||
|
|
||||||
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
|
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
|
||||||
img = Tensor.randn(1, 3, 224, 224)
|
img = Tensor.randn(1, 3, 224, 224)
|
||||||
test_model(mdl, img)
|
test_model(mdl, img)
|
||||||
|
|
||||||
|
|
||||||
def spec_unet3d():
|
def spec_unet3d():
|
||||||
# 3D UNET
|
# 3D UNET
|
||||||
from extra.models.unet3d import UNet3D
|
from extra.models.unet3d import UNet3D
|
||||||
|
|
||||||
mdl = UNet3D()
|
mdl = UNet3D()
|
||||||
# mdl.load_from_pretrained()
|
# mdl.load_from_pretrained()
|
||||||
img = Tensor.randn(1, 1, 128, 128, 128)
|
img = Tensor.randn(1, 1, 128, 128, 128)
|
||||||
test_model(mdl, img)
|
test_model(mdl, img)
|
||||||
|
|
||||||
|
|
||||||
def spec_rnnt():
|
def spec_rnnt():
|
||||||
from extra.models.rnnt import RNNT
|
from extra.models.rnnt import RNNT
|
||||||
|
|
||||||
mdl = RNNT()
|
mdl = RNNT()
|
||||||
# mdl.load_from_pretrained()
|
# mdl.load_from_pretrained()
|
||||||
x = Tensor.randn(220, 1, 240)
|
x = Tensor.randn(220, 1, 240)
|
||||||
y = Tensor.randn(1, 220)
|
y = Tensor.randn(1, 220)
|
||||||
test_model(mdl, x, y)
|
test_model(mdl, x, y)
|
||||||
|
|
||||||
|
|
||||||
def spec_bert():
|
def spec_bert():
|
||||||
from extra.models.bert import BertForQuestionAnswering
|
from extra.models.bert import BertForQuestionAnswering
|
||||||
|
|
||||||
mdl = BertForQuestionAnswering()
|
mdl = BertForQuestionAnswering()
|
||||||
# mdl.load_from_pretrained()
|
# mdl.load_from_pretrained()
|
||||||
x = Tensor.randn(1, 384)
|
x = Tensor.randn(1, 384)
|
||||||
|
@ -50,13 +64,16 @@ def spec_bert():
|
||||||
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
||||||
test_model(mdl, x, am, tt)
|
test_model(mdl, x, am, tt)
|
||||||
|
|
||||||
|
|
||||||
def spec_mrcnn():
|
def spec_mrcnn():
|
||||||
from extra.models.mask_rcnn import MaskRCNN, ResNet
|
from extra.models.mask_rcnn import MaskRCNN, ResNet
|
||||||
|
|
||||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||||
# mdl.load_from_pretrained()
|
# mdl.load_from_pretrained()
|
||||||
x = Tensor.randn(3, 224, 224)
|
x = Tensor.randn(3, 224, 224)
|
||||||
test_model(mdl, [x])
|
test_model(mdl, [x])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# inference only for now
|
# inference only for now
|
||||||
Tensor.training = False
|
Tensor.training = False
|
||||||
|
@ -67,4 +84,3 @@ if __name__ == "__main__":
|
||||||
if nm in globals():
|
if nm in globals():
|
||||||
print(f"testing {m}")
|
print(f"testing {m}")
|
||||||
globals()[nm]()
|
globals()[nm]()
|
||||||
|
|
||||||
|
|
|
@ -1,36 +1,43 @@
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
|
|
||||||
|
|
||||||
def train_resnet():
|
def train_resnet():
|
||||||
# TODO: Resnet50-v1.5
|
# TODO: Resnet50-v1.5
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def train_retinanet():
|
def train_retinanet():
|
||||||
# TODO: Retinanet
|
# TODO: Retinanet
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def train_unet3d():
|
def train_unet3d():
|
||||||
# TODO: Unet3d
|
# TODO: Unet3d
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def train_rnnt():
|
def train_rnnt():
|
||||||
# TODO: RNN-T
|
# TODO: RNN-T
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def train_bert():
|
def train_bert():
|
||||||
# TODO: BERT
|
# TODO: BERT
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def train_maskrcnn():
|
def train_maskrcnn():
|
||||||
# TODO: Mask RCNN
|
# TODO: Mask RCNN
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with Tensor.train():
|
with Tensor.train():
|
||||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(
|
||||||
|
","
|
||||||
|
):
|
||||||
nm = f"train_{m}"
|
nm = f"train_{m}"
|
||||||
if nm in globals():
|
if nm in globals():
|
||||||
print(f"training {m}")
|
print(f"training {m}")
|
||||||
globals()[nm]()
|
globals()[nm]()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from tinygrad.helpers import getenv
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
|
|
||||||
|
|
||||||
class LinearGen:
|
class LinearGen:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.l1 = Tensor.scaled_uniform(128, 256)
|
self.l1 = Tensor.scaled_uniform(128, 256)
|
||||||
|
@ -23,6 +24,7 @@ class LinearGen:
|
||||||
x = x.dot(self.l4).tanh()
|
x = x.dot(self.l4).tanh()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LinearDisc:
|
class LinearDisc:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.l1 = Tensor.scaled_uniform(784, 1024)
|
self.l1 = Tensor.scaled_uniform(784, 1024)
|
||||||
|
@ -38,16 +40,21 @@ class LinearDisc:
|
||||||
x = x.dot(self.l4).log_softmax()
|
x = x.dot(self.l4).log_softmax()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def make_batch(images):
|
def make_batch(images):
|
||||||
sample = np.random.randint(0, len(images), size=(batch_size))
|
sample = np.random.randint(0, len(images), size=(batch_size))
|
||||||
image_b = images[sample].reshape(-1, 28 * 28).astype(np.float32) / 127.5 - 1.0
|
image_b = images[sample].reshape(-1, 28 * 28).astype(np.float32) / 127.5 - 1.0
|
||||||
return Tensor(image_b)
|
return Tensor(image_b)
|
||||||
|
|
||||||
|
|
||||||
def make_labels(bs, col, val=-2.0):
|
def make_labels(bs, col, val=-2.0):
|
||||||
y = np.zeros((bs, 2), np.float32)
|
y = np.zeros((bs, 2), np.float32)
|
||||||
y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
|
y[
|
||||||
|
range(bs), [col] * bs
|
||||||
|
] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
|
||||||
return Tensor(y)
|
return Tensor(y)
|
||||||
|
|
||||||
|
|
||||||
def train_discriminator(optimizer, data_real, data_fake):
|
def train_discriminator(optimizer, data_real, data_fake):
|
||||||
real_labels = make_labels(batch_size, 1)
|
real_labels = make_labels(batch_size, 1)
|
||||||
fake_labels = make_labels(batch_size, 0)
|
fake_labels = make_labels(batch_size, 0)
|
||||||
|
@ -61,6 +68,7 @@ def train_discriminator(optimizer, data_real, data_fake):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
return (loss_real + loss_fake).numpy()
|
return (loss_real + loss_fake).numpy()
|
||||||
|
|
||||||
|
|
||||||
def train_generator(optimizer, data_fake):
|
def train_generator(optimizer, data_fake):
|
||||||
real_labels = make_labels(batch_size, 1)
|
real_labels = make_labels(batch_size, 1)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -70,6 +78,7 @@ def train_generator(optimizer, data_fake):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
return loss.numpy()
|
return loss.numpy()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# data for training and validation
|
# data for training and validation
|
||||||
images_real = np.vstack(fetch_mnist()[::2])
|
images_real = np.vstack(fetch_mnist()[::2])
|
||||||
|
@ -85,7 +94,9 @@ if __name__ == "__main__":
|
||||||
output_dir = Path(".").resolve() / "outputs"
|
output_dir = Path(".").resolve() / "outputs"
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
# optimizers
|
# optimizers
|
||||||
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
optim_g = optim.Adam(
|
||||||
|
get_parameters(generator), lr=0.0002, b1=0.5
|
||||||
|
) # 0.0002 for equilibrium!
|
||||||
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
||||||
# training loop
|
# training loop
|
||||||
for epoch in (t := trange(epochs)):
|
for epoch in (t := trange(epochs)):
|
||||||
|
@ -102,6 +113,11 @@ if __name__ == "__main__":
|
||||||
if (epoch + 1) % sample_interval == 0:
|
if (epoch + 1) % sample_interval == 0:
|
||||||
fake_images = generator.forward(ds_noise).detach().numpy()
|
fake_images = generator.forward(ds_noise).detach().numpy()
|
||||||
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
|
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
|
||||||
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
|
save_image(
|
||||||
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
|
make_grid(torch.tensor(fake_images)),
|
||||||
|
output_dir / f"image_{epoch+1}.jpg",
|
||||||
|
)
|
||||||
|
t.set_description(
|
||||||
|
f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}"
|
||||||
|
)
|
||||||
print("Training Completed!")
|
print("Training Completed!")
|
||||||
|
|
|
@ -9,10 +9,12 @@ from tinygrad.helpers import getenv
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
from extra.augment import augment_img
|
from extra.augment import augment_img
|
||||||
from extra.training import train, evaluate
|
from extra.training import train, evaluate
|
||||||
|
|
||||||
GPU = getenv("GPU")
|
GPU = getenv("GPU")
|
||||||
QUICK = getenv("QUICK")
|
QUICK = getenv("QUICK")
|
||||||
DEBUG = getenv("DEBUG")
|
DEBUG = getenv("DEBUG")
|
||||||
|
|
||||||
|
|
||||||
class SqueezeExciteBlock2D:
|
class SqueezeExciteBlock2D:
|
||||||
def __init__(self, filters):
|
def __init__(self, filters):
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
@ -22,7 +24,9 @@ class SqueezeExciteBlock2D:
|
||||||
self.bias2 = Tensor.scaled_uniform(1, self.filters)
|
self.bias2 = Tensor.scaled_uniform(1, self.filters)
|
||||||
|
|
||||||
def __call__(self, input):
|
def __call__(self, input):
|
||||||
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
|
se = input.avg_pool2d(
|
||||||
|
kernel_size=(input.shape[2], input.shape[3])
|
||||||
|
) # GlobalAveragePool2D
|
||||||
se = se.reshape(shape=(-1, self.filters))
|
se = se.reshape(shape=(-1, self.filters))
|
||||||
se = se.dot(self.weight1) + self.bias1
|
se = se.dot(self.weight1) + self.bias1
|
||||||
se = se.relu()
|
se = se.relu()
|
||||||
|
@ -31,12 +35,16 @@ class SqueezeExciteBlock2D:
|
||||||
se = input.mul(se)
|
se = input.mul(se)
|
||||||
return se
|
return se
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock:
|
class ConvBlock:
|
||||||
def __init__(self, h, w, inp, filters=128, conv=3):
|
def __init__(self, h, w, inp, filters=128, conv=3):
|
||||||
self.h, self.w = h, w
|
self.h, self.w = h, w
|
||||||
self.inp = inp
|
self.inp = inp
|
||||||
# init weights
|
# init weights
|
||||||
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
|
self.cweights = [
|
||||||
|
Tensor.scaled_uniform(filters, inp if i == 0 else filters, conv, conv)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
|
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
|
||||||
# init layers
|
# init layers
|
||||||
self._bn = BatchNorm2d(128)
|
self._bn = BatchNorm2d(128)
|
||||||
|
@ -50,9 +58,14 @@ class ConvBlock:
|
||||||
x = self._seb(x)
|
x = self._seb(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BigConvNet:
|
class BigConvNet:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
|
self.conv = [
|
||||||
|
ConvBlock(28, 28, 1),
|
||||||
|
ConvBlock(28, 28, 128),
|
||||||
|
ConvBlock(14, 14, 128),
|
||||||
|
]
|
||||||
self.weight1 = Tensor.scaled_uniform(128, 10)
|
self.weight1 = Tensor.scaled_uniform(128, 10)
|
||||||
self.weight2 = Tensor.scaled_uniform(128, 10)
|
self.weight2 = Tensor.scaled_uniform(128, 10)
|
||||||
|
|
||||||
|
@ -63,19 +76,19 @@ class BigConvNet:
|
||||||
for par in pars:
|
for par in pars:
|
||||||
print(par.shape)
|
print(par.shape)
|
||||||
no_pars += np.prod(par.shape)
|
no_pars += np.prod(par.shape)
|
||||||
print('no of parameters', no_pars)
|
print("no of parameters", no_pars)
|
||||||
return pars
|
return pars
|
||||||
else:
|
else:
|
||||||
return get_parameters(self)
|
return get_parameters(self)
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
with open(filename+'.npy', 'wb') as f:
|
with open(filename + ".npy", "wb") as f:
|
||||||
for par in get_parameters(self):
|
for par in get_parameters(self):
|
||||||
# if par.requires_grad:
|
# if par.requires_grad:
|
||||||
np.save(f, par.numpy())
|
np.save(f, par.numpy())
|
||||||
|
|
||||||
def load(self, filename):
|
def load(self, filename):
|
||||||
with open(filename+'.npy', 'rb') as f:
|
with open(filename + ".npy", "rb") as f:
|
||||||
for par in get_parameters(self):
|
for par in get_parameters(self):
|
||||||
# if par.requires_grad:
|
# if par.requires_grad:
|
||||||
try:
|
try:
|
||||||
|
@ -83,7 +96,7 @@ class BigConvNet:
|
||||||
if GPU:
|
if GPU:
|
||||||
par.gpu()
|
par.gpu()
|
||||||
except:
|
except:
|
||||||
print('Could not load parameter')
|
print("Could not load parameter")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv[0](x)
|
x = self.conv[0](x)
|
||||||
|
@ -102,7 +115,10 @@ if __name__ == "__main__":
|
||||||
BS = 32
|
BS = 32
|
||||||
|
|
||||||
lmbd = 0.00025
|
lmbd = 0.00025
|
||||||
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
lossfn = (
|
||||||
|
lambda out, y: out.sparse_categorical_crossentropy(y)
|
||||||
|
+ lmbd * (model.weight1.abs() + model.weight2.abs()).sum()
|
||||||
|
)
|
||||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||||
|
@ -133,4 +149,4 @@ if __name__ == "__main__":
|
||||||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
||||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||||
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')
|
model.save(f"examples/checkpoint{accuracy * 1e6:.0f}")
|
||||||
|
|
|
@ -6,14 +6,14 @@ from tinygrad.nn.state import get_parameters
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with Tensor.train():
|
with Tensor.train():
|
||||||
|
|
||||||
BS, C1, H, W = 4, 16, 224, 224
|
BS, C1, H, W = 4, 16, 224, 224
|
||||||
C2, K, S, P = 64, 7, 2, 1
|
C2, K, S, P = 64, 7, 2, 1
|
||||||
|
|
||||||
x = Tensor.uniform(BS, C1, H, W)
|
x = Tensor.uniform(BS, C1, H, W)
|
||||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||||
for t in get_parameters([x, conv, bn]): t.realize()
|
for t in get_parameters([x, conv, bn]):
|
||||||
|
t.realize()
|
||||||
|
|
||||||
print("running network")
|
print("running network")
|
||||||
x.sequential([conv, bn]).numpy()
|
x.sequential([conv, bn]).numpy()
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -7,9 +7,17 @@ import soundfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import parselmouth
|
import parselmouth
|
||||||
|
|
||||||
|
|
||||||
class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
|
class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||||
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
|
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
|
||||||
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm"
|
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = (
|
||||||
|
hop_length,
|
||||||
|
f0_min,
|
||||||
|
f0_max,
|
||||||
|
sampling_rate,
|
||||||
|
"pm",
|
||||||
|
)
|
||||||
|
|
||||||
def interpolate_f0(self, f0):
|
def interpolate_f0(self, f0):
|
||||||
vuv_vector = np.zeros_like(f0, dtype=np.float32)
|
vuv_vector = np.zeros_like(f0, dtype=np.float32)
|
||||||
vuv_vector[f0 > 0.0] = 1.0
|
vuv_vector[f0 > 0.0] = 1.0
|
||||||
|
@ -19,79 +27,142 @@ class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||||
nzindex = nzindex.astype(np.float32)
|
nzindex = nzindex.astype(np.float32)
|
||||||
time_org = self.hop_length / self.sampling_rate * nzindex
|
time_org = self.hop_length / self.sampling_rate * nzindex
|
||||||
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
|
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
|
||||||
if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
|
if data.shape[0] <= 0:
|
||||||
if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
|
return np.zeros(f0.shape[0], dtype=np.float32), vuv_vector
|
||||||
|
if data.shape[0] == 1:
|
||||||
|
return np.ones(f0.shape[0], dtype=np.float32) * f0[0], vuv_vector
|
||||||
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
|
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
|
||||||
return f0, vuv_vector
|
return f0, vuv_vector
|
||||||
|
|
||||||
def compute_f0(self, wav, p_len=None):
|
def compute_f0(self, wav, p_len=None):
|
||||||
x = wav
|
x = wav
|
||||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
if p_len is None:
|
||||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
p_len = x.shape[0] // self.hop_length
|
||||||
|
else:
|
||||||
|
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
|
||||||
time_step = self.hop_length / self.sampling_rate * 1000
|
time_step = self.hop_length / self.sampling_rate * 1000
|
||||||
f0 = parselmouth.Sound(x, self.sampling_rate) \
|
f0 = (
|
||||||
.to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \
|
parselmouth.Sound(x, self.sampling_rate)
|
||||||
.selected_array['frequency']
|
.to_pitch_ac(
|
||||||
|
time_step=time_step / 1000,
|
||||||
|
voicing_threshold=0.6,
|
||||||
|
pitch_floor=self.f0_min,
|
||||||
|
pitch_ceiling=self.f0_max,
|
||||||
|
)
|
||||||
|
.selected_array["frequency"]
|
||||||
|
)
|
||||||
pad_size = (p_len - len(f0) + 1) // 2
|
pad_size = (p_len - len(f0) + 1) // 2
|
||||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
||||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
||||||
f0, uv = self.interpolate_f0(f0)
|
f0, uv = self.interpolate_f0(f0)
|
||||||
return f0
|
return f0
|
||||||
|
|
||||||
def compute_f0_uv(self, wav, p_len=None):
|
def compute_f0_uv(self, wav, p_len=None):
|
||||||
x = wav
|
x = wav
|
||||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
if p_len is None:
|
||||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
p_len = x.shape[0] // self.hop_length
|
||||||
|
else:
|
||||||
|
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
|
||||||
time_step = self.hop_length / self.sampling_rate * 1000
|
time_step = self.hop_length / self.sampling_rate * 1000
|
||||||
f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
|
f0 = (
|
||||||
time_step=time_step / 1000, voicing_threshold=0.6,
|
parselmouth.Sound(x, self.sampling_rate)
|
||||||
pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
|
.to_pitch_ac(
|
||||||
|
time_step=time_step / 1000,
|
||||||
|
voicing_threshold=0.6,
|
||||||
|
pitch_floor=self.f0_min,
|
||||||
|
pitch_ceiling=self.f0_max,
|
||||||
|
)
|
||||||
|
.selected_array["frequency"]
|
||||||
|
)
|
||||||
pad_size = (p_len - len(f0) + 1) // 2
|
pad_size = (p_len - len(f0) + 1) // 2
|
||||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
||||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
||||||
f0, uv = self.interpolate_f0(f0)
|
f0, uv = self.interpolate_f0(f0)
|
||||||
return f0, uv
|
return f0, uv
|
||||||
|
|
||||||
|
|
||||||
class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
|
class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||||
def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000):
|
def __init__(
|
||||||
|
self,
|
||||||
|
sr: int,
|
||||||
|
threshold: float = -40.0,
|
||||||
|
min_length: int = 5000,
|
||||||
|
min_interval: int = 300,
|
||||||
|
hop_size: int = 20,
|
||||||
|
max_sil_kept: int = 5000,
|
||||||
|
):
|
||||||
if not min_length >= min_interval >= hop_size:
|
if not min_length >= min_interval >= hop_size:
|
||||||
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
|
raise ValueError(
|
||||||
|
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
|
||||||
|
)
|
||||||
if not max_sil_kept >= hop_size:
|
if not max_sil_kept >= hop_size:
|
||||||
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
|
raise ValueError(
|
||||||
|
"The following condition must be satisfied: max_sil_kept >= hop_size"
|
||||||
|
)
|
||||||
min_interval = sr * min_interval / 1000
|
min_interval = sr * min_interval / 1000
|
||||||
self.threshold = 10 ** (threshold / 20.)
|
self.threshold = 10 ** (threshold / 20.0)
|
||||||
self.hop_size = round(sr * hop_size / 1000)
|
self.hop_size = round(sr * hop_size / 1000)
|
||||||
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||||
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||||
self.min_interval = round(min_interval / self.hop_size)
|
self.min_interval = round(min_interval / self.hop_size)
|
||||||
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||||
|
|
||||||
def _apply_slice(self, waveform, begin, end):
|
def _apply_slice(self, waveform, begin, end):
|
||||||
if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
|
if len(waveform.shape) > 1:
|
||||||
else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
|
return waveform[
|
||||||
|
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return waveform[
|
||||||
|
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
|
||||||
|
]
|
||||||
|
|
||||||
def slice(self, waveform):
|
def slice(self, waveform):
|
||||||
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
|
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
|
||||||
if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
if samples.shape[0] <= self.min_length:
|
||||||
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
||||||
|
rms_list = librosa.feature.rms(
|
||||||
|
y=samples, frame_length=self.win_size, hop_length=self.hop_size
|
||||||
|
).squeeze(0)
|
||||||
sil_tags, silence_start, clip_start = [], None, 0
|
sil_tags, silence_start, clip_start = [], None, 0
|
||||||
for i, rms in enumerate(rms_list):
|
for i, rms in enumerate(rms_list):
|
||||||
if rms < self.threshold: # Keep looping while frame is silent.
|
if rms < self.threshold: # Keep looping while frame is silent.
|
||||||
if silence_start is None: # Record start of silent frames.
|
if silence_start is None: # Record start of silent frames.
|
||||||
silence_start = i
|
silence_start = i
|
||||||
continue
|
continue
|
||||||
if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded.
|
if silence_start is None:
|
||||||
|
continue # Keep looping while frame is not silent and silence start has not been recorded.
|
||||||
# Clear recorded silence start if interval is not enough or clip is too short
|
# Clear recorded silence start if interval is not enough or clip is too short
|
||||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||||
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
need_slice_middle = (
|
||||||
|
i - silence_start >= self.min_interval
|
||||||
|
and i - clip_start >= self.min_length
|
||||||
|
)
|
||||||
if not is_leading_silence and not need_slice_middle:
|
if not is_leading_silence and not need_slice_middle:
|
||||||
silence_start = None
|
silence_start = None
|
||||||
continue
|
continue
|
||||||
if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed.
|
if (
|
||||||
|
i - silence_start <= self.max_sil_kept
|
||||||
|
): # Need slicing. Record the range of silent frames to be removed.
|
||||||
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
||||||
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
||||||
clip_start = pos
|
clip_start = pos
|
||||||
elif i - silence_start <= self.max_sil_kept * 2:
|
elif i - silence_start <= self.max_sil_kept * 2:
|
||||||
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
|
pos = rms_list[
|
||||||
|
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
|
||||||
|
].argmin()
|
||||||
pos += i - self.max_sil_kept
|
pos += i - self.max_sil_kept
|
||||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
pos_l = (
|
||||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
rms_list[
|
||||||
|
silence_start : silence_start + self.max_sil_kept + 1
|
||||||
|
].argmin()
|
||||||
|
+ silence_start
|
||||||
|
)
|
||||||
|
pos_r = (
|
||||||
|
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||||
|
+ i
|
||||||
|
- self.max_sil_kept
|
||||||
|
)
|
||||||
if silence_start == 0:
|
if silence_start == 0:
|
||||||
sil_tags.append((0, pos_r))
|
sil_tags.append((0, pos_r))
|
||||||
clip_start = pos_r
|
clip_start = pos_r
|
||||||
|
@ -99,41 +170,105 @@ class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||||
clip_start = max(pos_r, pos)
|
clip_start = max(pos_r, pos)
|
||||||
else:
|
else:
|
||||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
pos_l = (
|
||||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
rms_list[
|
||||||
|
silence_start : silence_start + self.max_sil_kept + 1
|
||||||
|
].argmin()
|
||||||
|
+ silence_start
|
||||||
|
)
|
||||||
|
pos_r = (
|
||||||
|
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||||
|
+ i
|
||||||
|
- self.max_sil_kept
|
||||||
|
)
|
||||||
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
|
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
|
||||||
clip_start = pos_r
|
clip_start = pos_r
|
||||||
silence_start = None
|
silence_start = None
|
||||||
total_frames = rms_list.shape[0]
|
total_frames = rms_list.shape[0]
|
||||||
if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence.
|
if (
|
||||||
|
silence_start is not None
|
||||||
|
and total_frames - silence_start >= self.min_interval
|
||||||
|
): # Deal with trailing silence.
|
||||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||||
sil_tags.append((pos, total_frames + 1))
|
sil_tags.append((pos, total_frames + 1))
|
||||||
if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices.
|
if len(sil_tags) == 0:
|
||||||
|
return {
|
||||||
|
"0": {"slice": False, "split_time": f"0,{len(waveform)}"}
|
||||||
|
} # Apply and return slices.
|
||||||
chunks = []
|
chunks = []
|
||||||
if sil_tags[0][0]:
|
if sil_tags[0][0]:
|
||||||
chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
|
chunks.append(
|
||||||
|
{
|
||||||
|
"slice": False,
|
||||||
|
"split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}",
|
||||||
|
}
|
||||||
|
)
|
||||||
for i in range(0, len(sil_tags)):
|
for i in range(0, len(sil_tags)):
|
||||||
if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
|
if i:
|
||||||
chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
|
chunks.append(
|
||||||
|
{
|
||||||
|
"slice": False,
|
||||||
|
"split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"slice": True,
|
||||||
|
"split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}",
|
||||||
|
}
|
||||||
|
)
|
||||||
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
||||||
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
|
chunks.append(
|
||||||
|
{
|
||||||
|
"slice": False,
|
||||||
|
"split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}",
|
||||||
|
}
|
||||||
|
)
|
||||||
chunk_dict = {}
|
chunk_dict = {}
|
||||||
for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i]
|
for i in range(len(chunks)):
|
||||||
|
chunk_dict[str(i)] = chunks[i]
|
||||||
return chunk_dict
|
return chunk_dict
|
||||||
|
|
||||||
|
|
||||||
# sinc_interp_hann audio resampling
|
# sinc_interp_hann audio resampling
|
||||||
class Resample:
|
class Resample:
|
||||||
def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None):
|
def __init__(
|
||||||
self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta
|
self,
|
||||||
|
orig_freq: int = 16000,
|
||||||
|
new_freq: int = 16000,
|
||||||
|
lowpass_filter_width: int = 6,
|
||||||
|
rolloff: float = 0.99,
|
||||||
|
beta: Optional[float] = None,
|
||||||
|
dtype: Optional[dtypes] = None,
|
||||||
|
):
|
||||||
|
(
|
||||||
|
self.orig_freq,
|
||||||
|
self.new_freq,
|
||||||
|
self.lowpass_filter_width,
|
||||||
|
self.rolloff,
|
||||||
|
self.beta,
|
||||||
|
) = (orig_freq, new_freq, lowpass_filter_width, rolloff, beta)
|
||||||
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
|
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
|
||||||
self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None)
|
self.kernel, self.width = (
|
||||||
|
self._get_sinc_resample_kernel(dtype)
|
||||||
|
if self.orig_freq != self.new_freq
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, waveform: Tensor) -> Tensor:
|
def __call__(self, waveform: Tensor) -> Tensor:
|
||||||
if self.orig_freq == self.new_freq: return waveform
|
if self.orig_freq == self.new_freq:
|
||||||
|
return waveform
|
||||||
return self._apply_sinc_resample_kernel(waveform)
|
return self._apply_sinc_resample_kernel(waveform)
|
||||||
|
|
||||||
def _apply_sinc_resample_kernel(self, waveform: Tensor):
|
def _apply_sinc_resample_kernel(self, waveform: Tensor):
|
||||||
if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.")
|
if not waveform.is_floating_point():
|
||||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
raise TypeError(
|
||||||
|
f"Waveform tensor expected to be of type float, but received {waveform.dtype}."
|
||||||
|
)
|
||||||
|
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
|
||||||
|
int(self.new_freq) // self.gcd
|
||||||
|
)
|
||||||
shape = waveform.shape
|
shape = waveform.shape
|
||||||
waveform = waveform.reshape(-1, shape[-1]) # pack batch
|
waveform = waveform.reshape(-1, shape[-1]) # pack batch
|
||||||
num_wavs, length = waveform.shape
|
num_wavs, length = waveform.shape
|
||||||
|
@ -144,34 +279,58 @@ class Resample:
|
||||||
resampled = resampled[..., :target_length]
|
resampled = resampled[..., :target_length]
|
||||||
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
|
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
|
||||||
return resampled
|
return resampled
|
||||||
|
|
||||||
def _get_sinc_resample_kernel(self, dtype=None):
|
def _get_sinc_resample_kernel(self, dtype=None):
|
||||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
|
||||||
if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.")
|
int(self.new_freq) // self.gcd
|
||||||
|
)
|
||||||
|
if self.lowpass_filter_width <= 0:
|
||||||
|
raise ValueError("Low pass filter width should be positive.")
|
||||||
base_freq = min(orig_freq, new_freq)
|
base_freq = min(orig_freq, new_freq)
|
||||||
base_freq *= self.rolloff
|
base_freq *= self.rolloff
|
||||||
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
|
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
|
||||||
idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq
|
idx = (
|
||||||
|
Tensor.arange(
|
||||||
|
-width,
|
||||||
|
width + orig_freq,
|
||||||
|
dtype=(dtype if dtype is not None else dtypes.float32),
|
||||||
|
)[None, None]
|
||||||
|
/ orig_freq
|
||||||
|
)
|
||||||
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
|
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
|
||||||
t *= base_freq
|
t *= base_freq
|
||||||
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
|
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
|
||||||
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
|
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
|
||||||
t *= math.pi
|
t *= math.pi
|
||||||
scale = base_freq / orig_freq
|
scale = base_freq / orig_freq
|
||||||
kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t)
|
kernels = Tensor.where(
|
||||||
|
t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t
|
||||||
|
)
|
||||||
kernels *= window * scale
|
kernels *= window * scale
|
||||||
if dtype is None: kernels = kernels.cast(dtype=dtypes.float32)
|
if dtype is None:
|
||||||
|
kernels = kernels.cast(dtype=dtypes.float32)
|
||||||
return kernels, width
|
return kernels, width
|
||||||
|
|
||||||
def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None):
|
|
||||||
|
def sinc_interp_resample(
|
||||||
|
x: Tensor,
|
||||||
|
orig_freq: int = 16000,
|
||||||
|
new_freq: int = 1600,
|
||||||
|
lowpass_filter_width: int = 6,
|
||||||
|
rolloff: float = 0.99,
|
||||||
|
beta: Optional[float] = None,
|
||||||
|
):
|
||||||
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
|
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
|
||||||
return resamp(x)
|
return resamp(x)
|
||||||
|
|
||||||
|
|
||||||
def cut(audio_path, db_thresh=-30, min_len=5000):
|
def cut(audio_path, db_thresh=-30, min_len=5000):
|
||||||
audio, sr = librosa.load(audio_path, sr=None)
|
audio, sr = librosa.load(audio_path, sr=None)
|
||||||
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
|
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
|
||||||
chunks = slicer.slice(audio)
|
chunks = slicer.slice(audio)
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def chunks2audio(audio_path, chunks):
|
def chunks2audio(audio_path, chunks):
|
||||||
chunks = dict(chunks)
|
chunks = dict(chunks)
|
||||||
audio, sr = load_audiofile(audio_path)
|
audio, sr = load_audiofile(audio_path)
|
||||||
|
@ -185,19 +344,30 @@ def chunks2audio(audio_path, chunks):
|
||||||
result.append((v["slice"], audio[int(tag[0]) : int(tag[1])]))
|
result.append((v["slice"], audio[int(tag[0]) : int(tag[1])]))
|
||||||
return result, sr
|
return result, sr
|
||||||
|
|
||||||
def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True):
|
|
||||||
|
def load_audiofile(
|
||||||
|
filepath: str,
|
||||||
|
frame_offset: int = 0,
|
||||||
|
num_frames: int = -1,
|
||||||
|
channels_first: bool = True,
|
||||||
|
):
|
||||||
with soundfile.SoundFile(filepath, "r") as file_:
|
with soundfile.SoundFile(filepath, "r") as file_:
|
||||||
frames = file_._prepare_read(frame_offset, None, num_frames)
|
frames = file_._prepare_read(frame_offset, None, num_frames)
|
||||||
waveform = file_.read(frames, "float32", always_2d=True)
|
waveform = file_.read(frames, "float32", always_2d=True)
|
||||||
sample_rate = file_.samplerate
|
sample_rate = file_.samplerate
|
||||||
waveform = Tensor(waveform)
|
waveform = Tensor(waveform)
|
||||||
if channels_first: waveform = waveform.transpose(0, 1)
|
if channels_first:
|
||||||
|
waveform = waveform.transpose(0, 1)
|
||||||
return waveform, sample_rate
|
return waveform, sample_rate
|
||||||
|
|
||||||
def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]:
|
|
||||||
|
def get_unit_f0(
|
||||||
|
wav: Tensor, tran, hop_length, target_sample, f0_filter=False
|
||||||
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
|
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
|
||||||
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
|
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
|
||||||
if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected")
|
if f0_filter and sum(f0) == 0:
|
||||||
|
raise RuntimeError("No voice detected")
|
||||||
f0 = Tensor(f0.astype(np.float32)).float()
|
f0 = Tensor(f0.astype(np.float32)).float()
|
||||||
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
|
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
|
||||||
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
|
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
|
||||||
|
|
|
@ -14,6 +14,7 @@ from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock:
|
class AttnBlock:
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
self.norm = GroupNorm(32, in_channels)
|
self.norm = GroupNorm(32, in_channels)
|
||||||
|
@ -30,22 +31,32 @@ class AttnBlock:
|
||||||
# compute attention
|
# compute attention
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = [x.reshape(b, c, h * w).transpose(1, 2) for x in (q, k, v)]
|
q, k, v = [x.reshape(b, c, h * w).transpose(1, 2) for x in (q, k, v)]
|
||||||
h_ = Tensor.scaled_dot_product_attention(q,k,v).transpose(1,2).reshape(b,c,h,w)
|
h_ = (
|
||||||
|
Tensor.scaled_dot_product_attention(q, k, v)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(b, c, h, w)
|
||||||
|
)
|
||||||
return x + self.proj_out(h_)
|
return x + self.proj_out(h_)
|
||||||
|
|
||||||
|
|
||||||
class ResnetBlock:
|
class ResnetBlock:
|
||||||
def __init__(self, in_channels, out_channels=None):
|
def __init__(self, in_channels, out_channels=None):
|
||||||
self.norm1 = GroupNorm(32, in_channels)
|
self.norm1 = GroupNorm(32, in_channels)
|
||||||
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
||||||
self.norm2 = GroupNorm(32, out_channels)
|
self.norm2 = GroupNorm(32, out_channels)
|
||||||
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
||||||
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
|
self.nin_shortcut = (
|
||||||
|
Conv2d(in_channels, out_channels, 1)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else lambda x: x
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
h = self.conv1(self.norm1(x).swish())
|
h = self.conv1(self.norm1(x).swish())
|
||||||
h = self.conv2(self.norm2(h).swish())
|
h = self.conv2(self.norm2(h).swish())
|
||||||
return self.nin_shortcut(x) + h
|
return self.nin_shortcut(x) + h
|
||||||
|
|
||||||
|
|
||||||
class Mid:
|
class Mid:
|
||||||
def __init__(self, block_in):
|
def __init__(self, block_in):
|
||||||
self.block_1 = ResnetBlock(block_in, block_in)
|
self.block_1 = ResnetBlock(block_in, block_in)
|
||||||
|
@ -55,6 +66,7 @@ class Mid:
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
||||||
|
|
||||||
|
|
||||||
class Decoder:
|
class Decoder:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
||||||
|
@ -63,11 +75,17 @@ class Decoder:
|
||||||
|
|
||||||
arr = []
|
arr = []
|
||||||
for i, s in enumerate(sz):
|
for i, s in enumerate(sz):
|
||||||
arr.append({"block":
|
arr.append(
|
||||||
[ResnetBlock(s[1], s[0]),
|
{
|
||||||
|
"block": [
|
||||||
|
ResnetBlock(s[1], s[0]),
|
||||||
ResnetBlock(s[0], s[0]),
|
ResnetBlock(s[0], s[0]),
|
||||||
ResnetBlock(s[0], s[0])]})
|
ResnetBlock(s[0], s[0]),
|
||||||
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if i != 0:
|
||||||
|
arr[-1]["upsample"] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||||||
self.up = arr
|
self.up = arr
|
||||||
|
|
||||||
self.norm_out = GroupNorm(32, 128)
|
self.norm_out = GroupNorm(32, 128)
|
||||||
|
@ -79,16 +97,22 @@ class Decoder:
|
||||||
|
|
||||||
for l in self.up[::-1]:
|
for l in self.up[::-1]:
|
||||||
print("decode", x.shape)
|
print("decode", x.shape)
|
||||||
for b in l['block']: x = b(x)
|
for b in l["block"]:
|
||||||
if 'upsample' in l:
|
x = b(x)
|
||||||
|
if "upsample" in l:
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
|
||||||
bs, c, py, px = x.shape
|
bs, c, py, px = x.shape
|
||||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
x = (
|
||||||
x = l['upsample']['conv'](x)
|
x.reshape(bs, c, py, 1, px, 1)
|
||||||
|
.expand(bs, c, py, 2, px, 2)
|
||||||
|
.reshape(bs, c, py * 2, px * 2)
|
||||||
|
)
|
||||||
|
x = l["upsample"]["conv"](x)
|
||||||
x.realize()
|
x.realize()
|
||||||
|
|
||||||
return self.conv_out(self.norm_out(x).swish())
|
return self.conv_out(self.norm_out(x).swish())
|
||||||
|
|
||||||
|
|
||||||
class Encoder:
|
class Encoder:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
||||||
|
@ -96,10 +120,11 @@ class Encoder:
|
||||||
|
|
||||||
arr = []
|
arr = []
|
||||||
for i, s in enumerate(sz):
|
for i, s in enumerate(sz):
|
||||||
arr.append({"block":
|
arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
|
||||||
[ResnetBlock(s[0], s[1]),
|
if i != 3:
|
||||||
ResnetBlock(s[1], s[1])]})
|
arr[-1]["downsample"] = {
|
||||||
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
|
"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))
|
||||||
|
}
|
||||||
self.down = arr
|
self.down = arr
|
||||||
|
|
||||||
self.mid = Mid(512)
|
self.mid = Mid(512)
|
||||||
|
@ -111,12 +136,15 @@ class Encoder:
|
||||||
|
|
||||||
for l in self.down:
|
for l in self.down:
|
||||||
print("encode", x.shape)
|
print("encode", x.shape)
|
||||||
for b in l['block']: x = b(x)
|
for b in l["block"]:
|
||||||
if 'downsample' in l: x = l['downsample']['conv'](x)
|
x = b(x)
|
||||||
|
if "downsample" in l:
|
||||||
|
x = l["downsample"]["conv"](x)
|
||||||
|
|
||||||
x = self.mid(x)
|
x = self.mid(x)
|
||||||
return self.conv_out(self.norm_out(x).swish())
|
return self.conv_out(self.norm_out(x).swish())
|
||||||
|
|
||||||
|
|
||||||
class AutoencoderKL:
|
class AutoencoderKL:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.encoder = Encoder()
|
self.encoder = Encoder()
|
||||||
|
@ -132,25 +160,27 @@ class AutoencoderKL:
|
||||||
latent = self.post_quant_conv(latent)
|
latent = self.post_quant_conv(latent)
|
||||||
return self.decoder(latent)
|
return self.decoder(latent)
|
||||||
|
|
||||||
|
|
||||||
# not to be confused with ResnetBlock
|
# not to be confused with ResnetBlock
|
||||||
class ResBlock:
|
class ResBlock:
|
||||||
def __init__(self, channels, emb_channels, out_channels):
|
def __init__(self, channels, emb_channels, out_channels):
|
||||||
self.in_layers = [
|
self.in_layers = [
|
||||||
GroupNorm(32, channels),
|
GroupNorm(32, channels),
|
||||||
Tensor.silu,
|
Tensor.silu,
|
||||||
Conv2d(channels, out_channels, 3, padding=1)
|
Conv2d(channels, out_channels, 3, padding=1),
|
||||||
]
|
|
||||||
self.emb_layers = [
|
|
||||||
Tensor.silu,
|
|
||||||
Linear(emb_channels, out_channels)
|
|
||||||
]
|
]
|
||||||
|
self.emb_layers = [Tensor.silu, Linear(emb_channels, out_channels)]
|
||||||
self.out_layers = [
|
self.out_layers = [
|
||||||
GroupNorm(32, out_channels),
|
GroupNorm(32, out_channels),
|
||||||
Tensor.silu,
|
Tensor.silu,
|
||||||
lambda x: x, # needed for weights loading code to work
|
lambda x: x, # needed for weights loading code to work
|
||||||
Conv2d(out_channels, out_channels, 3, padding=1)
|
Conv2d(out_channels, out_channels, 3, padding=1),
|
||||||
]
|
]
|
||||||
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
|
self.skip_connection = (
|
||||||
|
Conv2d(channels, out_channels, 1)
|
||||||
|
if channels != out_channels
|
||||||
|
else lambda x: x
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x, emb):
|
def __call__(self, x, emb):
|
||||||
h = x.sequential(self.in_layers)
|
h = x.sequential(self.in_layers)
|
||||||
|
@ -160,6 +190,7 @@ class ResBlock:
|
||||||
ret = self.skip_connection(x) + h
|
ret = self.skip_connection(x) + h
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention:
|
class CrossAttention:
|
||||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||||
self.to_q = Linear(query_dim, n_heads * d_head, bias=False)
|
self.to_q = Linear(query_dim, n_heads * d_head, bias=False)
|
||||||
|
@ -172,11 +203,15 @@ class CrossAttention:
|
||||||
def __call__(self, x, context=None):
|
def __call__(self, x, context=None):
|
||||||
context = x if context is None else context
|
context = x if context is None else context
|
||||||
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||||
q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)]
|
q, k, v = [
|
||||||
|
y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1, 2)
|
||||||
|
for y in (q, k, v)
|
||||||
|
]
|
||||||
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2)
|
attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2)
|
||||||
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size)
|
||||||
return h_.sequential(self.to_out)
|
return h_.sequential(self.to_out)
|
||||||
|
|
||||||
|
|
||||||
class GEGLU:
|
class GEGLU:
|
||||||
def __init__(self, dim_in, dim_out):
|
def __init__(self, dim_in, dim_out):
|
||||||
self.proj = Linear(dim_in, dim_out * 2)
|
self.proj = Linear(dim_in, dim_out * 2)
|
||||||
|
@ -186,17 +221,19 @@ class GEGLU:
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
return x * gate.gelu()
|
return x * gate.gelu()
|
||||||
|
|
||||||
|
|
||||||
class FeedForward:
|
class FeedForward:
|
||||||
def __init__(self, dim, mult=4):
|
def __init__(self, dim, mult=4):
|
||||||
self.net = [
|
self.net = [
|
||||||
GEGLU(dim, dim * mult),
|
GEGLU(dim, dim * mult),
|
||||||
lambda x: x, # needed for weights loading code to work
|
lambda x: x, # needed for weights loading code to work
|
||||||
Linear(dim*mult, dim)
|
Linear(dim * mult, dim),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return x.sequential(self.net)
|
return x.sequential(self.net)
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock:
|
class BasicTransformerBlock:
|
||||||
def __init__(self, dim, context_dim, n_heads, d_head):
|
def __init__(self, dim, context_dim, n_heads, d_head):
|
||||||
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
|
||||||
|
@ -212,12 +249,15 @@ class BasicTransformerBlock:
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SpatialTransformer:
|
class SpatialTransformer:
|
||||||
def __init__(self, channels, context_dim, n_heads, d_head):
|
def __init__(self, channels, context_dim, n_heads, d_head):
|
||||||
self.norm = GroupNorm(32, channels)
|
self.norm = GroupNorm(32, channels)
|
||||||
assert channels == n_heads * d_head
|
assert channels == n_heads * d_head
|
||||||
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
|
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
|
||||||
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
|
self.transformer_blocks = [
|
||||||
|
BasicTransformerBlock(channels, context_dim, n_heads, d_head)
|
||||||
|
]
|
||||||
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
|
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
|
||||||
|
|
||||||
def __call__(self, x, context=None):
|
def __call__(self, x, context=None):
|
||||||
|
@ -232,6 +272,7 @@ class SpatialTransformer:
|
||||||
ret = self.proj_out(x) + x_in
|
ret = self.proj_out(x) + x_in
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class Downsample:
|
class Downsample:
|
||||||
def __init__(self, channels):
|
def __init__(self, channels):
|
||||||
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
|
||||||
|
@ -239,21 +280,28 @@ class Downsample:
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return self.op(x)
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
class Upsample:
|
class Upsample:
|
||||||
def __init__(self, channels):
|
def __init__(self, channels):
|
||||||
self.conv = Conv2d(channels, channels, 3, padding=1)
|
self.conv = Conv2d(channels, channels, 3, padding=1)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
bs, c, py, px = x.shape
|
bs, c, py, px = x.shape
|
||||||
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
|
x = (
|
||||||
|
x.reshape(bs, c, py, 1, px, 1)
|
||||||
|
.expand(bs, c, py, 2, px, 2)
|
||||||
|
.reshape(bs, c, py * 2, px * 2)
|
||||||
|
)
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||||
args = timesteps * freqs
|
args = timesteps * freqs
|
||||||
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)
|
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)
|
||||||
|
|
||||||
|
|
||||||
class UNetModel:
|
class UNetModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.time_embed = [
|
self.time_embed = [
|
||||||
|
@ -273,12 +321,12 @@ class UNetModel:
|
||||||
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||||
[Downsample(1280)],
|
[Downsample(1280)],
|
||||||
[ResBlock(1280, 1280, 1280)],
|
[ResBlock(1280, 1280, 1280)],
|
||||||
[ResBlock(1280, 1280, 1280)]
|
[ResBlock(1280, 1280, 1280)],
|
||||||
]
|
]
|
||||||
self.middle_block = [
|
self.middle_block = [
|
||||||
ResBlock(1280, 1280, 1280),
|
ResBlock(1280, 1280, 1280),
|
||||||
SpatialTransformer(1280, 768, 8, 160),
|
SpatialTransformer(1280, 768, 8, 160),
|
||||||
ResBlock(1280, 1280, 1280)
|
ResBlock(1280, 1280, 1280),
|
||||||
]
|
]
|
||||||
self.output_blocks = [
|
self.output_blocks = [
|
||||||
[ResBlock(2560, 1280, 1280)],
|
[ResBlock(2560, 1280, 1280)],
|
||||||
|
@ -286,10 +334,18 @@ class UNetModel:
|
||||||
[ResBlock(2560, 1280, 1280), Upsample(1280)],
|
[ResBlock(2560, 1280, 1280), Upsample(1280)],
|
||||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||||
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
|
||||||
[ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)],
|
[
|
||||||
|
ResBlock(1920, 1280, 1280),
|
||||||
|
SpatialTransformer(1280, 768, 8, 160),
|
||||||
|
Upsample(1280),
|
||||||
|
],
|
||||||
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
|
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
|
||||||
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
|
||||||
[ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)],
|
[
|
||||||
|
ResBlock(960, 1280, 640),
|
||||||
|
SpatialTransformer(640, 768, 8, 80),
|
||||||
|
Upsample(640),
|
||||||
|
],
|
||||||
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||||
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
|
||||||
|
@ -297,7 +353,7 @@ class UNetModel:
|
||||||
self.out = [
|
self.out = [
|
||||||
GroupNorm(32, 320),
|
GroupNorm(32, 320),
|
||||||
Tensor.silu,
|
Tensor.silu,
|
||||||
Conv2d(320, 4, kernel_size=3, padding=1)
|
Conv2d(320, 4, kernel_size=3, padding=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, x, timesteps=None, context=None):
|
def __call__(self, x, timesteps=None, context=None):
|
||||||
|
@ -306,9 +362,12 @@ class UNetModel:
|
||||||
emb = t_emb.sequential(self.time_embed)
|
emb = t_emb.sequential(self.time_embed)
|
||||||
|
|
||||||
def run(x, bb):
|
def run(x, bb):
|
||||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
if isinstance(bb, ResBlock):
|
||||||
elif isinstance(bb, SpatialTransformer): x = bb(x, context)
|
x = bb(x, emb)
|
||||||
else: x = bb(x)
|
elif isinstance(bb, SpatialTransformer):
|
||||||
|
x = bb(x, context)
|
||||||
|
else:
|
||||||
|
x = bb(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
saved_inputs = []
|
saved_inputs = []
|
||||||
|
@ -326,6 +385,7 @@ class UNetModel:
|
||||||
x = run(x, bb)
|
x = run(x, bb)
|
||||||
return x.sequential(self.out)
|
return x.sequential(self.out)
|
||||||
|
|
||||||
|
|
||||||
class CLIPMLP:
|
class CLIPMLP:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.fc1 = Linear(768, 3072)
|
self.fc1 = Linear(768, 3072)
|
||||||
|
@ -337,6 +397,7 @@ class CLIPMLP:
|
||||||
hidden_states = self.fc2(hidden_states)
|
hidden_states = self.fc2(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class CLIPAttention:
|
class CLIPAttention:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embed_dim = 768
|
self.embed_dim = 768
|
||||||
|
@ -349,10 +410,22 @@ class CLIPAttention:
|
||||||
|
|
||||||
def __call__(self, hidden_states, causal_attention_mask):
|
def __call__(self, hidden_states, causal_attention_mask):
|
||||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||||
q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
q, k, v = (
|
||||||
q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)]
|
self.q_proj(hidden_states),
|
||||||
attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask)
|
self.k_proj(hidden_states),
|
||||||
return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim))
|
self.v_proj(hidden_states),
|
||||||
|
)
|
||||||
|
q, k, v = [
|
||||||
|
x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
for x in (q, k, v)
|
||||||
|
]
|
||||||
|
attn_output = Tensor.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=causal_attention_mask
|
||||||
|
)
|
||||||
|
return self.out_proj(
|
||||||
|
attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoderLayer:
|
class CLIPEncoderLayer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -374,6 +447,7 @@ class CLIPEncoderLayer:
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoder:
|
class CLIPEncoder:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.layers = [CLIPEncoderLayer() for i in range(12)]
|
self.layers = [CLIPEncoderLayer() for i in range(12)]
|
||||||
|
@ -383,6 +457,7 @@ class CLIPEncoder:
|
||||||
hidden_states = l(hidden_states, causal_attention_mask)
|
hidden_states = l(hidden_states, causal_attention_mask)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEmbeddings:
|
class CLIPTextEmbeddings:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.token_embedding = Embedding(49408, 768)
|
self.token_embedding = Embedding(49408, 768)
|
||||||
|
@ -391,6 +466,7 @@ class CLIPTextEmbeddings:
|
||||||
def __call__(self, input_ids, position_ids):
|
def __call__(self, input_ids, position_ids):
|
||||||
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
return self.token_embedding(input_ids) + self.position_embedding(position_ids)
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextTransformer:
|
class CLIPTextTransformer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embeddings = CLIPTextEmbeddings()
|
self.embeddings = CLIPTextEmbeddings()
|
||||||
|
@ -402,9 +478,15 @@ class CLIPTextTransformer:
|
||||||
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1))
|
x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1))
|
||||||
return self.final_layer_norm(x)
|
return self.final_layer_norm(x)
|
||||||
|
|
||||||
|
|
||||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
def default_bpe():
|
||||||
|
return fetch(
|
||||||
|
"https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
"bpe_simple_vocab_16e6.txt.gz",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pairs(word):
|
def get_pairs(word):
|
||||||
"""Return set of symbol pairs in a word.
|
"""Return set of symbol pairs in a word.
|
||||||
|
@ -417,11 +499,13 @@ def get_pairs(word):
|
||||||
prev_char = char
|
prev_char = char
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
def whitespace_clean(text):
|
def whitespace_clean(text):
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r"\s+", " ", text)
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_unicode():
|
def bytes_to_unicode():
|
||||||
"""
|
"""
|
||||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||||
|
@ -432,7 +516,11 @@ def bytes_to_unicode():
|
||||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||||
"""
|
"""
|
||||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
bs = (
|
||||||
|
list(range(ord("!"), ord("~") + 1))
|
||||||
|
+ list(range(ord("¡"), ord("¬") + 1))
|
||||||
|
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||||
|
)
|
||||||
cs = bs[:]
|
cs = bs[:]
|
||||||
n = 0
|
n = 0
|
||||||
for b in range(2**8):
|
for b in range(2**8):
|
||||||
|
@ -443,33 +531,40 @@ def bytes_to_unicode():
|
||||||
cs = [chr(n) for n in cs]
|
cs = [chr(n) for n in cs]
|
||||||
return dict(zip(bs, cs))
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
class ClipTokenizer:
|
class ClipTokenizer:
|
||||||
def __init__(self, bpe_path: str = default_bpe()):
|
def __init__(self, bpe_path: str = default_bpe()):
|
||||||
self.byte_encoder = bytes_to_unicode()
|
self.byte_encoder = bytes_to_unicode()
|
||||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
|
||||||
merges = merges[1 : 49152 - 256 - 2 + 1]
|
merges = merges[1 : 49152 - 256 - 2 + 1]
|
||||||
merges = [tuple(merge.split()) for merge in merges]
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
vocab = list(bytes_to_unicode().values())
|
vocab = list(bytes_to_unicode().values())
|
||||||
vocab = vocab + [v+'</w>' for v in vocab]
|
vocab = vocab + [v + "</w>" for v in vocab]
|
||||||
for merge in merges:
|
for merge in merges:
|
||||||
vocab.append(''.join(merge))
|
vocab.append("".join(merge))
|
||||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
|
||||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
self.cache = {
|
||||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE)
|
"<|startoftext|>": "<|startoftext|>",
|
||||||
|
"<|endoftext|>": "<|endoftext|>",
|
||||||
|
}
|
||||||
|
self.pat = re.compile(
|
||||||
|
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
if token in self.cache:
|
if token in self.cache:
|
||||||
return self.cache[token]
|
return self.cache[token]
|
||||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
||||||
pairs = get_pairs(word)
|
pairs = get_pairs(word)
|
||||||
|
|
||||||
if not pairs:
|
if not pairs:
|
||||||
return token+'</w>'
|
return token + "</w>"
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||||
if bigram not in self.bpe_ranks:
|
if bigram not in self.bpe_ranks:
|
||||||
break
|
break
|
||||||
first, second = bigram
|
first, second = bigram
|
||||||
|
@ -495,7 +590,7 @@ class ClipTokenizer:
|
||||||
if len(word) == 1:
|
if len(word) == 1:
|
||||||
break
|
break
|
||||||
pairs = get_pairs(word)
|
pairs = get_pairs(word)
|
||||||
word = ' '.join(word)
|
word = " ".join(word)
|
||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
|
@ -503,19 +598,28 @@ class ClipTokenizer:
|
||||||
bpe_tokens = []
|
bpe_tokens = []
|
||||||
text = whitespace_clean(text.strip()).lower()
|
text = whitespace_clean(text.strip()).lower()
|
||||||
for token in re.findall(self.pat, text):
|
for token in re.findall(self.pat, text):
|
||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
bpe_tokens.extend(
|
||||||
|
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
||||||
|
)
|
||||||
# Truncation, keeping two slots for start and end tokens.
|
# Truncation, keeping two slots for start and end tokens.
|
||||||
if len(bpe_tokens) > 75:
|
if len(bpe_tokens) > 75:
|
||||||
bpe_tokens = bpe_tokens[:75]
|
bpe_tokens = bpe_tokens[:75]
|
||||||
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
|
return [49406] + bpe_tokens + [49407] * (77 - len(bpe_tokens) - 1)
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
class StableDiffusion:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.alphas_cumprod = Tensor.empty(1000)
|
self.alphas_cumprod = Tensor.empty(1000)
|
||||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
|
||||||
|
diffusion_model=UNetModel()
|
||||||
|
)
|
||||||
self.first_stage_model = AutoencoderKL()
|
self.first_stage_model = AutoencoderKL()
|
||||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
|
||||||
|
transformer=namedtuple("Transformer", ["text_model"])(
|
||||||
|
text_model=CLIPTextTransformer()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
def get_x_prev_and_pred_x0(self, x, e_t, a_t, a_prev):
|
||||||
temperature = 1
|
temperature = 1
|
||||||
|
@ -526,17 +630,30 @@ class StableDiffusion:
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||||
return x_prev, pred_x0
|
return x_prev, pred_x0
|
||||||
|
|
||||||
def get_model_output(self, unconditional_context, context, latent, timestep, unconditional_guidance_scale):
|
def get_model_output(
|
||||||
|
self,
|
||||||
|
unconditional_context,
|
||||||
|
context,
|
||||||
|
latent,
|
||||||
|
timestep,
|
||||||
|
unconditional_guidance_scale,
|
||||||
|
):
|
||||||
# put into diffuser
|
# put into diffuser
|
||||||
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, unconditional_context.cat(context, dim=0))
|
latents = self.model.diffusion_model(
|
||||||
|
latent.expand(2, *latent.shape[1:]),
|
||||||
|
timestep,
|
||||||
|
unconditional_context.cat(context, dim=0),
|
||||||
|
)
|
||||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||||
|
|
||||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
e_t = unconditional_latent + unconditional_guidance_scale * (
|
||||||
|
latent - unconditional_latent
|
||||||
|
)
|
||||||
return e_t
|
return e_t
|
||||||
|
|
||||||
def decode(self, x):
|
def decode(self, x):
|
||||||
|
@ -548,14 +665,26 @@ class StableDiffusion:
|
||||||
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
|
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
|
||||||
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
|
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
|
||||||
|
|
||||||
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
|
def __call__(
|
||||||
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
|
self,
|
||||||
|
unconditional_context,
|
||||||
|
context,
|
||||||
|
latent,
|
||||||
|
timestep,
|
||||||
|
alphas,
|
||||||
|
alphas_prev,
|
||||||
|
guidance,
|
||||||
|
):
|
||||||
|
e_t = self.get_model_output(
|
||||||
|
unconditional_context, context, latent, timestep, guidance
|
||||||
|
)
|
||||||
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
|
x_prev, _ = self.get_x_prev_and_pred_x0(latent, e_t, alphas, alphas_prev)
|
||||||
# e_t_next = get_model_output(x_prev)
|
# e_t_next = get_model_output(x_prev)
|
||||||
# e_t_prime = (e_t + e_t_next) / 2
|
# e_t_prime = (e_t + e_t_next) / 2
|
||||||
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||||
return x_prev.realize()
|
return x_prev.realize()
|
||||||
|
|
||||||
|
|
||||||
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
||||||
# 3x512x512 <--> 4x64x64 (16384)
|
# 3x512x512 <--> 4x64x64 (16384)
|
||||||
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
|
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
|
||||||
|
@ -573,22 +702,48 @@ class StableDiffusion:
|
||||||
# cond_stage_model.transformer.text_model
|
# cond_stage_model.transformer.text_model
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
|
description="Run Stable Diffusion",
|
||||||
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
)
|
||||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
parser.add_argument(
|
||||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
"--steps", type=int, default=5, help="Number of steps in diffusion"
|
||||||
parser.add_argument('--timing', action='store_true', help="Print timing per step")
|
)
|
||||||
parser.add_argument('--seed', type=int, help="Set the random latent seed")
|
parser.add_argument(
|
||||||
parser.add_argument('--guidance', type=float, default=7.5, help="Prompt strength")
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
default="a horse sized cat eating a bagel",
|
||||||
|
help="Phrase to render",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out",
|
||||||
|
type=str,
|
||||||
|
default=Path(tempfile.gettempdir()) / "rendered.png",
|
||||||
|
help="Output filename",
|
||||||
|
)
|
||||||
|
parser.add_argument("--noshow", action="store_true", help="Don't show the image")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16", action="store_true", help="Cast the weights to float16"
|
||||||
|
)
|
||||||
|
parser.add_argument("--timing", action="store_true", help="Print timing per step")
|
||||||
|
parser.add_argument("--seed", type=int, help="Set the random latent seed")
|
||||||
|
parser.add_argument("--guidance", type=float, default=7.5, help="Prompt strength")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
Tensor.no_grad = True
|
Tensor.no_grad = True
|
||||||
model = StableDiffusion()
|
model = StableDiffusion()
|
||||||
|
|
||||||
# load in weights
|
# load in weights
|
||||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
load_state_dict(
|
||||||
|
model,
|
||||||
|
torch_load(
|
||||||
|
fetch(
|
||||||
|
"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
|
||||||
|
"sd-v1-4.ckpt",
|
||||||
|
)
|
||||||
|
)["state_dict"],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
for l in get_state_dict(model).values():
|
for l in get_state_dict(model).values():
|
||||||
|
@ -601,7 +756,9 @@ if __name__ == "__main__":
|
||||||
print("got CLIP context", context.shape)
|
print("got CLIP context", context.shape)
|
||||||
|
|
||||||
prompt = Tensor([tokenizer.encode("")])
|
prompt = Tensor([tokenizer.encode("")])
|
||||||
unconditional_context = model.cond_stage_model.transformer.text_model(prompt).realize()
|
unconditional_context = model.cond_stage_model.transformer.text_model(
|
||||||
|
prompt
|
||||||
|
).realize()
|
||||||
print("got unconditional CLIP context", unconditional_context.shape)
|
print("got unconditional CLIP context", unconditional_context.shape)
|
||||||
|
|
||||||
# done with clip model
|
# done with clip model
|
||||||
|
@ -613,21 +770,37 @@ if __name__ == "__main__":
|
||||||
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
||||||
|
|
||||||
# start with random noise
|
# start with random noise
|
||||||
if args.seed is not None: Tensor._seed = args.seed
|
if args.seed is not None:
|
||||||
|
Tensor._seed = args.seed
|
||||||
latent = Tensor.randn(1, 4, 64, 64)
|
latent = Tensor.randn(1, 4, 64, 64)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def run(model, *x): return model(*x).realize()
|
def run(model, *x):
|
||||||
|
return model(*x).realize()
|
||||||
|
|
||||||
# this is diffusion
|
# this is diffusion
|
||||||
with Context(BEAM=getenv("LATEBEAM")):
|
with Context(BEAM=getenv("LATEBEAM")):
|
||||||
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
t.set_description("%3d %3d" % (index, timestep))
|
t.set_description("%3d %3d" % (index, timestep))
|
||||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
with Timing(
|
||||||
|
"step in ",
|
||||||
|
enabled=args.timing,
|
||||||
|
on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB",
|
||||||
|
):
|
||||||
tid = Tensor([index])
|
tid = Tensor([index])
|
||||||
latent = run(model, unconditional_context, context, latent, Tensor([timestep]), alphas[tid], alphas_prev[tid], Tensor([args.guidance]))
|
latent = run(
|
||||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
model,
|
||||||
|
unconditional_context,
|
||||||
|
context,
|
||||||
|
latent,
|
||||||
|
Tensor([timestep]),
|
||||||
|
alphas[tid],
|
||||||
|
alphas_prev[tid],
|
||||||
|
Tensor([args.guidance]),
|
||||||
|
)
|
||||||
|
if args.timing:
|
||||||
|
Device[Device.DEFAULT].synchronize()
|
||||||
del run
|
del run
|
||||||
|
|
||||||
# upsample latent space to image with autoencoder
|
# upsample latent space to image with autoencoder
|
||||||
|
@ -637,8 +810,10 @@ if __name__ == "__main__":
|
||||||
# save image
|
# save image
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
|
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
|
||||||
print(f"saving {args.out}")
|
print(f"saving {args.out}")
|
||||||
im.save(args.out)
|
im.save(args.out)
|
||||||
# Open image.
|
# Open image.
|
||||||
if not args.noshow: im.show()
|
if not args.noshow:
|
||||||
|
im.show()
|
||||||
|
|
|
@ -10,6 +10,7 @@ from tinygrad.tensor import Tensor
|
||||||
from extra.datasets import fetch_cifar
|
from extra.datasets import fetch_cifar
|
||||||
from extra.models.efficientnet import EfficientNet
|
from extra.models.efficientnet import EfficientNet
|
||||||
|
|
||||||
|
|
||||||
class TinyConvNet:
|
class TinyConvNet:
|
||||||
def __init__(self, classes=10):
|
def __init__(self, classes=10):
|
||||||
conv = 3
|
conv = 3
|
||||||
|
@ -24,6 +25,7 @@ class TinyConvNet:
|
||||||
x = x.reshape(shape=[x.shape[0], -1])
|
x = x.reshape(shape=[x.shape[0], -1])
|
||||||
return x.dot(self.l1)
|
return x.dot(self.l1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
IMAGENET = getenv("IMAGENET")
|
IMAGENET = getenv("IMAGENET")
|
||||||
classes = 1000 if IMAGENET else 10
|
classes = 1000 if IMAGENET else 10
|
||||||
|
@ -47,12 +49,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if IMAGENET:
|
if IMAGENET:
|
||||||
from extra.datasets.imagenet import fetch_batch
|
from extra.datasets.imagenet import fetch_batch
|
||||||
|
|
||||||
def loader(q):
|
def loader(q):
|
||||||
while 1:
|
while 1:
|
||||||
try:
|
try:
|
||||||
q.put(fetch_batch(BS))
|
q.put(fetch_batch(BS))
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
q = Queue(16)
|
q = Queue(16)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
p = Process(target=loader, args=(q,))
|
p = Process(target=loader, args=(q,))
|
||||||
|
@ -97,9 +101,17 @@ if __name__ == "__main__":
|
||||||
finish_time = (time.time() - st) * 1000.0
|
finish_time = (time.time() - st) * 1000.0
|
||||||
|
|
||||||
# printing
|
# printing
|
||||||
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
t.set_description(
|
||||||
(loss, accuracy,
|
"loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f"
|
||||||
fp_time, bp_time, opt_time, finish_time,
|
% (
|
||||||
fp_time + bp_time + opt_time + finish_time))
|
loss,
|
||||||
|
accuracy,
|
||||||
|
fp_time,
|
||||||
|
bp_time,
|
||||||
|
opt_time,
|
||||||
|
finish_time,
|
||||||
|
fp_time + bp_time + opt_time + finish_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
del out, y, loss
|
del out, y, loss
|
||||||
|
|
|
@ -19,27 +19,30 @@ class ComposeTransforms:
|
||||||
x = t(x)
|
x = t(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||||
classes = 10
|
classes = 10
|
||||||
|
|
||||||
TRANSFER = getenv('TRANSFER')
|
TRANSFER = getenv("TRANSFER")
|
||||||
model = ResNet(getenv('NUM', 18), num_classes=classes)
|
model = ResNet(getenv("NUM", 18), num_classes=classes)
|
||||||
if TRANSFER:
|
if TRANSFER:
|
||||||
model.load_from_pretrained()
|
model.load_from_pretrained()
|
||||||
|
|
||||||
lr = 5e-3
|
lr = 5e-3
|
||||||
transform = ComposeTransforms([
|
transform = ComposeTransforms(
|
||||||
lambda x: [Image.fromarray(xx, mode='L').resize((64, 64)) for xx in x],
|
[
|
||||||
|
lambda x: [Image.fromarray(xx, mode="L").resize((64, 64)) for xx in x],
|
||||||
lambda x: np.stack([np.asarray(xx) for xx in x], 0),
|
lambda x: np.stack([np.asarray(xx) for xx in x], 0),
|
||||||
lambda x: x / 255.0,
|
lambda x: x / 255.0,
|
||||||
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
|
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
|
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
|
||||||
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
|
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
|
||||||
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
|
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
|
||||||
lr /= 1.2
|
lr /= 1.2
|
||||||
print(f'reducing lr to {lr:.7f}')
|
print(f"reducing lr to {lr:.7f}")
|
||||||
|
|
|
@ -7,13 +7,16 @@ from tinygrad.nn.optim import Adam
|
||||||
from extra.training import train, evaluate
|
from extra.training import train, evaluate
|
||||||
from extra.models.transformer import Transformer
|
from extra.models.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
# dataset idea from https://github.com/karpathy/minGPT/blob/master/projects/adder/adder.py
|
# dataset idea from https://github.com/karpathy/minGPT/blob/master/projects/adder/adder.py
|
||||||
def make_dataset():
|
def make_dataset():
|
||||||
ds = []
|
ds = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
for j in range(100):
|
for j in range(100):
|
||||||
s = i + j
|
s = i + j
|
||||||
ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
|
ds.append(
|
||||||
|
[i // 10, i % 10, j // 10, j % 10, s // 100, (s // 10) % 10, s % 10]
|
||||||
|
)
|
||||||
random.shuffle(ds)
|
random.shuffle(ds)
|
||||||
ds = np.array(ds).astype(np.float32)
|
ds = np.array(ds).astype(np.float32)
|
||||||
ds_X = ds[:, 0:6]
|
ds_X = ds[:, 0:6]
|
||||||
|
@ -22,6 +25,7 @@ def make_dataset():
|
||||||
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
|
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
|
||||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = Transformer(10, 6, 2, 128, 4, 32)
|
model = Transformer(10, 6, 2, 128, 4, 32)
|
||||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||||
|
@ -29,14 +33,23 @@ if __name__ == "__main__":
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
optim = Adam(get_parameters(model), lr=lr)
|
optim = Adam(get_parameters(model), lr=lr)
|
||||||
train(model, X_train, Y_train, optim, 50, BS=64)
|
train(model, X_train, Y_train, optim, 50, BS=64)
|
||||||
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
|
acc, Y_test_preds = evaluate(
|
||||||
|
model, X_test, Y_test, num_classes=10, return_predict=True
|
||||||
|
)
|
||||||
lr /= 1.2
|
lr /= 1.2
|
||||||
print(f'reducing lr to {lr:.4f}')
|
print(f"reducing lr to {lr:.4f}")
|
||||||
if acc > 0.998:
|
if acc > 0.998:
|
||||||
wrong = 0
|
wrong = 0
|
||||||
for k in range(len(Y_test_preds)):
|
for k in range(len(Y_test_preds)):
|
||||||
if (Y_test_preds[k] != Y_test[k]).any():
|
if (Y_test_preds[k] != Y_test[k]).any():
|
||||||
wrong += 1
|
wrong += 1
|
||||||
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
|
a, b, c, x = (
|
||||||
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
|
X_test[k, :2],
|
||||||
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
|
X_test[k, 2:4],
|
||||||
|
Y_test[k, -3:],
|
||||||
|
Y_test_preds[k, -3:],
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})"
|
||||||
|
)
|
||||||
|
print(f"Wrong predictions: {wrong}, acc = {acc:.4f}")
|
||||||
|
|
|
@ -12,6 +12,7 @@ from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7
|
||||||
# amount of context erased by model
|
# amount of context erased by model
|
||||||
CONTEXT = 7
|
CONTEXT = 7
|
||||||
|
|
||||||
|
|
||||||
def get_sample_count(samples_dir):
|
def get_sample_count(samples_dir):
|
||||||
try:
|
try:
|
||||||
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
|
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
|
||||||
|
@ -21,18 +22,24 @@ def get_sample_count(samples_dir):
|
||||||
except:
|
except:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def set_sample_count(samples_dir, sc):
|
def set_sample_count(samples_dir, sc):
|
||||||
with open(samples_dir + "/sample_count.txt", "w") as file:
|
with open(samples_dir + "/sample_count.txt", "w") as file:
|
||||||
file.write(str(sc) + "\n")
|
file.write(str(sc) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if len(sys.argv) < 2:
|
if len(sys.argv) < 2:
|
||||||
print("python3 -m examples.vgg7 import MODELJSON MODEL")
|
print("python3 -m examples.vgg7 import MODELJSON MODEL")
|
||||||
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
|
print(
|
||||||
|
" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json"
|
||||||
|
)
|
||||||
print(" into a safetensors file")
|
print(" into a safetensors file")
|
||||||
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
|
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
|
||||||
print(" *this format is used by most other commands in this program*")
|
print(" *this format is used by most other commands in this program*")
|
||||||
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
|
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
|
||||||
print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors")
|
print(
|
||||||
|
" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors"
|
||||||
|
)
|
||||||
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
|
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
|
||||||
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
|
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
|
||||||
print(" output image has 7 pixels removed on all edges")
|
print(" output image has 7 pixels removed on all edges")
|
||||||
|
@ -53,7 +60,9 @@ if len(sys.argv) < 2:
|
||||||
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
|
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
|
||||||
print(" won't pad or tile, so keep image sizes sane")
|
print(" won't pad or tile, so keep image sizes sane")
|
||||||
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
|
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
|
||||||
print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
|
print(
|
||||||
|
" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training"
|
||||||
|
)
|
||||||
print(" maintains/creates samples_count.txt automatically")
|
print(" maintains/creates samples_count.txt automatically")
|
||||||
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
|
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
@ -61,9 +70,13 @@ if len(sys.argv) < 2:
|
||||||
cmd = sys.argv[1]
|
cmd = sys.argv[1]
|
||||||
vgg7 = Vgg7()
|
vgg7 = Vgg7()
|
||||||
|
|
||||||
|
|
||||||
def nansbane(p):
|
def nansbane(p):
|
||||||
if numpy.isnan(numpy.min(p.numpy())):
|
if numpy.isnan(numpy.min(p.numpy())):
|
||||||
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
|
raise Exception(
|
||||||
|
"A NaN in the model has been detected. This model will not be interacted with to prevent further damage."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_and_save(path, save):
|
def load_and_save(path, save):
|
||||||
if save:
|
if save:
|
||||||
|
@ -77,6 +90,7 @@ def load_and_save(path, save):
|
||||||
for v in vgg7.get_parameters():
|
for v in vgg7.get_parameters():
|
||||||
nansbane(v)
|
nansbane(v)
|
||||||
|
|
||||||
|
|
||||||
if cmd == "import":
|
if cmd == "import":
|
||||||
src = sys.argv[2]
|
src = sys.argv[2]
|
||||||
model = sys.argv[3]
|
model = sys.argv[3]
|
||||||
|
@ -158,7 +172,9 @@ elif cmd == "train":
|
||||||
|
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
try:
|
try:
|
||||||
sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
|
sample_idx = numpy.random.choice(
|
||||||
|
samples_count, p=sample_probs / sample_probs.sum()
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
|
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
|
||||||
sample_idx = random.randint(0, samples_count - 1)
|
sample_idx = random.randint(0, samples_count - 1)
|
||||||
|
@ -204,7 +220,7 @@ elif cmd == "train":
|
||||||
rnum = rnum + 1
|
rnum = rnum + 1
|
||||||
# Probability management
|
# Probability management
|
||||||
# there must always be a probability, no matter how slim, even if loss goes to 0
|
# there must always be a probability, no matter how slim, even if loss goes to 0
|
||||||
sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
|
sample_probs[sample_idx] = max(loss_indicator, 1.0e-10)
|
||||||
|
|
||||||
# if we were told to save every round, we already saved
|
# if we were told to save every round, we already saved
|
||||||
if rounds_per_save != 1:
|
if rounds_per_save != 1:
|
||||||
|
@ -237,8 +253,12 @@ elif cmd == "samplify":
|
||||||
samples_added = 0
|
samples_added = 0
|
||||||
|
|
||||||
# actual patch extraction
|
# actual patch extraction
|
||||||
for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
|
for posy in range(
|
||||||
for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
|
CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size
|
||||||
|
):
|
||||||
|
for posx in range(
|
||||||
|
CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size
|
||||||
|
):
|
||||||
# this is a viable patch location, add it
|
# this is a viable patch location, add it
|
||||||
# note the ranges here:
|
# note the ranges here:
|
||||||
# + there are always CONTEXT pixels *before* the point
|
# + there are always CONTEXT pixels *before* the point
|
||||||
|
@ -247,7 +267,12 @@ elif cmd == "samplify":
|
||||||
# + additionally, there are sample_size - 1 additional sample pixels
|
# + additionally, there are sample_size - 1 additional sample pixels
|
||||||
# + additionally, there are CONTEXT additional pixels
|
# + additionally, there are CONTEXT additional pixels
|
||||||
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
|
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
|
||||||
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
|
patch_x = a_img[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
posy - CONTEXT : posy + CONTEXT + sample_size,
|
||||||
|
posx - CONTEXT : posx + CONTEXT + sample_size,
|
||||||
|
]
|
||||||
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
|
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
|
||||||
|
|
||||||
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
|
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
|
||||||
|
|
|
@ -11,6 +11,7 @@ from tinygrad.helpers import fetch
|
||||||
# tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
|
# tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
|
||||||
# tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
|
# tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
|
||||||
|
|
||||||
|
|
||||||
def image_load(path) -> numpy.ndarray:
|
def image_load(path) -> numpy.ndarray:
|
||||||
"""
|
"""
|
||||||
Loads an image in the shape expected by other functions in this module.
|
Loads an image in the shape expected by other functions in this module.
|
||||||
|
@ -29,6 +30,7 @@ def image_load(path) -> numpy.ndarray:
|
||||||
na = na.astype("float32") / 255.0
|
na = na.astype("float32") / 255.0
|
||||||
return na
|
return na
|
||||||
|
|
||||||
|
|
||||||
def image_save(path, na: numpy.ndarray):
|
def image_save(path, na: numpy.ndarray):
|
||||||
"""
|
"""
|
||||||
Saves an image of the shape expected by other functions in this module.
|
Saves an image of the shape expected by other functions in this module.
|
||||||
|
@ -44,12 +46,15 @@ def image_save(path, na: numpy.ndarray):
|
||||||
# file
|
# file
|
||||||
Image.fromarray(na).save(path)
|
Image.fromarray(na).save(path)
|
||||||
|
|
||||||
|
|
||||||
# The Model
|
# The Model
|
||||||
|
|
||||||
|
|
||||||
class Conv3x3Biased:
|
class Conv3x3Biased:
|
||||||
"""
|
"""
|
||||||
A 3x3 convolution layer with some utility functions.
|
A 3x3 convolution layer with some utility functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inC, outC, last=False):
|
def __init__(self, inC, outC, last=False):
|
||||||
# The properties must be named as "W" and "b".
|
# The properties must be named as "W" and "b".
|
||||||
# This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x
|
# This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x
|
||||||
|
@ -80,9 +85,12 @@ class Conv3x3Biased:
|
||||||
# Not outChannel,inChannel,Y,X.
|
# Not outChannel,inChannel,Y,X.
|
||||||
# Therefore, transpose it before assignment.
|
# Therefore, transpose it before assignment.
|
||||||
# I have long since forgotten how I worked this out.
|
# I have long since forgotten how I worked this out.
|
||||||
self.W.assign(Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3))
|
self.W.assign(
|
||||||
|
Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3)
|
||||||
|
)
|
||||||
self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape))
|
self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape))
|
||||||
|
|
||||||
|
|
||||||
class Vgg7:
|
class Vgg7:
|
||||||
"""
|
"""
|
||||||
The 'vgg7' waifu2x network.
|
The 'vgg7' waifu2x network.
|
||||||
|
@ -115,14 +123,31 @@ class Vgg7:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_parameters(self) -> list:
|
def get_parameters(self) -> list:
|
||||||
return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
|
return (
|
||||||
|
self.conv1.get_parameters()
|
||||||
|
+ self.conv2.get_parameters()
|
||||||
|
+ self.conv3.get_parameters()
|
||||||
|
+ self.conv4.get_parameters()
|
||||||
|
+ self.conv5.get_parameters()
|
||||||
|
+ self.conv6.get_parameters()
|
||||||
|
+ self.conv7.get_parameters()
|
||||||
|
)
|
||||||
|
|
||||||
def load_from_pretrained(self, intent="art", subtype="scale2.0x"):
|
def load_from_pretrained(self, intent="art", subtype="scale2.0x"):
|
||||||
"""
|
"""
|
||||||
Downloads a nagadomi/waifu2x JSON weight file and loads it.
|
Downloads a nagadomi/waifu2x JSON weight file and loads it.
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
data = json.loads(fetch("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json").read_bytes())
|
|
||||||
|
data = json.loads(
|
||||||
|
fetch(
|
||||||
|
"https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/"
|
||||||
|
+ intent
|
||||||
|
+ "/"
|
||||||
|
+ subtype
|
||||||
|
+ "_model.json"
|
||||||
|
).read_bytes()
|
||||||
|
)
|
||||||
self.load_waifu2x_json(data)
|
self.load_waifu2x_json(data)
|
||||||
|
|
||||||
def load_waifu2x_json(self, data: list):
|
def load_waifu2x_json(self, data: list):
|
||||||
|
@ -157,7 +182,9 @@ class Vgg7:
|
||||||
|
|
||||||
# Padding next. Note that this padding is done on the whole image.
|
# Padding next. Note that this padding is done on the whole image.
|
||||||
# Padding the tiles would lose critical context, cause seams, etc.
|
# Padding the tiles would lose critical context, cause seams, etc.
|
||||||
image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge")
|
image = numpy.pad(
|
||||||
|
image, [[0, 0], [0, 0], [context, context], [context, context]], mode="edge"
|
||||||
|
)
|
||||||
|
|
||||||
# Now for tiling.
|
# Now for tiling.
|
||||||
# The output tile size is the usable output from an input tile (tile_size).
|
# The output tile size is the usable output from an input tile (tile_size).
|
||||||
|
@ -187,7 +214,8 @@ class Vgg7:
|
||||||
tile_t = Tensor(tile)
|
tile_t = Tensor(tile)
|
||||||
tile_fwd_t = self.forward(tile_t)
|
tile_fwd_t = self.forward(tile_t)
|
||||||
# Replace tile.
|
# Replace tile.
|
||||||
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy()
|
image_out[
|
||||||
|
:, :, out_y : out_y + out_h, out_x : out_x + out_w
|
||||||
|
] = tile_fwd_t.numpy()
|
||||||
|
|
||||||
return image_out
|
return image_out
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from PIL import Image
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import getenv, fetch
|
from tinygrad.helpers import getenv, fetch
|
||||||
from extra.models.vit import ViT
|
from extra.models.vit import ViT
|
||||||
|
|
||||||
"""
|
"""
|
||||||
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -22,7 +23,11 @@ else:
|
||||||
m.load_from_pretrained()
|
m.load_from_pretrained()
|
||||||
|
|
||||||
# category labels
|
# category labels
|
||||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
lbls = ast.literal_eval(
|
||||||
|
fetch(
|
||||||
|
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||||
|
).read_text()
|
||||||
|
)
|
||||||
|
|
||||||
# url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
# url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
||||||
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
||||||
|
@ -30,7 +35,9 @@ url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-1
|
||||||
# junk
|
# junk
|
||||||
img = Image.open(fetch(url))
|
img = Image.open(fetch(url))
|
||||||
aspect_ratio = img.size[0] / img.size[1]
|
aspect_ratio = img.size[0] / img.size[1]
|
||||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
img = img.resize(
|
||||||
|
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
|
||||||
|
)
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
||||||
img = img[y0 : y0 + 224, x0 : x0 + 224]
|
img = img[y0 : y0 + 224, x0 : x0 + 224]
|
||||||
|
|
1857
examples/vits.py
1857
examples/vits.py
File diff suppressed because it is too large
Load Diff
|
@ -1,7 +1,13 @@
|
||||||
import os
|
import os
|
||||||
from extra.export_model import compile_net, jit_model
|
from extra.export_model import compile_net, jit_model
|
||||||
from examples.stable_diffusion import StableDiffusion
|
from examples.stable_diffusion import StableDiffusion
|
||||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
from tinygrad.nn.state import (
|
||||||
|
get_state_dict,
|
||||||
|
safe_save,
|
||||||
|
safe_load_metadata,
|
||||||
|
torch_load,
|
||||||
|
load_state_dict,
|
||||||
|
)
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
from tinygrad.helpers import fetch
|
from tinygrad.helpers import fetch
|
||||||
|
@ -10,10 +16,13 @@ from pathlib import Path
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def convert_f32_to_f16(input_file, output_file):
|
def convert_f32_to_f16(input_file, output_file):
|
||||||
with open(input_file, 'rb') as f:
|
with open(input_file, "rb") as f:
|
||||||
metadata_length_bytes = f.read(8)
|
metadata_length_bytes = f.read(8)
|
||||||
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
|
metadata_length = int.from_bytes(
|
||||||
|
metadata_length_bytes, byteorder="little", signed=False
|
||||||
|
)
|
||||||
metadata_json_bytes = f.read(metadata_length)
|
metadata_json_bytes = f.read(metadata_length)
|
||||||
float32_values = np.fromfile(f, dtype=np.float32)
|
float32_values = np.fromfile(f, dtype=np.float32)
|
||||||
|
|
||||||
|
@ -22,12 +31,13 @@ def convert_f32_to_f16(input_file, output_file):
|
||||||
front_float16_values = float32_values[:num_elements].astype(np.float16)
|
front_float16_values = float32_values[:num_elements].astype(np.float16)
|
||||||
rest_float32_values = float32_values[num_elements:]
|
rest_float32_values = float32_values[num_elements:]
|
||||||
|
|
||||||
with open(output_file, 'wb') as f:
|
with open(output_file, "wb") as f:
|
||||||
f.write(metadata_length_bytes)
|
f.write(metadata_length_bytes)
|
||||||
f.write(metadata_json_bytes)
|
f.write(metadata_json_bytes)
|
||||||
front_float16_values.tofile(f)
|
front_float16_values.tofile(f)
|
||||||
rest_float32_values.tofile(f)
|
rest_float32_values.tofile(f)
|
||||||
|
|
||||||
|
|
||||||
def split_safetensor(fn):
|
def split_safetensor(fn):
|
||||||
_, json_len, metadata = safe_load_metadata(fn)
|
_, json_len, metadata = safe_load_metadata(fn)
|
||||||
text_model_offset = 3772703308
|
text_model_offset = 3772703308
|
||||||
|
@ -35,7 +45,7 @@ def split_safetensor(fn):
|
||||||
|
|
||||||
for k in metadata:
|
for k in metadata:
|
||||||
# safetensor is in fp16, except for text moel
|
# safetensor is in fp16, except for text moel
|
||||||
if (metadata[k]["data_offsets"][0] < text_model_offset):
|
if metadata[k]["data_offsets"][0] < text_model_offset:
|
||||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0] / 2)
|
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0] / 2)
|
||||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1] / 2)
|
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1] / 2)
|
||||||
|
|
||||||
|
@ -43,35 +53,43 @@ def split_safetensor(fn):
|
||||||
part_end_offsets = []
|
part_end_offsets = []
|
||||||
|
|
||||||
for k in metadata:
|
for k in metadata:
|
||||||
offset = metadata[k]['data_offsets'][0]
|
offset = metadata[k]["data_offsets"][0]
|
||||||
|
|
||||||
if offset == text_model_offset:
|
if offset == text_model_offset:
|
||||||
break
|
break
|
||||||
|
|
||||||
part_offset = offset - last_offset
|
part_offset = offset - last_offset
|
||||||
|
|
||||||
if (part_offset >= chunk_size):
|
if part_offset >= chunk_size:
|
||||||
part_end_offsets.append(8 + json_len + offset)
|
part_end_offsets.append(8 + json_len + offset)
|
||||||
last_offset = offset
|
last_offset = offset
|
||||||
|
|
||||||
text_model_start = int(text_model_offset / 2)
|
text_model_start = int(text_model_offset / 2)
|
||||||
net_bytes = bytes(open(fn, 'rb').read())
|
net_bytes = bytes(open(fn, "rb").read())
|
||||||
part_end_offsets.append(text_model_start + 8 + json_len)
|
part_end_offsets.append(text_model_start + 8 + json_len)
|
||||||
cur_pos = 0
|
cur_pos = 0
|
||||||
|
|
||||||
for i, end_pos in enumerate(part_end_offsets):
|
for i, end_pos in enumerate(part_end_offsets):
|
||||||
with open(f'./net_part{i}.safetensors', "wb+") as f:
|
with open(f"./net_part{i}.safetensors", "wb+") as f:
|
||||||
f.write(net_bytes[cur_pos:end_pos])
|
f.write(net_bytes[cur_pos:end_pos])
|
||||||
cur_pos = end_pos
|
cur_pos = end_pos
|
||||||
|
|
||||||
with open(f'./net_textmodel.safetensors', "wb+") as f:
|
with open(f"./net_textmodel.safetensors", "wb+") as f:
|
||||||
f.write(net_bytes[text_model_start + 8 + json_len :])
|
f.write(net_bytes[text_model_start + 8 + json_len :])
|
||||||
|
|
||||||
return part_end_offsets
|
return part_end_offsets
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
|
description="Run Stable Diffusion",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--remoteweights",
|
||||||
|
action="store_true",
|
||||||
|
help="Use safetensors from Huggingface, or from local",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
Device.DEFAULT = "WEBGPU"
|
Device.DEFAULT = "WEBGPU"
|
||||||
|
|
||||||
|
@ -79,7 +97,16 @@ if __name__ == "__main__":
|
||||||
model = StableDiffusion()
|
model = StableDiffusion()
|
||||||
|
|
||||||
# load in weights
|
# load in weights
|
||||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
load_state_dict(
|
||||||
|
model,
|
||||||
|
torch_load(
|
||||||
|
fetch(
|
||||||
|
"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
|
||||||
|
"sd-v1-4.ckpt",
|
||||||
|
)
|
||||||
|
)["state_dict"],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
class Step(NamedTuple):
|
class Step(NamedTuple):
|
||||||
name: str = ""
|
name: str = ""
|
||||||
|
@ -87,9 +114,25 @@ if __name__ == "__main__":
|
||||||
forward: Any = None
|
forward: Any = None
|
||||||
|
|
||||||
sub_steps = [
|
sub_steps = [
|
||||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
Step(
|
||||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
name="textModel",
|
||||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
|
input=[Tensor.randn(1, 77)],
|
||||||
|
forward=model.cond_stage_model.transformer.text_model,
|
||||||
|
),
|
||||||
|
Step(
|
||||||
|
name="diffusor",
|
||||||
|
input=[
|
||||||
|
Tensor.randn(1, 77, 768),
|
||||||
|
Tensor.randn(1, 77, 768),
|
||||||
|
Tensor.randn(1, 4, 64, 64),
|
||||||
|
Tensor.rand(1),
|
||||||
|
Tensor.randn(1),
|
||||||
|
Tensor.randn(1),
|
||||||
|
Tensor.randn(1),
|
||||||
|
],
|
||||||
|
forward=model,
|
||||||
|
),
|
||||||
|
Step(name="decoder", input=[Tensor.randn(1, 4, 64, 64)], forward=model.decode),
|
||||||
]
|
]
|
||||||
|
|
||||||
prg = ""
|
prg = ""
|
||||||
|
@ -99,12 +142,47 @@ if __name__ == "__main__":
|
||||||
functions, statements, bufs, _ = compile_net(run, special_names)
|
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||||
state = get_state_dict(model)
|
state = get_state_dict(model)
|
||||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
kernel_code = "\n\n".join(
|
||||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
[
|
||||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
f"const {key} = `{code.replace(key, 'main')}`;"
|
||||||
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
for key, code in functions.items()
|
||||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
]
|
||||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
|
)
|
||||||
|
kernel_names = ", ".join([name for (name, _, _, _) in statements])
|
||||||
|
kernel_calls = "\n ".join(
|
||||||
|
[
|
||||||
|
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
|
||||||
|
for i, (_name, args, global_size, _local_size) in enumerate(statements)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
bufs = "\n ".join(
|
||||||
|
[
|
||||||
|
f"const {name} = "
|
||||||
|
+ (
|
||||||
|
f"createEmptyBuf(device, {size});"
|
||||||
|
if _key not in weights
|
||||||
|
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))"
|
||||||
|
)
|
||||||
|
+ ";"
|
||||||
|
for name, (size, dtype, _key) in bufs.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gpu_write_bufs = "\n ".join(
|
||||||
|
[
|
||||||
|
f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
|
||||||
|
for i, (_, value) in enumerate(special_names.items())
|
||||||
|
if "output" not in value
|
||||||
|
]
|
||||||
|
)
|
||||||
|
input_writer = "\n ".join(
|
||||||
|
[
|
||||||
|
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
|
||||||
|
+ f"data{i});"
|
||||||
|
+ f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);"
|
||||||
|
for i, (_, value) in enumerate(special_names.items())
|
||||||
|
if value != "output0"
|
||||||
|
]
|
||||||
|
)
|
||||||
return f"""\n var {step.name} = function() {{
|
return f"""\n var {step.name} = function() {{
|
||||||
|
|
||||||
{kernel_code}
|
{kernel_code}
|
||||||
|
@ -143,7 +221,7 @@ if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for step in sub_steps:
|
for step in sub_steps:
|
||||||
print(f'Executing step={step.name}')
|
print(f"Executing step={step.name}")
|
||||||
prg += compile_step(model, step)
|
prg += compile_step(model, step)
|
||||||
|
|
||||||
if step.name == "diffusor":
|
if step.name == "diffusor":
|
||||||
|
@ -151,7 +229,9 @@ if __name__ == "__main__":
|
||||||
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
|
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
|
||||||
else:
|
else:
|
||||||
state = get_state_dict(model)
|
state = get_state_dict(model)
|
||||||
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
safe_save(
|
||||||
|
state, os.path.join(os.path.dirname(__file__), "net.safetensors")
|
||||||
|
)
|
||||||
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
|
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
|
||||||
split_safetensor("./net_conv.safetensors")
|
split_safetensor("./net_conv.safetensors")
|
||||||
os.remove("net.safetensors")
|
os.remove("net.safetensors")
|
||||||
|
|
|
@ -15,8 +15,15 @@ from tinygrad.tensor import Tensor
|
||||||
import itertools
|
import itertools
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention:
|
class MultiHeadAttention:
|
||||||
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_state,
|
||||||
|
n_head,
|
||||||
|
kv_caching: Literal["cross", "self"] = None,
|
||||||
|
max_self_attn_cache_len=None,
|
||||||
|
):
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.query = nn.Linear(n_state, n_state)
|
self.query = nn.Linear(n_state, n_state)
|
||||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||||
|
@ -26,11 +33,17 @@ class MultiHeadAttention:
|
||||||
self.kv_caching = kv_caching
|
self.kv_caching = kv_caching
|
||||||
self.max_self_attn_cache_len = max_self_attn_cache_len
|
self.max_self_attn_cache_len = max_self_attn_cache_len
|
||||||
|
|
||||||
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
|
def __call__(
|
||||||
if self.kv_caching == 'cross':
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Optional[Tensor] = None,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
len: Union[Variable, int] = None,
|
||||||
|
):
|
||||||
|
if self.kv_caching == "cross":
|
||||||
if xa is not None:
|
if xa is not None:
|
||||||
k, v = self.key(xa), self.value(xa)
|
k, v = self.key(xa), self.value(xa)
|
||||||
if not hasattr(self, 'cache_k'):
|
if not hasattr(self, "cache_k"):
|
||||||
self.cache_k, self.cache_v = k, v
|
self.cache_k, self.cache_v = k, v
|
||||||
else:
|
else:
|
||||||
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
|
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
|
||||||
|
@ -40,50 +53,84 @@ class MultiHeadAttention:
|
||||||
k, v = self.cache_k, self.cache_v
|
k, v = self.cache_k, self.cache_v
|
||||||
else:
|
else:
|
||||||
k, v = self.key(x), self.value(x)
|
k, v = self.key(x), self.value(x)
|
||||||
if self.kv_caching == 'self':
|
if self.kv_caching == "self":
|
||||||
if not hasattr(self, 'cache_k'):
|
if not hasattr(self, "cache_k"):
|
||||||
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
self.cache_k = Tensor.zeros(
|
||||||
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
|
||||||
|
)
|
||||||
|
self.cache_v = Tensor.zeros(
|
||||||
|
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
|
||||||
|
)
|
||||||
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
||||||
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
||||||
padding = self.max_self_attn_cache_len - len - x.shape[1]
|
padding = self.max_self_attn_cache_len - len - x.shape[1]
|
||||||
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
|
self.cache_k.assign(
|
||||||
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
|
k.pad((None, (0, padding), None)).contiguous()
|
||||||
|
).realize()
|
||||||
|
self.cache_v.assign(
|
||||||
|
v.pad((None, (0, padding), None)).contiguous()
|
||||||
|
).realize()
|
||||||
|
|
||||||
q = self.query(x)
|
q = self.query(x)
|
||||||
n_ctx = q.shape[1]
|
n_ctx = q.shape[1]
|
||||||
assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
|
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
||||||
head_dim = q.shape[-1] // self.n_head
|
head_dim = q.shape[-1] // self.n_head
|
||||||
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||||
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||||
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||||
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
|
attn = Tensor.scaled_dot_product_attention(
|
||||||
|
q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None
|
||||||
|
)
|
||||||
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
return self.out(wv)
|
return self.out(wv)
|
||||||
|
|
||||||
|
|
||||||
class ResidualAttentionBlock:
|
class ResidualAttentionBlock:
|
||||||
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
|
def __init__(
|
||||||
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
|
self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None
|
||||||
|
):
|
||||||
|
self.attn = MultiHeadAttention(
|
||||||
|
n_state,
|
||||||
|
n_head,
|
||||||
|
kv_caching="self" if is_decoder_block else None,
|
||||||
|
max_self_attn_cache_len=max_self_attn_cache_len,
|
||||||
|
)
|
||||||
self.attn_ln = nn.LayerNorm(n_state)
|
self.attn_ln = nn.LayerNorm(n_state)
|
||||||
|
|
||||||
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
|
self.cross_attn = (
|
||||||
|
MultiHeadAttention(n_state, n_head, kv_caching="cross")
|
||||||
|
if is_decoder_block
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
||||||
|
|
||||||
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
|
self.mlp = [
|
||||||
|
nn.Linear(n_state, n_state * 4),
|
||||||
|
Tensor.gelu,
|
||||||
|
nn.Linear(n_state * 4, n_state),
|
||||||
|
]
|
||||||
self.mlp_ln = nn.LayerNorm(n_state)
|
self.mlp_ln = nn.LayerNorm(n_state)
|
||||||
|
|
||||||
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None):
|
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None):
|
||||||
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
||||||
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
if self.cross_attn:
|
||||||
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||||
x = x + self.mlp_ln(x).sequential(self.mlp)
|
x = x + self.mlp_ln(x).sequential(self.mlp)
|
||||||
return x.realize()
|
return x.realize()
|
||||||
|
|
||||||
|
|
||||||
class AudioEncoder:
|
class AudioEncoder:
|
||||||
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
|
def __init__(
|
||||||
|
self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_
|
||||||
|
):
|
||||||
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
|
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
|
||||||
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
|
self.conv2 = nn.Conv1d(
|
||||||
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
|
n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
|
self.blocks = [
|
||||||
|
ResidualAttentionBlock(n_audio_state, n_audio_head)
|
||||||
|
for _ in range(n_audio_layer)
|
||||||
|
]
|
||||||
self.ln_post = nn.LayerNorm(n_audio_state)
|
self.ln_post = nn.LayerNorm(n_audio_state)
|
||||||
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
||||||
self.encode = TinyJit(self.__call__)
|
self.encode = TinyJit(self.__call__)
|
||||||
|
@ -97,14 +144,27 @@ class AudioEncoder:
|
||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
return x.realize()
|
return x.realize()
|
||||||
|
|
||||||
|
|
||||||
class TextDecoder:
|
class TextDecoder:
|
||||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
def __init__(
|
||||||
|
self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_
|
||||||
|
):
|
||||||
self.max_tokens_to_sample = n_text_ctx // 2
|
self.max_tokens_to_sample = n_text_ctx // 2
|
||||||
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
|
self.max_self_attn_cache_len = (
|
||||||
|
self.max_tokens_to_sample * 2 + 5
|
||||||
|
) # roughly prompt + start toks + max_tokens_to_sample
|
||||||
|
|
||||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||||
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
|
self.blocks = [
|
||||||
|
ResidualAttentionBlock(
|
||||||
|
n_text_state,
|
||||||
|
n_text_head,
|
||||||
|
is_decoder_block=True,
|
||||||
|
max_self_attn_cache_len=self.max_self_attn_cache_len,
|
||||||
|
)
|
||||||
|
for _ in range(n_text_layer)
|
||||||
|
]
|
||||||
self.ln = nn.LayerNorm(n_text_state)
|
self.ln = nn.LayerNorm(n_text_state)
|
||||||
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
||||||
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||||
|
@ -117,18 +177,23 @@ class TextDecoder:
|
||||||
seqlen = x.shape[-1]
|
seqlen = x.shape[-1]
|
||||||
x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen]
|
x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen]
|
||||||
if pos == 0:
|
if pos == 0:
|
||||||
for block in (self.blocks if streaming else self.blocks_start_tok):
|
for block in self.blocks if streaming else self.blocks_start_tok:
|
||||||
x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
|
x = block(
|
||||||
|
x, xa=encoded_audio, mask=self.mask, len=0
|
||||||
|
) # pass xa for cross attn kv caching
|
||||||
return self.output_tok(x) if streaming else self.start_output_tok(x)
|
return self.output_tok(x) if streaming else self.start_output_tok(x)
|
||||||
else:
|
else:
|
||||||
for block in self.blocks_after_start_tok:
|
for block in self.blocks_after_start_tok:
|
||||||
len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
|
len_v = Variable(
|
||||||
|
"self_attn_cache_len", 1, self.max_self_attn_cache_len
|
||||||
|
).bind(pos)
|
||||||
x = block(x, mask=self.mask, len=len_v)
|
x = block(x, mask=self.mask, len=len_v)
|
||||||
return self.after_start_output_tok(x)
|
return self.after_start_output_tok(x)
|
||||||
|
|
||||||
def output_tok(self, x):
|
def output_tok(self, x):
|
||||||
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
||||||
|
|
||||||
|
|
||||||
class Whisper:
|
class Whisper:
|
||||||
def __init__(self, dims, batch_size=1):
|
def __init__(self, dims, batch_size=1):
|
||||||
self.encoder = AudioEncoder(**dims)
|
self.encoder = AudioEncoder(**dims)
|
||||||
|
@ -145,7 +210,10 @@ HOP_LENGTH = 160
|
||||||
N_MELS = 80
|
N_MELS = 80
|
||||||
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
|
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
|
||||||
|
|
||||||
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
|
|
||||||
|
def prep_audio(
|
||||||
|
waveforms: List[np.ndarray], batch_size: int, truncate=False
|
||||||
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
:param waveforms: A list of possibly variable length 16000Hz audio samples
|
:param waveforms: A list of possibly variable length 16000Hz audio samples
|
||||||
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
|
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
|
||||||
|
@ -153,24 +221,30 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) ->
|
||||||
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
|
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
|
||||||
:return: mel spectrogram of the given waveforms
|
:return: mel spectrogram of the given waveforms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pad_or_trim(arr, target_len):
|
def pad_or_trim(arr, target_len):
|
||||||
curr_len = len(arr)
|
curr_len = len(arr)
|
||||||
if curr_len == target_len:
|
if curr_len == target_len:
|
||||||
return arr
|
return arr
|
||||||
elif curr_len < target_len:
|
elif curr_len < target_len:
|
||||||
return np.pad(arr, (0, target_len - curr_len), 'constant')
|
return np.pad(arr, (0, target_len - curr_len), "constant")
|
||||||
else:
|
else:
|
||||||
return arr[:target_len]
|
return arr[:target_len]
|
||||||
|
|
||||||
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
|
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
|
||||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
|
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
|
||||||
|
max_len += SAMPLES_PER_SEGMENT - r
|
||||||
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
||||||
assert waveforms.shape[0] <= batch_size
|
assert waveforms.shape[0] <= batch_size
|
||||||
if waveforms.shape[0] < batch_size:
|
if waveforms.shape[0] < batch_size:
|
||||||
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
|
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
|
||||||
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
|
waveforms = np.pad(
|
||||||
|
waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0))
|
||||||
|
)
|
||||||
|
|
||||||
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
stft = librosa.stft(
|
||||||
|
waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann", dtype=np.csingle
|
||||||
|
)
|
||||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||||
|
|
||||||
|
@ -180,22 +254,118 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) ->
|
||||||
|
|
||||||
return log_spec
|
return log_spec
|
||||||
|
|
||||||
|
|
||||||
LANGUAGES = {
|
LANGUAGES = {
|
||||||
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
|
"en": "english",
|
||||||
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
|
"zh": "chinese",
|
||||||
"he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian",
|
"de": "german",
|
||||||
"th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu",
|
"es": "spanish",
|
||||||
"fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
|
"ru": "russian",
|
||||||
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
|
"ko": "korean",
|
||||||
"gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian",
|
"fr": "french",
|
||||||
"be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole",
|
"ja": "japanese",
|
||||||
"ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy",
|
"pt": "portuguese",
|
||||||
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
|
"tr": "turkish",
|
||||||
|
"pl": "polish",
|
||||||
|
"ca": "catalan",
|
||||||
|
"nl": "dutch",
|
||||||
|
"ar": "arabic",
|
||||||
|
"sv": "swedish",
|
||||||
|
"it": "italian",
|
||||||
|
"id": "indonesian",
|
||||||
|
"hi": "hindi",
|
||||||
|
"fi": "finnish",
|
||||||
|
"vi": "vietnamese",
|
||||||
|
"he": "hebrew",
|
||||||
|
"uk": "ukrainian",
|
||||||
|
"el": "greek",
|
||||||
|
"ms": "malay",
|
||||||
|
"cs": "czech",
|
||||||
|
"ro": "romanian",
|
||||||
|
"da": "danish",
|
||||||
|
"hu": "hungarian",
|
||||||
|
"ta": "tamil",
|
||||||
|
"no": "norwegian",
|
||||||
|
"th": "thai",
|
||||||
|
"ur": "urdu",
|
||||||
|
"hr": "croatian",
|
||||||
|
"bg": "bulgarian",
|
||||||
|
"lt": "lithuanian",
|
||||||
|
"la": "latin",
|
||||||
|
"mi": "maori",
|
||||||
|
"ml": "malayalam",
|
||||||
|
"cy": "welsh",
|
||||||
|
"sk": "slovak",
|
||||||
|
"te": "telugu",
|
||||||
|
"fa": "persian",
|
||||||
|
"lv": "latvian",
|
||||||
|
"bn": "bengali",
|
||||||
|
"sr": "serbian",
|
||||||
|
"az": "azerbaijani",
|
||||||
|
"sl": "slovenian",
|
||||||
|
"kn": "kannada",
|
||||||
|
"et": "estonian",
|
||||||
|
"mk": "macedonian",
|
||||||
|
"br": "breton",
|
||||||
|
"eu": "basque",
|
||||||
|
"is": "icelandic",
|
||||||
|
"hy": "armenian",
|
||||||
|
"ne": "nepali",
|
||||||
|
"mn": "mongolian",
|
||||||
|
"bs": "bosnian",
|
||||||
|
"kk": "kazakh",
|
||||||
|
"sq": "albanian",
|
||||||
|
"sw": "swahili",
|
||||||
|
"gl": "galician",
|
||||||
|
"mr": "marathi",
|
||||||
|
"pa": "punjabi",
|
||||||
|
"si": "sinhala",
|
||||||
|
"km": "khmer",
|
||||||
|
"sn": "shona",
|
||||||
|
"yo": "yoruba",
|
||||||
|
"so": "somali",
|
||||||
|
"af": "afrikaans",
|
||||||
|
"oc": "occitan",
|
||||||
|
"ka": "georgian",
|
||||||
|
"be": "belarusian",
|
||||||
|
"tg": "tajik",
|
||||||
|
"sd": "sindhi",
|
||||||
|
"gu": "gujarati",
|
||||||
|
"am": "amharic",
|
||||||
|
"yi": "yiddish",
|
||||||
|
"lo": "lao",
|
||||||
|
"uz": "uzbek",
|
||||||
|
"fo": "faroese",
|
||||||
|
"ht": "haitian creole",
|
||||||
|
"ps": "pashto",
|
||||||
|
"tk": "turkmen",
|
||||||
|
"nn": "nynorsk",
|
||||||
|
"mt": "maltese",
|
||||||
|
"sa": "sanskrit",
|
||||||
|
"lb": "luxembourgish",
|
||||||
|
"my": "myanmar",
|
||||||
|
"bo": "tibetan",
|
||||||
|
"tl": "tagalog",
|
||||||
|
"mg": "malagasy",
|
||||||
|
"as": "assamese",
|
||||||
|
"tt": "tatar",
|
||||||
|
"haw": "hawaiian",
|
||||||
|
"ln": "lingala",
|
||||||
|
"ha": "hausa",
|
||||||
|
"ba": "bashkir",
|
||||||
|
"jw": "javanese",
|
||||||
|
"su": "sundanese",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_encoding(encoding_name):
|
def get_encoding(encoding_name):
|
||||||
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
|
with fetch(
|
||||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
|
f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken"
|
||||||
|
).open() as f:
|
||||||
|
ranks = {
|
||||||
|
base64.b64decode(token): int(rank)
|
||||||
|
for token, rank in (line.split() for line in f if line)
|
||||||
|
}
|
||||||
n_vocab = len(ranks)
|
n_vocab = len(ranks)
|
||||||
specials = [
|
specials = [
|
||||||
"<|endoftext|>",
|
"<|endoftext|>",
|
||||||
|
@ -212,12 +382,15 @@ def get_encoding(encoding_name):
|
||||||
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
||||||
n_vocab += len(specials)
|
n_vocab += len(specials)
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
return tiktoken.Encoding(
|
return tiktoken.Encoding(
|
||||||
name=encoding_name,
|
name=encoding_name,
|
||||||
explicit_n_vocab=n_vocab,
|
explicit_n_vocab=n_vocab,
|
||||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||||
mergeable_ranks=ranks,
|
mergeable_ranks=ranks,
|
||||||
special_tokens=special_tokens)
|
special_tokens=special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_URLS = {
|
MODEL_URLS = {
|
||||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
|
@ -232,23 +405,28 @@ MODEL_URLS = {
|
||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def init_whisper(model_name="tiny.en", batch_size=1):
|
def init_whisper(model_name="tiny.en", batch_size=1):
|
||||||
assert MODEL_URLS[model_name] is not None
|
assert MODEL_URLS[model_name] is not None
|
||||||
|
|
||||||
filename = fetch(MODEL_URLS[model_name])
|
filename = fetch(MODEL_URLS[model_name])
|
||||||
state = torch_load(filename)
|
state = torch_load(filename)
|
||||||
model = Whisper(state['dims'], batch_size)
|
model = Whisper(state["dims"], batch_size)
|
||||||
load_state_dict(model, state['model_state_dict'], strict=False)
|
load_state_dict(model, state["model_state_dict"], strict=False)
|
||||||
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
||||||
return model, enc
|
return model, enc
|
||||||
|
|
||||||
|
|
||||||
def load_file_waveform(filename):
|
def load_file_waveform(filename):
|
||||||
waveform, _ = librosa.load(filename, sr=RATE)
|
waveform, _ = librosa.load(filename, sr=RATE)
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
|
|
||||||
def transcribe_file(model, enc, filename):
|
def transcribe_file(model, enc, filename):
|
||||||
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
|
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
|
||||||
|
|
||||||
|
|
||||||
def transcribe_waveform(model, enc, waveforms, truncate=False):
|
def transcribe_waveform(model, enc, waveforms, truncate=False):
|
||||||
"""
|
"""
|
||||||
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
|
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
|
||||||
|
@ -260,12 +438,18 @@ def transcribe_waveform(model, enc, waveforms, truncate=False):
|
||||||
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
|
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
|
||||||
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
|
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
|
||||||
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
|
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
|
||||||
raise Exception("Multi-segment transcription not supported with batch audio input")
|
raise Exception(
|
||||||
|
"Multi-segment transcription not supported with batch audio input"
|
||||||
|
)
|
||||||
|
|
||||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||||
if model.is_multilingual:
|
if model.is_multilingual:
|
||||||
# TODO detect language
|
# TODO detect language
|
||||||
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
|
language_token = (
|
||||||
|
enc._special_tokens["<|startoftranscript|>"]
|
||||||
|
+ 1
|
||||||
|
+ tuple(LANGUAGES.keys()).index("en")
|
||||||
|
)
|
||||||
start_tokens.append(language_token)
|
start_tokens.append(language_token)
|
||||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||||
|
@ -274,52 +458,83 @@ def transcribe_waveform(model, enc, waveforms, truncate=False):
|
||||||
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
|
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
|
||||||
|
|
||||||
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
||||||
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
|
encoded_audio = model.encoder.encode(
|
||||||
|
Tensor(log_spec[:, :, curr_frame : curr_frame + FRAMES_PER_SEGMENT])
|
||||||
|
)
|
||||||
pos = 0
|
pos = 0
|
||||||
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
||||||
if curr_frame > 0:
|
if curr_frame > 0:
|
||||||
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||||
prompt = np.concatenate((
|
prompt = np.concatenate(
|
||||||
|
(
|
||||||
[enc._special_tokens["<|startofprev|>"]],
|
[enc._special_tokens["<|startofprev|>"]],
|
||||||
transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :],
|
transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :],
|
||||||
start_tokens))
|
start_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
|
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
|
||||||
transcription_start_index = len(curr_segment_tokens[0])
|
transcription_start_index = len(curr_segment_tokens[0])
|
||||||
|
|
||||||
for i in range(model.decoder.max_tokens_to_sample):
|
for i in range(model.decoder.max_tokens_to_sample):
|
||||||
out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
|
out = model.decoder(
|
||||||
|
Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]),
|
||||||
|
pos,
|
||||||
|
encoded_audio,
|
||||||
|
streaming=curr_frame > 0,
|
||||||
|
)
|
||||||
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
||||||
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
|
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
|
||||||
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
|
curr_segment_tokens = np.concatenate(
|
||||||
|
(curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1
|
||||||
|
)
|
||||||
pos = curr_segment_tokens.shape[-1] - 1
|
pos = curr_segment_tokens.shape[-1] - 1
|
||||||
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
|
if DEBUG >= 1:
|
||||||
|
print(
|
||||||
|
i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens))
|
||||||
|
)
|
||||||
if (curr_segment_tokens[:, -1] == eot).all():
|
if (curr_segment_tokens[:, -1] == eot).all():
|
||||||
break
|
break
|
||||||
|
|
||||||
for i, t in enumerate(curr_segment_tokens):
|
for i, t in enumerate(curr_segment_tokens):
|
||||||
eot_index = np.where(t == eot)[0]
|
eot_index = np.where(t == eot)[0]
|
||||||
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
||||||
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
|
transcription_tokens[i] = np.concatenate(
|
||||||
|
(transcription_tokens[i], t[transcription_start_index:eot_index])
|
||||||
|
)
|
||||||
|
|
||||||
transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
|
transcriptions = list(
|
||||||
|
map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens)
|
||||||
|
)
|
||||||
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
|
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
|
||||||
|
|
||||||
|
|
||||||
CHUNK = 1600
|
CHUNK = 1600
|
||||||
RECORD_SECONDS = 10
|
RECORD_SECONDS = 10
|
||||||
|
|
||||||
|
|
||||||
def listener(q):
|
def listener(q):
|
||||||
import pyaudio
|
import pyaudio
|
||||||
|
|
||||||
p = pyaudio.PyAudio()
|
p = pyaudio.PyAudio()
|
||||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
stream = p.open(
|
||||||
|
format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=RATE,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=CHUNK,
|
||||||
|
)
|
||||||
print("listening")
|
print("listening")
|
||||||
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||||
data = stream.read(CHUNK)
|
data = stream.read(CHUNK)
|
||||||
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
|
waveform = (np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3
|
||||||
q.put(waveform)
|
q.put(waveform)
|
||||||
print("done listening")
|
print("done listening")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
|
model, enc = init_whisper(
|
||||||
|
"small.en" if getenv("SMALL") else "tiny.en", batch_size=1
|
||||||
|
)
|
||||||
|
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
print(transcribe_file(model, enc, sys.argv[1]))
|
print(transcribe_file(model, enc, sys.argv[1]))
|
||||||
|
@ -330,20 +545,29 @@ if __name__ == "__main__":
|
||||||
p.daemon = True
|
p.daemon = True
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
lst = [
|
||||||
|
enc._special_tokens["<|startoftranscript|>"],
|
||||||
|
enc._special_tokens["<|notimestamps|>"],
|
||||||
|
]
|
||||||
total = None
|
total = None
|
||||||
did_read = False
|
did_read = False
|
||||||
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||||
while not q.empty() or total is None:
|
while not q.empty() or total is None:
|
||||||
waveform = q.get()
|
waveform = q.get()
|
||||||
if total is None: total = waveform
|
if total is None:
|
||||||
else: total = np.concatenate([total, waveform])
|
total = waveform
|
||||||
|
else:
|
||||||
|
total = np.concatenate([total, waveform])
|
||||||
did_read = True
|
did_read = True
|
||||||
if did_read:
|
if did_read:
|
||||||
log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
|
log_spec = prep_audio(
|
||||||
|
total.reshape(1, -1), model.batch_size, truncate=True
|
||||||
|
)
|
||||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||||
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||||
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
|
out = model.decoder(
|
||||||
|
Tensor([lst]), 0, encoded_audio, streaming=True
|
||||||
|
).realize()
|
||||||
idx = int(out[0, -1].argmax().numpy().item())
|
idx = int(out[0, -1].argmax().numpy().item())
|
||||||
lst.append(idx)
|
lst.append(idx)
|
||||||
dec = enc.decode(lst)
|
dec = enc.decode(lst)
|
||||||
|
|
|
@ -10,11 +10,14 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import BatchNorm2d, Conv2d
|
from tinygrad.nn import BatchNorm2d, Conv2d
|
||||||
from tinygrad.helpers import fetch
|
from tinygrad.helpers import fetch
|
||||||
|
|
||||||
|
|
||||||
def show_labels(prediction, confidence=0.5, num_classes=80):
|
def show_labels(prediction, confidence=0.5, num_classes=80):
|
||||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_bytes()
|
coco_labels = fetch(
|
||||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||||
|
).read_bytes()
|
||||||
|
coco_labels = coco_labels.decode("utf-8").split("\n")
|
||||||
prediction = prediction.detach().numpy()
|
prediction = prediction.detach().numpy()
|
||||||
conf_mask = (prediction[:,:,4] > confidence)
|
conf_mask = prediction[:, :, 4] > confidence
|
||||||
prediction *= np.expand_dims(conf_mask, 2)
|
prediction *= np.expand_dims(conf_mask, 2)
|
||||||
labels = []
|
labels = []
|
||||||
# Iterate over batches
|
# Iterate over batches
|
||||||
|
@ -30,16 +33,22 @@ def show_labels(prediction, confidence=0.5, num_classes=80):
|
||||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7))
|
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7))
|
||||||
classes, indexes = np.unique(image_pred_[:, -1], return_index=True)
|
classes, indexes = np.unique(image_pred_[:, -1], return_index=True)
|
||||||
for index, coco_class in enumerate(classes):
|
for index, coco_class in enumerate(classes):
|
||||||
label, probability = coco_labels[int(coco_class)], image_pred_[indexes[index]][4] * 100
|
label, probability = (
|
||||||
|
coco_labels[int(coco_class)],
|
||||||
|
image_pred_[indexes[index]][4] * 100,
|
||||||
|
)
|
||||||
print(f"Detected {label} {probability:.2f}")
|
print(f"Detected {label} {probability:.2f}")
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
def add_boxes(img, prediction):
|
def add_boxes(img, prediction):
|
||||||
if isinstance(prediction, int): # no predictions
|
if isinstance(prediction, int): # no predictions
|
||||||
return img
|
return img
|
||||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
|
coco_labels = fetch(
|
||||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||||
|
)
|
||||||
|
coco_labels = coco_labels.decode("utf-8").split("\n")
|
||||||
height, width = img.shape[0:2]
|
height, width = img.shape[0:2]
|
||||||
scale_factor = 608 / width
|
scale_factor = 608 / width
|
||||||
prediction[:, [1, 3]] -= (608 - scale_factor * width) / 2
|
prediction[:, [1, 3]] -= (608 - scale_factor * width) / 2
|
||||||
|
@ -55,9 +64,18 @@ def add_boxes(img, prediction):
|
||||||
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
|
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
|
||||||
c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4
|
c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4
|
||||||
img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1)
|
img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1)
|
||||||
img = cv2.putText(img, label, (corner1[0], corner1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1)
|
img = cv2.putText(
|
||||||
|
img,
|
||||||
|
label,
|
||||||
|
(corner1[0], corner1[1] + t_size[1] + 4),
|
||||||
|
cv2.FONT_HERSHEY_PLAIN,
|
||||||
|
1,
|
||||||
|
[225, 255, 255],
|
||||||
|
1,
|
||||||
|
)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def bbox_iou(box1, box2):
|
def bbox_iou(box1, box2):
|
||||||
"""
|
"""
|
||||||
Returns the IoU of two bounding boxes
|
Returns the IoU of two bounding boxes
|
||||||
|
@ -74,24 +92,27 @@ def bbox_iou(box1, box2):
|
||||||
inter_rect_x2 = np.maximum(b1_x2, b2_x2)
|
inter_rect_x2 = np.maximum(b1_x2, b2_x2)
|
||||||
inter_rect_y2 = np.maximum(b1_y2, b2_y2)
|
inter_rect_y2 = np.maximum(b1_y2, b2_y2)
|
||||||
# Intersection area
|
# Intersection area
|
||||||
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, 99999)
|
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(
|
||||||
|
inter_rect_y2 - inter_rect_y1 + 1, 0, 99999
|
||||||
|
)
|
||||||
# Union Area
|
# Union Area
|
||||||
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
|
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
|
||||||
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
|
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
|
||||||
iou = inter_area / (b1_area + b2_area - inter_area)
|
iou = inter_area / (b1_area + b2_area - inter_area)
|
||||||
return iou
|
return iou
|
||||||
|
|
||||||
|
|
||||||
def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
||||||
prediction = prediction.detach().numpy()
|
prediction = prediction.detach().numpy()
|
||||||
conf_mask = (prediction[:,:,4] > confidence)
|
conf_mask = prediction[:, :, 4] > confidence
|
||||||
conf_mask = np.expand_dims(conf_mask, 2)
|
conf_mask = np.expand_dims(conf_mask, 2)
|
||||||
prediction = prediction * conf_mask
|
prediction = prediction * conf_mask
|
||||||
# Non max suppression
|
# Non max suppression
|
||||||
box_corner = prediction
|
box_corner = prediction
|
||||||
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
|
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
||||||
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
|
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
||||||
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
|
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
||||||
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
|
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
||||||
prediction[:, :, :4] = box_corner[:, :, :4]
|
prediction[:, :, :4] = box_corner[:, :, :4]
|
||||||
write = False
|
write = False
|
||||||
# Process img
|
# Process img
|
||||||
|
@ -121,7 +142,10 @@ def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
||||||
for i in range(image_pred_class.shape[0]):
|
for i in range(image_pred_class.shape[0]):
|
||||||
# Get the IOUs of all boxes that come after the one we are looking at in the loop
|
# Get the IOUs of all boxes that come after the one we are looking at in the loop
|
||||||
try:
|
try:
|
||||||
ious = bbox_iou(np.expand_dims(image_pred_class[i], axis=0), image_pred_class[i+1:])
|
ious = bbox_iou(
|
||||||
|
np.expand_dims(image_pred_class[i], axis=0),
|
||||||
|
image_pred_class[i + 1 :],
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
break
|
break
|
||||||
# Zero out all the detections that have IoU > threshold
|
# Zero out all the detections that have IoU > threshold
|
||||||
|
@ -139,6 +163,7 @@ def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
||||||
output = np.concatenate((output, out))
|
output = np.concatenate((output, out))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def infer(model, img):
|
def infer(model, img):
|
||||||
img = np.array(Image.fromarray(img).resize((608, 608)))
|
img = np.array(Image.fromarray(img).resize((608, 608)))
|
||||||
img = img[:, :, ::-1].transpose((2, 0, 1))
|
img = img[:, :, ::-1].transpose((2, 0, 1))
|
||||||
|
@ -149,9 +174,9 @@ def infer(model, img):
|
||||||
|
|
||||||
def parse_cfg(cfg):
|
def parse_cfg(cfg):
|
||||||
# Return a list of blocks
|
# Return a list of blocks
|
||||||
lines = cfg.decode("utf-8").split('\n')
|
lines = cfg.decode("utf-8").split("\n")
|
||||||
lines = [x for x in lines if len(x) > 0]
|
lines = [x for x in lines if len(x) > 0]
|
||||||
lines = [x for x in lines if x[0] != '#']
|
lines = [x for x in lines if x[0] != "#"]
|
||||||
lines = [x.rstrip().lstrip() for x in lines]
|
lines = [x.rstrip().lstrip() for x in lines]
|
||||||
block, blocks = {}, []
|
block, blocks = {}, []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
@ -166,6 +191,7 @@ def parse_cfg(cfg):
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
# TODO: Speed up this function, avoid copying stuff from GPU to CPU
|
# TODO: Speed up this function, avoid copying stuff from GPU to CPU
|
||||||
def predict_transform(prediction, inp_dim, anchors, num_classes):
|
def predict_transform(prediction, inp_dim, anchors, num_classes):
|
||||||
batch_size = prediction.shape[0]
|
batch_size = prediction.shape[0]
|
||||||
|
@ -173,9 +199,13 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
|
||||||
grid_size = inp_dim // stride
|
grid_size = inp_dim // stride
|
||||||
bbox_attrs = 5 + num_classes
|
bbox_attrs = 5 + num_classes
|
||||||
num_anchors = len(anchors)
|
num_anchors = len(anchors)
|
||||||
prediction = prediction.reshape(shape=(batch_size, bbox_attrs*num_anchors, grid_size*grid_size))
|
prediction = prediction.reshape(
|
||||||
|
shape=(batch_size, bbox_attrs * num_anchors, grid_size * grid_size)
|
||||||
|
)
|
||||||
prediction = prediction.transpose(1, 2)
|
prediction = prediction.transpose(1, 2)
|
||||||
prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs))
|
prediction = prediction.reshape(
|
||||||
|
shape=(batch_size, grid_size * grid_size * num_anchors, bbox_attrs)
|
||||||
|
)
|
||||||
prediction_cpu = prediction.numpy()
|
prediction_cpu = prediction.numpy()
|
||||||
for i in (0, 1, 4):
|
for i in (0, 1, 4):
|
||||||
prediction_cpu[:, :, i] = 1 / (1 + np.exp(-prediction_cpu[:, :, i]))
|
prediction_cpu[:, :, i] = 1 / (1 + np.exp(-prediction_cpu[:, :, i]))
|
||||||
|
@ -193,7 +223,9 @@ def predict_transform(prediction, inp_dim, anchors, num_classes):
|
||||||
anchors = np.expand_dims(anchors, 0)
|
anchors = np.expand_dims(anchors, 0)
|
||||||
prediction_cpu[:, :, :2] += x_y_offset
|
prediction_cpu[:, :, :2] += x_y_offset
|
||||||
prediction_cpu[:, :, 2:4] = np.exp(prediction_cpu[:, :, 2:4]) * anchors
|
prediction_cpu[:, :, 2:4] = np.exp(prediction_cpu[:, :, 2:4]) * anchors
|
||||||
prediction_cpu[:,:,5:5+num_classes] = 1 / (1 + np.exp(-prediction_cpu[:,:,5:5+num_classes]))
|
prediction_cpu[:, :, 5 : 5 + num_classes] = 1 / (
|
||||||
|
1 + np.exp(-prediction_cpu[:, :, 5 : 5 + num_classes])
|
||||||
|
)
|
||||||
prediction_cpu[:, :, :4] *= stride
|
prediction_cpu[:, :, :4] *= stride
|
||||||
return Tensor(prediction_cpu)
|
return Tensor(prediction_cpu)
|
||||||
|
|
||||||
|
@ -222,18 +254,33 @@ class Darknet:
|
||||||
filters = int(x["filters"])
|
filters = int(x["filters"])
|
||||||
padding = int(x["pad"])
|
padding = int(x["pad"])
|
||||||
pad = (int(x["size"]) - 1) // 2 if padding else 0
|
pad = (int(x["size"]) - 1) // 2 if padding else 0
|
||||||
module.append(Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias=bias))
|
module.append(
|
||||||
|
Conv2d(
|
||||||
|
prev_filters,
|
||||||
|
filters,
|
||||||
|
int(x["size"]),
|
||||||
|
int(x["stride"]),
|
||||||
|
pad,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
# BatchNorm2d
|
# BatchNorm2d
|
||||||
if batch_normalize:
|
if batch_normalize:
|
||||||
module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True))
|
module.append(
|
||||||
|
BatchNorm2d(filters, eps=1e-05, track_running_stats=True)
|
||||||
|
)
|
||||||
# LeakyReLU activation
|
# LeakyReLU activation
|
||||||
if activation == "leaky":
|
if activation == "leaky":
|
||||||
module.append(lambda x: x.leakyrelu(0.1))
|
module.append(lambda x: x.leakyrelu(0.1))
|
||||||
elif module_type == "maxpool":
|
elif module_type == "maxpool":
|
||||||
size, stride = int(x["size"]), int(x["stride"])
|
size, stride = int(x["size"]), int(x["stride"])
|
||||||
module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride))
|
module.append(
|
||||||
|
lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride)
|
||||||
|
)
|
||||||
elif module_type == "upsample":
|
elif module_type == "upsample":
|
||||||
module.append(lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1)))
|
module.append(
|
||||||
|
lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1))
|
||||||
|
)
|
||||||
elif module_type == "route":
|
elif module_type == "route":
|
||||||
x["layers"] = x["layers"].split(",")
|
x["layers"] = x["layers"].split(",")
|
||||||
# Start of route
|
# Start of route
|
||||||
|
@ -243,11 +290,15 @@ class Darknet:
|
||||||
end = int(x["layers"][1])
|
end = int(x["layers"][1])
|
||||||
except:
|
except:
|
||||||
end = 0
|
end = 0
|
||||||
if start > 0: start -= index
|
if start > 0:
|
||||||
if end > 0: end -= index
|
start -= index
|
||||||
|
if end > 0:
|
||||||
|
end -= index
|
||||||
module.append(lambda x: x)
|
module.append(lambda x: x)
|
||||||
if end < 0:
|
if end < 0:
|
||||||
filters = output_filters[index + start] + output_filters[index + end]
|
filters = (
|
||||||
|
output_filters[index + start] + output_filters[index + end]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filters = output_filters[index + start]
|
filters = output_filters[index + start]
|
||||||
# Shortcut corresponds to skip connection
|
# Shortcut corresponds to skip connection
|
||||||
|
@ -256,7 +307,9 @@ class Darknet:
|
||||||
elif module_type == "yolo":
|
elif module_type == "yolo":
|
||||||
mask = list(map(int, x["mask"].split(",")))
|
mask = list(map(int, x["mask"].split(",")))
|
||||||
anchors = [int(a) for a in x["anchors"].split(",")]
|
anchors = [int(a) for a in x["anchors"].split(",")]
|
||||||
anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)]
|
anchors = [
|
||||||
|
(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)
|
||||||
|
]
|
||||||
module.append([anchors[i] for i in mask])
|
module.append([anchors[i] for i in mask])
|
||||||
# Append to module_list
|
# Append to module_list
|
||||||
module_list.append(module)
|
module_list.append(module)
|
||||||
|
@ -308,8 +361,12 @@ class Darknet:
|
||||||
# Cast the loaded weights into dims of model weights
|
# Cast the loaded weights into dims of model weights
|
||||||
bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape))
|
bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape))
|
||||||
bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape))
|
bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape))
|
||||||
bn_running_mean = bn_running_mean.reshape(shape=tuple(bn.running_mean.shape))
|
bn_running_mean = bn_running_mean.reshape(
|
||||||
bn_running_var = bn_running_var.reshape(shape=tuple(bn.running_var.shape))
|
shape=tuple(bn.running_mean.shape)
|
||||||
|
)
|
||||||
|
bn_running_var = bn_running_var.reshape(
|
||||||
|
shape=tuple(bn.running_var.shape)
|
||||||
|
)
|
||||||
# Copy data
|
# Copy data
|
||||||
bn.bias = bn_biases
|
bn.bias = bn_biases
|
||||||
bn.weight = bn_weights
|
bn.weight = bn_weights
|
||||||
|
@ -337,7 +394,7 @@ class Darknet:
|
||||||
outputs = {} # Cached outputs for route layer
|
outputs = {} # Cached outputs for route layer
|
||||||
detections, write = None, False
|
detections, write = None, False
|
||||||
for i, module in enumerate(modules):
|
for i, module in enumerate(modules):
|
||||||
module_type = (module["type"])
|
module_type = module["type"]
|
||||||
if module_type == "convolutional" or module_type == "upsample":
|
if module_type == "convolutional" or module_type == "upsample":
|
||||||
for layer in self.module_list[i]:
|
for layer in self.module_list[i]:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
@ -349,7 +406,8 @@ class Darknet:
|
||||||
if len(layers) == 1:
|
if len(layers) == 1:
|
||||||
x = outputs[i + (layers[0])]
|
x = outputs[i + (layers[0])]
|
||||||
else:
|
else:
|
||||||
if (layers[1]) > 0: layers[1] = layers[1] - i
|
if (layers[1]) > 0:
|
||||||
|
layers[1] = layers[1] - i
|
||||||
map1 = outputs[i + layers[0]]
|
map1 = outputs[i + layers[0]]
|
||||||
map2 = outputs[i + layers[1]]
|
map2 = outputs[i + layers[1]]
|
||||||
x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1))
|
x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1))
|
||||||
|
@ -364,19 +422,26 @@ class Darknet:
|
||||||
if not write:
|
if not write:
|
||||||
detections, write = x, True
|
detections, write = x, True
|
||||||
else:
|
else:
|
||||||
detections = Tensor(np.concatenate((detections.numpy(), x.numpy()), axis=1))
|
detections = Tensor(
|
||||||
|
np.concatenate((detections.numpy(), x.numpy()), axis=1)
|
||||||
|
)
|
||||||
outputs[i] = x
|
outputs[i] = x
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg'))
|
model = Darknet(
|
||||||
|
fetch(
|
||||||
|
"https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg"
|
||||||
|
)
|
||||||
|
)
|
||||||
print("Loading weights file (237MB). This might take a while…")
|
print("Loading weights file (237MB). This might take a while…")
|
||||||
model.load_weights('https://pjreddie.com/media/files/yolov3.weights')
|
model.load_weights("https://pjreddie.com/media/files/yolov3.weights")
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
url = sys.argv[1]
|
url = sys.argv[1]
|
||||||
else:
|
else:
|
||||||
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
|
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
|
||||||
if url == 'webcam':
|
if url == "webcam":
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||||
while 1:
|
while 1:
|
||||||
|
@ -386,21 +451,21 @@ if __name__ == "__main__":
|
||||||
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
||||||
boxes = add_boxes(np.array(img.resize((608, 608))), prediction)
|
boxes = add_boxes(np.array(img.resize((608, 608))), prediction)
|
||||||
boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR)
|
boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR)
|
||||||
cv2.imshow('yolo', boxes)
|
cv2.imshow("yolo", boxes)
|
||||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
break
|
break
|
||||||
cap.release()
|
cap.release()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
elif url.startswith('http'):
|
elif url.startswith("http"):
|
||||||
img_stream = io.BytesIO(fetch(url))
|
img_stream = io.BytesIO(fetch(url))
|
||||||
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
|
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
|
||||||
else:
|
else:
|
||||||
img = cv2.imread(url)
|
img = cv2.imread(url)
|
||||||
st = time.time()
|
st = time.time()
|
||||||
print('running inference…')
|
print("running inference…")
|
||||||
prediction = infer(model, img)
|
prediction = infer(model, img)
|
||||||
print(f'did inference in {(time.time() - st):2f}s')
|
print(f"did inference in {(time.time() - st):2f}s")
|
||||||
show_labels(prediction)
|
show_labels(prediction)
|
||||||
prediction = process_results(prediction)
|
prediction = process_results(prediction)
|
||||||
boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction)
|
boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction)
|
||||||
cv2.imwrite('boxes.jpg', boxes)
|
cv2.imwrite("boxes.jpg", boxes)
|
||||||
|
|
|
@ -12,7 +12,10 @@ if not Path("yolov8n-seg.onnx").is_file():
|
||||||
model.export(format="onnx", imgsz=[480, 640])
|
model.export(format="onnx", imgsz=[480, 640])
|
||||||
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
|
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
|
||||||
# TODO: move get example inputs to onnx
|
# TODO: move get example inputs to onnx
|
||||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
input_shapes = {
|
||||||
|
inp.name: tuple(x.dim_value for x in inp.type.tensor_type.shape.dim)
|
||||||
|
for inp in onnx_model.graph.input
|
||||||
|
}
|
||||||
print(input_shapes)
|
print(input_shapes)
|
||||||
run_onnx = get_run_onnx(onnx_model)
|
run_onnx = get_run_onnx(onnx_model)
|
||||||
run_onnx({"images": Tensor.zeros(1, 3, 480, 640)}, debug=True)
|
run_onnx({"images": Tensor.zeros(1, 3, 480, 640)}, debug=True)
|
||||||
|
|
|
@ -12,8 +12,11 @@ from tinygrad.nn.state import safe_load, load_state_dict
|
||||||
# Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
# Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||||
# The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
|
# The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
|
||||||
|
|
||||||
|
|
||||||
# Pre processing image functions.
|
# Pre processing image functions.
|
||||||
def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
|
def compute_transform(
|
||||||
|
image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32
|
||||||
|
):
|
||||||
shape = image.shape[:2] # current shape [height, width]
|
shape = image.shape[:2] # current shape [height, width]
|
||||||
new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape
|
new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape
|
||||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||||
|
@ -24,25 +27,39 @@ def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False,
|
||||||
new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad
|
new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad
|
||||||
dw /= 2
|
dw /= 2
|
||||||
dh /= 2
|
dh /= 2
|
||||||
image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else image
|
image = (
|
||||||
|
cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||||
|
if shape[::-1] != new_unpad
|
||||||
|
else image
|
||||||
|
)
|
||||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||||
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
image = cv2.copyMakeBorder(
|
||||||
|
image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||||
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
|
def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
|
||||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||||
auto = same_shapes and model_pt
|
auto = same_shapes and model_pt
|
||||||
im = Tensor([compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im])
|
im = Tensor(
|
||||||
|
[
|
||||||
|
compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride)
|
||||||
|
for x in im
|
||||||
|
]
|
||||||
|
)
|
||||||
im = Tensor.stack(im) if im.shape[0] > 1 else im
|
im = Tensor.stack(im) if im.shape[0] > 1 else im
|
||||||
im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||||
return im
|
return im
|
||||||
|
|
||||||
|
|
||||||
# Post Processing functions
|
# Post Processing functions
|
||||||
def box_area(box):
|
def box_area(box):
|
||||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||||
|
|
||||||
|
|
||||||
def box_iou(box1, box2):
|
def box_iou(box1, box2):
|
||||||
lt = np.maximum(box1[:, None, :2], box2[:, :2])
|
lt = np.maximum(box1[:, None, :2], box2[:, :2])
|
||||||
rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
|
rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
|
||||||
|
@ -53,6 +70,7 @@ def box_iou(box1, box2):
|
||||||
iou = inter / (area1 + area2 - inter)
|
iou = inter / (area1 + area2 - inter)
|
||||||
return iou
|
return iou
|
||||||
|
|
||||||
|
|
||||||
def compute_nms(boxes, scores, iou_threshold):
|
def compute_nms(boxes, scores, iou_threshold):
|
||||||
order, keep = scores.argsort()[::-1], []
|
order, keep = scores.argsort()[::-1], []
|
||||||
while order.size > 0:
|
while order.size > 0:
|
||||||
|
@ -65,7 +83,16 @@ def compute_nms(boxes, scores, iou_threshold):
|
||||||
order = order[inds + 1]
|
order = order[inds + 1]
|
||||||
return np.array(keep)
|
return np.array(keep)
|
||||||
|
|
||||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, max_det=300, nc=0, max_wh=7680):
|
|
||||||
|
def non_max_suppression(
|
||||||
|
prediction,
|
||||||
|
conf_thres=0.25,
|
||||||
|
iou_thres=0.45,
|
||||||
|
agnostic=False,
|
||||||
|
max_det=300,
|
||||||
|
nc=0,
|
||||||
|
max_wh=7680,
|
||||||
|
):
|
||||||
prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
|
prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
|
||||||
bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
|
bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
|
||||||
xc = np.amax(prediction[:, 4 : 4 + nc], axis=1) > conf_thres
|
xc = np.amax(prediction[:, 4 : 4 + nc], axis=1) > conf_thres
|
||||||
|
@ -74,12 +101,16 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=Fa
|
||||||
|
|
||||||
for xi, x in enumerate(prediction):
|
for xi, x in enumerate(prediction):
|
||||||
x = x.swapaxes(0, -1)[xc[xi]]
|
x = x.swapaxes(0, -1)[xc[xi]]
|
||||||
if not x.shape[0]: continue
|
if not x.shape[0]:
|
||||||
|
continue
|
||||||
box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
|
box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
|
||||||
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(cls, axis=1, keepdims=True)
|
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(
|
||||||
|
cls, axis=1, keepdims=True
|
||||||
|
)
|
||||||
x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
|
x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
|
||||||
x = x[conf.ravel() > conf_thres]
|
x = x[conf.ravel() > conf_thres]
|
||||||
if not x.shape[0]: continue
|
if not x.shape[0]:
|
||||||
|
continue
|
||||||
x = x[np.argsort(-x[:, 4])]
|
x = x[np.argsort(-x[:, 4])]
|
||||||
c = x[:, 5:6] * (0 if agnostic else max_wh)
|
c = x[:, 5:6] * (0 if agnostic else max_wh)
|
||||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||||
|
@ -87,12 +118,15 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=Fa
|
||||||
output[xi] = x[i]
|
output[xi] = x[i]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def postprocess(preds, img, orig_imgs):
|
def postprocess(preds, img, orig_imgs):
|
||||||
print('copying to CPU now for post processing')
|
print("copying to CPU now for post processing")
|
||||||
# if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
|
# if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
|
||||||
# TODO: make non_max_suppression in tinygrad - to make this faster
|
# TODO: make non_max_suppression in tinygrad - to make this faster
|
||||||
preds = preds.numpy() if isinstance(preds, Tensor) else preds
|
preds = preds.numpy() if isinstance(preds, Tensor) else preds
|
||||||
preds = non_max_suppression(prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300)
|
preds = non_max_suppression(
|
||||||
|
prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300
|
||||||
|
)
|
||||||
all_preds = []
|
all_preds = []
|
||||||
for i, pred in enumerate(preds):
|
for i, pred in enumerate(preds):
|
||||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||||
|
@ -101,8 +135,16 @@ def postprocess(preds, img, orig_imgs):
|
||||||
all_preds.append(pred)
|
all_preds.append(pred)
|
||||||
return all_preds
|
return all_preds
|
||||||
|
|
||||||
def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5):
|
|
||||||
color_dict = {label: tuple((((i+1) * 50) % 256, ((i+1) * 100) % 256, ((i+1) * 150) % 256)) for i, label in enumerate(class_labels)}
|
def draw_bounding_boxes_and_save(
|
||||||
|
orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5
|
||||||
|
):
|
||||||
|
color_dict = {
|
||||||
|
label: tuple(
|
||||||
|
(((i + 1) * 50) % 256, ((i + 1) * 100) % 256, ((i + 1) * 150) % 256)
|
||||||
|
)
|
||||||
|
for i, label in enumerate(class_labels)
|
||||||
|
}
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
|
||||||
def is_bright_color(color):
|
def is_bright_color(color):
|
||||||
|
@ -110,9 +152,15 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
|
||||||
brightness = (r * 299 + g * 587 + b * 114) / 1000
|
brightness = (r * 299 + g * 587 + b * 114) / 1000
|
||||||
return brightness > 127
|
return brightness > 127
|
||||||
|
|
||||||
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(zip(orig_img_paths, output_img_paths, all_predictions)):
|
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(
|
||||||
|
zip(orig_img_paths, output_img_paths, all_predictions)
|
||||||
|
):
|
||||||
predictions = np.array(predictions)
|
predictions = np.array(predictions)
|
||||||
orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1)
|
orig_img = (
|
||||||
|
cv2.imread(orig_img_path)
|
||||||
|
if not isinstance(orig_img_path, np.ndarray)
|
||||||
|
else cv2.imdecode(orig_img_path, 1)
|
||||||
|
)
|
||||||
height, width, _ = orig_img.shape
|
height, width, _ = orig_img.shape
|
||||||
box_thickness = int((height + width) / 400)
|
box_thickness = int((height + width) / 400)
|
||||||
font_scale = (height + width) / 2500
|
font_scale = (height + width) / 2500
|
||||||
|
@ -129,10 +177,29 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
|
||||||
cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
|
cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
|
||||||
label = f"{class_labels[class_id]} {conf:.2f}"
|
label = f"{class_labels[class_id]} {conf:.2f}"
|
||||||
text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
|
text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
|
||||||
label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1)
|
label_y, bg_y = (
|
||||||
cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1)
|
(y1 - 4, y1 - text_size[1] - 4)
|
||||||
|
if y1 - text_size[1] - 4 > 0
|
||||||
|
else (y1 + text_size[1], y1)
|
||||||
|
)
|
||||||
|
cv2.rectangle(
|
||||||
|
orig_img,
|
||||||
|
(x1, bg_y),
|
||||||
|
(x1 + text_size[0], bg_y + text_size[1]),
|
||||||
|
color,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
|
font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
|
||||||
cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA)
|
cv2.putText(
|
||||||
|
orig_img,
|
||||||
|
label,
|
||||||
|
(x1, label_y),
|
||||||
|
font,
|
||||||
|
font_scale,
|
||||||
|
font_color,
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
for class_id, pred_list in grouped_preds.items():
|
for class_id, pred_list in grouped_preds.items():
|
||||||
pred_list = np.array(pred_list)
|
pred_list = np.array(pred_list)
|
||||||
|
@ -155,7 +222,8 @@ def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictio
|
||||||
print(f"- {obj}: {count}")
|
print(f"- {obj}: {count}")
|
||||||
|
|
||||||
cv2.imwrite(output_img_path, orig_img)
|
cv2.imwrite(output_img_path, orig_img)
|
||||||
print(f'saved detections at {output_img_path}')
|
print(f"saved detections at {output_img_path}")
|
||||||
|
|
||||||
|
|
||||||
# utility functions for forward pass.
|
# utility functions for forward pass.
|
||||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||||
|
@ -168,6 +236,7 @@ def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||||
return c_xy.cat(wh, dim=1)
|
return c_xy.cat(wh, dim=1)
|
||||||
return x1y1.cat(x2y2, dim=1)
|
return x1y1.cat(x2y2, dim=1)
|
||||||
|
|
||||||
|
|
||||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||||
anchor_points, stride_tensor = [], []
|
anchor_points, stride_tensor = [], []
|
||||||
assert feats is not None
|
assert feats is not None
|
||||||
|
@ -183,25 +252,39 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||||
anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2))
|
anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2))
|
||||||
stride_tensor.append(Tensor.full((h * w), stride))
|
stride_tensor.append(Tensor.full((h * w), stride))
|
||||||
anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
|
anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
|
||||||
stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
|
stride_tensor = (
|
||||||
|
stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
|
||||||
|
)
|
||||||
return anchor_points, stride_tensor
|
return anchor_points, stride_tensor
|
||||||
|
|
||||||
|
|
||||||
# this function is from the original implementation
|
# this function is from the original implementation
|
||||||
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
||||||
if d > 1:
|
if d > 1:
|
||||||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
k = (
|
||||||
|
d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
|
||||||
|
) # actual kernel-size
|
||||||
if p is None:
|
if p is None:
|
||||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
def clip_boxes(boxes, shape):
|
def clip_boxes(boxes, shape):
|
||||||
boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
|
boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
|
||||||
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
|
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
|
||||||
return boxes
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||||
gain = ratio_pad if ratio_pad else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
|
gain = (
|
||||||
pad = ((img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2)
|
ratio_pad
|
||||||
|
if ratio_pad
|
||||||
|
else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
|
||||||
|
)
|
||||||
|
pad = (
|
||||||
|
(img1_shape[1] - img0_shape[1] * gain) / 2,
|
||||||
|
(img1_shape[0] - img0_shape[0] * gain) / 2,
|
||||||
|
)
|
||||||
boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
|
boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
|
||||||
boxes_np[..., [0, 2]] -= pad[0]
|
boxes_np[..., [0, 2]] -= pad[0]
|
||||||
boxes_np[..., [1, 3]] -= pad[1]
|
boxes_np[..., [1, 3]] -= pad[1]
|
||||||
|
@ -209,6 +292,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||||
boxes_np = clip_boxes(boxes_np, img0_shape)
|
boxes_np = clip_boxes(boxes_np, img0_shape)
|
||||||
return boxes_np
|
return boxes_np
|
||||||
|
|
||||||
|
|
||||||
def xywh2xyxy(x):
|
def xywh2xyxy(x):
|
||||||
xy = x[..., :2] # center x, y
|
xy = x[..., :2] # center x, y
|
||||||
wh = x[..., 2:4] # width, height
|
wh = x[..., 2:4] # width, height
|
||||||
|
@ -217,8 +301,16 @@ def xywh2xyxy(x):
|
||||||
result = np.concatenate((xy1, xy2), axis=-1)
|
result = np.concatenate((xy1, xy2), axis=-1)
|
||||||
return Tensor(result) if isinstance(x, Tensor) else result
|
return Tensor(result) if isinstance(x, Tensor) else result
|
||||||
|
|
||||||
|
|
||||||
def get_variant_multiples(variant):
|
def get_variant_multiples(variant):
|
||||||
return {'n':(0.33, 0.25, 2.0), 's':(0.33, 0.50, 2.0), 'm':(0.67, 0.75, 1.5), 'l':(1.0, 1.0, 1.0), 'x':(1, 1.25, 1.0) }.get(variant, None)
|
return {
|
||||||
|
"n": (0.33, 0.25, 2.0),
|
||||||
|
"s": (0.33, 0.50, 2.0),
|
||||||
|
"m": (0.67, 0.75, 1.5),
|
||||||
|
"l": (1.0, 1.0, 1.0),
|
||||||
|
"x": (1, 1.25, 1.0),
|
||||||
|
}.get(variant, None)
|
||||||
|
|
||||||
|
|
||||||
def label_predictions(all_predictions):
|
def label_predictions(all_predictions):
|
||||||
class_index_count = defaultdict(int)
|
class_index_count = defaultdict(int)
|
||||||
|
@ -230,6 +322,7 @@ def label_predictions(all_predictions):
|
||||||
|
|
||||||
return dict(class_index_count)
|
return dict(class_index_count)
|
||||||
|
|
||||||
|
|
||||||
# this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling)
|
# this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling)
|
||||||
class Upsample:
|
class Upsample:
|
||||||
def __init__(self, scale_factor: int, mode: str = "nearest") -> None:
|
def __init__(self, scale_factor: int, mode: str = "nearest") -> None:
|
||||||
|
@ -240,41 +333,86 @@ class Upsample:
|
||||||
def __call__(self, x: Tensor) -> Tensor:
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
assert len(x.shape) > 2 and len(x.shape) <= 5
|
assert len(x.shape) > 2 and len(x.shape) <= 5
|
||||||
(b, c), _lens = x.shape[:2], len(x.shape[2:])
|
(b, c), _lens = x.shape[:2], len(x.shape[2:])
|
||||||
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(*[1, 1, 1] + [self.scale_factor] * _lens)
|
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(
|
||||||
return tmp.reshape(list(x.shape) + [self.scale_factor] * _lens).permute([0, 1] + list(chain.from_iterable([[y+2, y+2+_lens] for y in range(_lens)]))).reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
|
*[1, 1, 1] + [self.scale_factor] * _lens
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
tmp.reshape(list(x.shape) + [self.scale_factor] * _lens)
|
||||||
|
.permute(
|
||||||
|
[0, 1]
|
||||||
|
+ list(
|
||||||
|
chain.from_iterable([[y + 2, y + 2 + _lens] for y in range(_lens)])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Conv_Block:
|
class Conv_Block:
|
||||||
def __init__(self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None):
|
def __init__(
|
||||||
self.conv = Conv2d(c1,c2, kernel_size, stride, padding=autopad(kernel_size, padding, dilation), bias=False, groups=groups, dilation=dilation)
|
self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None
|
||||||
|
):
|
||||||
|
self.conv = Conv2d(
|
||||||
|
c1,
|
||||||
|
c2,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding=autopad(kernel_size, padding, dilation),
|
||||||
|
bias=False,
|
||||||
|
groups=groups,
|
||||||
|
dilation=dilation,
|
||||||
|
)
|
||||||
self.bn = BatchNorm2d(c2, eps=0.001)
|
self.bn = BatchNorm2d(c2, eps=0.001)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return self.bn(self.conv(x)).silu()
|
return self.bn(self.conv(x)).silu()
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck:
|
class Bottleneck:
|
||||||
def __init__(self, c1, c2 , shortcut: bool, g=1, kernels: list = (3,3), channel_factor=0.5):
|
def __init__(
|
||||||
|
self, c1, c2, shortcut: bool, g=1, kernels: list = (3, 3), channel_factor=0.5
|
||||||
|
):
|
||||||
c_ = int(c2 * channel_factor)
|
c_ = int(c2 * channel_factor)
|
||||||
self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None)
|
self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None)
|
||||||
self.cv2 = Conv_Block(c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g)
|
self.cv2 = Conv_Block(
|
||||||
|
c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g
|
||||||
|
)
|
||||||
self.residual = c1 == c2 and shortcut
|
self.residual = c1 == c2 and shortcut
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x))
|
return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x))
|
||||||
|
|
||||||
|
|
||||||
class C2f:
|
class C2f:
|
||||||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
|
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
|
||||||
self.c = int(c2 * e)
|
self.c = int(c2 * e)
|
||||||
self.cv1 = Conv_Block(c1, 2 * self.c, 1,)
|
self.cv1 = Conv_Block(
|
||||||
|
c1,
|
||||||
|
2 * self.c,
|
||||||
|
1,
|
||||||
|
)
|
||||||
self.cv2 = Conv_Block((2 + n) * self.c, c2, 1)
|
self.cv2 = Conv_Block((2 + n) * self.c, c2, 1)
|
||||||
self.bottleneck = [Bottleneck(self.c, self.c, shortcut, g, kernels=[(3, 3), (3, 3)], channel_factor=1.0) for _ in range(n)]
|
self.bottleneck = [
|
||||||
|
Bottleneck(
|
||||||
|
self.c,
|
||||||
|
self.c,
|
||||||
|
shortcut,
|
||||||
|
g,
|
||||||
|
kernels=[(3, 3), (3, 3)],
|
||||||
|
channel_factor=1.0,
|
||||||
|
)
|
||||||
|
for _ in range(n)
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
y = list(self.cv1(x).chunk(2, 1))
|
y = list(self.cv1(x).chunk(2, 1))
|
||||||
y.extend(m(y[-1]) for m in self.bottleneck)
|
y.extend(m(y[-1]) for m in self.bottleneck)
|
||||||
z = y[0]
|
z = y[0]
|
||||||
for i in y[1:]: z = z.cat(i, dim=1)
|
for i in y[1:]:
|
||||||
|
z = z.cat(i, dim=1)
|
||||||
return self.cv2(z)
|
return self.cv2(z)
|
||||||
|
|
||||||
|
|
||||||
class SPPF:
|
class SPPF:
|
||||||
def __init__(self, c1, c2, k=5):
|
def __init__(self, c1, c2, k=5):
|
||||||
c_ = c1 // 2 # hidden channels
|
c_ = c1 // 2 # hidden channels
|
||||||
|
@ -282,7 +420,9 @@ class SPPF:
|
||||||
self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None)
|
self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None)
|
||||||
|
|
||||||
# TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually.
|
# TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually.
|
||||||
self.maxpool = lambda x : x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1)
|
self.maxpool = lambda x: x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(
|
||||||
|
kernel_size=k, stride=1
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = self.cv1(x)
|
x = self.cv1(x)
|
||||||
|
@ -291,6 +431,7 @@ class SPPF:
|
||||||
x4 = self.maxpool(x3)
|
x4 = self.maxpool(x3)
|
||||||
return self.cv2(x.cat(x2, x3, x4, dim=1))
|
return self.cv2(x.cat(x2, x3, x4, dim=1))
|
||||||
|
|
||||||
|
|
||||||
class DFL:
|
class DFL:
|
||||||
def __init__(self, c1=16):
|
def __init__(self, c1=16):
|
||||||
self.conv = Conv2d(c1, 1, 1, bias=False)
|
self.conv = Conv2d(c1, 1, 1, bias=False)
|
||||||
|
@ -300,15 +441,33 @@ class DFL:
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
b, c, a = x.shape # batch, channels, anchors
|
b, c, a = x.shape # batch, channels, anchors
|
||||||
return self.conv(x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)).reshape(b, 4, a)
|
return self.conv(
|
||||||
|
x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)
|
||||||
|
).reshape(b, 4, a)
|
||||||
|
|
||||||
|
|
||||||
# backbone
|
# backbone
|
||||||
class Darknet:
|
class Darknet:
|
||||||
def __init__(self, w, r, d):
|
def __init__(self, w, r, d):
|
||||||
self.b1 = [Conv_Block(c1=3, c2= int(64*w), kernel_size=3, stride=2, padding=1), Conv_Block(int(64*w), int(128*w), kernel_size=3, stride=2, padding=1)]
|
self.b1 = [
|
||||||
self.b2 = [C2f(c1=int(128*w), c2=int(128*w), n=round(3*d), shortcut=True), Conv_Block(int(128*w), int(256*w), 3, 2, 1), C2f(int(256*w), int(256*w), round(6*d), True)]
|
Conv_Block(c1=3, c2=int(64 * w), kernel_size=3, stride=2, padding=1),
|
||||||
self.b3 = [Conv_Block(int(256*w), int(512*w), kernel_size=3, stride=2, padding=1), C2f(int(512*w), int(512*w), round(6*d), True)]
|
Conv_Block(int(64 * w), int(128 * w), kernel_size=3, stride=2, padding=1),
|
||||||
self.b4 = [Conv_Block(int(512*w), int(512*w*r), kernel_size=3, stride=2, padding=1), C2f(int(512*w*r), int(512*w*r), round(3*d), True)]
|
]
|
||||||
|
self.b2 = [
|
||||||
|
C2f(c1=int(128 * w), c2=int(128 * w), n=round(3 * d), shortcut=True),
|
||||||
|
Conv_Block(int(128 * w), int(256 * w), 3, 2, 1),
|
||||||
|
C2f(int(256 * w), int(256 * w), round(6 * d), True),
|
||||||
|
]
|
||||||
|
self.b3 = [
|
||||||
|
Conv_Block(int(256 * w), int(512 * w), kernel_size=3, stride=2, padding=1),
|
||||||
|
C2f(int(512 * w), int(512 * w), round(6 * d), True),
|
||||||
|
]
|
||||||
|
self.b4 = [
|
||||||
|
Conv_Block(
|
||||||
|
int(512 * w), int(512 * w * r), kernel_size=3, stride=2, padding=1
|
||||||
|
),
|
||||||
|
C2f(int(512 * w * r), int(512 * w * r), round(3 * d), True),
|
||||||
|
]
|
||||||
self.b5 = [SPPF(int(512 * w * r), int(512 * w * r), 5)]
|
self.b5 = [SPPF(int(512 * w * r), int(512 * w * r), 5)]
|
||||||
|
|
||||||
def return_modules(self):
|
def return_modules(self):
|
||||||
|
@ -322,16 +481,28 @@ class Darknet:
|
||||||
x5 = x4.sequential(self.b5)
|
x5 = x4.sequential(self.b5)
|
||||||
return (x2, x3, x5)
|
return (x2, x3, x5)
|
||||||
|
|
||||||
|
|
||||||
# yolo fpn (neck)
|
# yolo fpn (neck)
|
||||||
class Yolov8NECK:
|
class Yolov8NECK:
|
||||||
def __init__(self, w, r, d): # width_multiple, ratio_multiple, depth_multiple
|
def __init__(self, w, r, d): # width_multiple, ratio_multiple, depth_multiple
|
||||||
self.up = Upsample(2, mode='nearest')
|
self.up = Upsample(2, mode="nearest")
|
||||||
self.n1 = C2f(c1=int(512*w*(1+r)), c2=int(512*w), n=round(3*d), shortcut=False)
|
self.n1 = C2f(
|
||||||
|
c1=int(512 * w * (1 + r)), c2=int(512 * w), n=round(3 * d), shortcut=False
|
||||||
|
)
|
||||||
self.n2 = C2f(c1=int(768 * w), c2=int(256 * w), n=round(3 * d), shortcut=False)
|
self.n2 = C2f(c1=int(768 * w), c2=int(256 * w), n=round(3 * d), shortcut=False)
|
||||||
self.n3 = Conv_Block(c1=int(256*w), c2=int(256*w), kernel_size=3, stride=2, padding=1)
|
self.n3 = Conv_Block(
|
||||||
|
c1=int(256 * w), c2=int(256 * w), kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
self.n4 = C2f(c1=int(768 * w), c2=int(512 * w), n=round(3 * d), shortcut=False)
|
self.n4 = C2f(c1=int(768 * w), c2=int(512 * w), n=round(3 * d), shortcut=False)
|
||||||
self.n5 = Conv_Block(c1=int(512* w), c2=int(512 * w), kernel_size=3, stride=2, padding=1)
|
self.n5 = Conv_Block(
|
||||||
self.n6 = C2f(c1=int(512*w*(1+r)), c2=int(512*w*r), n=round(3*d), shortcut=False)
|
c1=int(512 * w), c2=int(512 * w), kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
|
self.n6 = C2f(
|
||||||
|
c1=int(512 * w * (1 + r)),
|
||||||
|
c2=int(512 * w * r),
|
||||||
|
n=round(3 * d),
|
||||||
|
shortcut=False,
|
||||||
|
)
|
||||||
|
|
||||||
def return_modules(self):
|
def return_modules(self):
|
||||||
return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6]
|
return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6]
|
||||||
|
@ -343,6 +514,7 @@ class Yolov8NECK:
|
||||||
head_3 = self.n6(self.n5(head_2).cat(p5, dim=1))
|
head_3 = self.n6(self.n5(head_2).cat(p5, dim=1))
|
||||||
return [head_1, head_2, head_3]
|
return [head_1, head_2, head_3]
|
||||||
|
|
||||||
|
|
||||||
# task specific head.
|
# task specific head.
|
||||||
class DetectionHead:
|
class DetectionHead:
|
||||||
def __init__(self, nc=80, filters=()):
|
def __init__(self, nc=80, filters=()):
|
||||||
|
@ -354,25 +526,41 @@ class DetectionHead:
|
||||||
c1 = max(filters[0], self.nc)
|
c1 = max(filters[0], self.nc)
|
||||||
c2 = max((filters[0] // 4, self.ch * 4))
|
c2 = max((filters[0] // 4, self.ch * 4))
|
||||||
self.dfl = DFL(self.ch)
|
self.dfl = DFL(self.ch)
|
||||||
self.cv3 = [[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)] for x in filters]
|
self.cv3 = [
|
||||||
self.cv2 = [[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)] for x in filters]
|
[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)]
|
||||||
|
for x in filters
|
||||||
|
]
|
||||||
|
self.cv2 = [
|
||||||
|
[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)]
|
||||||
|
for x in filters
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
for i in range(self.nl):
|
for i in range(self.nl):
|
||||||
x[i] = (x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1))
|
x[i] = x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1)
|
||||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
self.anchors, self.strides = (
|
||||||
|
x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)
|
||||||
|
)
|
||||||
y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
|
y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
|
||||||
x_cat = y[0].cat(y[1], y[2], dim=2)
|
x_cat = y[0].cat(y[1], y[2], dim=2)
|
||||||
box, cls = x_cat[:, : self.ch * 4], x_cat[:, self.ch * 4 :]
|
box, cls = x_cat[:, : self.ch * 4], x_cat[:, self.ch * 4 :]
|
||||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
dbox = (
|
||||||
|
dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1)
|
||||||
|
* self.strides
|
||||||
|
)
|
||||||
z = dbox.cat(cls.sigmoid(), dim=1)
|
z = dbox.cat(cls.sigmoid(), dim=1)
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
class YOLOv8:
|
class YOLOv8:
|
||||||
def __init__(self, w, r, d, num_classes): #width_multiple, ratio_multiple, depth_multiple
|
def __init__(
|
||||||
|
self, w, r, d, num_classes
|
||||||
|
): # width_multiple, ratio_multiple, depth_multiple
|
||||||
self.net = Darknet(w, r, d)
|
self.net = Darknet(w, r, d)
|
||||||
self.fpn = Yolov8NECK(w, r, d)
|
self.fpn = Yolov8NECK(w, r, d)
|
||||||
self.head = DetectionHead(num_classes, filters=(int(256*w), int(512*w), int(512*w*r)))
|
self.head = DetectionHead(
|
||||||
|
num_classes, filters=(int(256 * w), int(512 * w), int(512 * w * r))
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = self.net(x)
|
x = self.net(x)
|
||||||
|
@ -383,27 +571,44 @@ class YOLOv8:
|
||||||
backbone_modules = [*range(10)]
|
backbone_modules = [*range(10)]
|
||||||
yolov8neck_modules = [12, 15, 16, 18, 19, 21]
|
yolov8neck_modules = [12, 15, 16, 18, 19, 21]
|
||||||
yolov8_head_weights = [(22, self.head)]
|
yolov8_head_weights = [(22, self.head)]
|
||||||
return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights]
|
return [
|
||||||
|
*zip(backbone_modules, self.net.return_modules()),
|
||||||
|
*zip(yolov8neck_modules, self.fpn.return_modules()),
|
||||||
|
*yolov8_head_weights,
|
||||||
|
]
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
# usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
|
# usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
|
||||||
if len(sys.argv) < 2:
|
if len(sys.argv) < 2:
|
||||||
print("Error: Image URL or path not provided.")
|
print("Error: Image URL or path not provided.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
img_path = sys.argv[1]
|
img_path = sys.argv[1]
|
||||||
yolo_variant = sys.argv[2] if len(sys.argv) >= 3 else (print("No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']") or 'n')
|
yolo_variant = (
|
||||||
print(f'running inference for YOLO version {yolo_variant}')
|
sys.argv[2]
|
||||||
|
if len(sys.argv) >= 3
|
||||||
|
else (
|
||||||
|
print(
|
||||||
|
"No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']"
|
||||||
|
)
|
||||||
|
or "n"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"running inference for YOLO version {yolo_variant}")
|
||||||
|
|
||||||
output_folder_path = Path('./outputs_yolov8')
|
output_folder_path = Path("./outputs_yolov8")
|
||||||
output_folder_path.mkdir(parents=True, exist_ok=True)
|
output_folder_path.mkdir(parents=True, exist_ok=True)
|
||||||
# absolute image path or URL
|
# absolute image path or URL
|
||||||
image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
|
image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
|
||||||
image = [cv2.imdecode(image_location[0], 1)]
|
image = [cv2.imdecode(image_location[0], 1)]
|
||||||
out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}").as_posix()]
|
out_paths = [
|
||||||
|
(
|
||||||
|
output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}"
|
||||||
|
).as_posix()
|
||||||
|
]
|
||||||
if not isinstance(image[0], np.ndarray):
|
if not isinstance(image[0], np.ndarray):
|
||||||
print('Error in image loading. Check your image file.')
|
print("Error in image loading. Check your image file.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
pre_processed_image = preprocess(image)
|
pre_processed_image = preprocess(image)
|
||||||
|
|
||||||
|
@ -411,19 +616,36 @@ if __name__ == '__main__':
|
||||||
depth, width, ratio = get_variant_multiples(yolo_variant)
|
depth, width, ratio = get_variant_multiples(yolo_variant)
|
||||||
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||||
|
|
||||||
state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors'))
|
state_dict = safe_load(
|
||||||
|
fetch(
|
||||||
|
f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors"
|
||||||
|
)
|
||||||
|
)
|
||||||
load_state_dict(yolo_infer, state_dict)
|
load_state_dict(yolo_infer, state_dict)
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
predictions = yolo_infer(pre_processed_image)
|
predictions = yolo_infer(pre_processed_image)
|
||||||
print(f'did inference in {int(round(((time.time() - st) * 1000)))}ms')
|
print(f"did inference in {int(round(((time.time() - st) * 1000)))}ms")
|
||||||
|
|
||||||
post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image)
|
post_predictions = postprocess(
|
||||||
|
preds=predictions, img=pre_processed_image, orig_imgs=image
|
||||||
|
)
|
||||||
|
|
||||||
# v8 and v3 have same 80 class names for Object Detection
|
# v8 and v3 have same 80 class names for Object Detection
|
||||||
class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n")
|
class_labels = (
|
||||||
|
fetch(
|
||||||
|
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||||
|
)
|
||||||
|
.read_text()
|
||||||
|
.split("\n")
|
||||||
|
)
|
||||||
|
|
||||||
draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels)
|
draw_bounding_boxes_and_save(
|
||||||
|
orig_img_paths=image_location,
|
||||||
|
output_img_paths=out_paths,
|
||||||
|
all_predictions=post_predictions,
|
||||||
|
class_labels=class_labels,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO for later:
|
# TODO for later:
|
||||||
# 1. Fix SPPF minor difference due to maxpool
|
# 1. Fix SPPF minor difference due to maxpool
|
||||||
|
|
|
@ -6,9 +6,9 @@ from coremltools.models.neural_network import datatypes, NeuralNetworkBuilder
|
||||||
# KxK GEMM with bias
|
# KxK GEMM with bias
|
||||||
K = 64
|
K = 64
|
||||||
|
|
||||||
input_features = [('image', datatypes.Array(K))]
|
input_features = [("image", datatypes.Array(K))]
|
||||||
input_features2 = [('image2', datatypes.Array(K))]
|
input_features2 = [("image2", datatypes.Array(K))]
|
||||||
output_features = [('probs', datatypes.Array(K))]
|
output_features = [("probs", datatypes.Array(K))]
|
||||||
|
|
||||||
weights = np.zeros((K, K)) + 3
|
weights = np.zeros((K, K)) + 3
|
||||||
bias = np.ones(K)
|
bias = np.ones(K)
|
||||||
|
@ -17,7 +17,9 @@ builder = NeuralNetworkBuilder(input_features+input_features2, output_features)
|
||||||
|
|
||||||
# builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med')
|
# builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med')
|
||||||
# builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs')
|
# builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs')
|
||||||
builder.add_elementwise(name='element', input_names=['image', 'image2'], output_name='probs', mode='ADD')
|
builder.add_elementwise(
|
||||||
|
name="element", input_names=["image", "image2"], output_name="probs", mode="ADD"
|
||||||
|
)
|
||||||
# builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,))
|
# builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,))
|
||||||
# builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs')
|
# builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs')
|
||||||
|
|
||||||
|
@ -25,6 +27,11 @@ builder.add_elementwise(name='element', input_names=['image', 'image2'], output_
|
||||||
mlmodel = ct.models.MLModel(builder.spec)
|
mlmodel = ct.models.MLModel(builder.spec)
|
||||||
|
|
||||||
# trigger the ANE!
|
# trigger the ANE!
|
||||||
out = mlmodel.predict({"image": np.zeros(K, dtype=np.float32)+1, "image2": np.zeros(K, dtype=np.float32)+2})
|
out = mlmodel.predict(
|
||||||
|
{
|
||||||
|
"image": np.zeros(K, dtype=np.float32) + 1,
|
||||||
|
"image2": np.zeros(K, dtype=np.float32) + 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
print(out)
|
print(out)
|
||||||
mlmodel.save('test.mlmodel')
|
mlmodel.save("test.mlmodel")
|
||||||
|
|
|
@ -6,7 +6,7 @@ import pylab as plt
|
||||||
from networkx.drawing.nx_pydot import read_dot
|
from networkx.drawing.nx_pydot import read_dot
|
||||||
|
|
||||||
ret = os.system("./a.out " + sys.argv[1] + " debug")
|
ret = os.system("./a.out " + sys.argv[1] + " debug")
|
||||||
assert(ret == 0)
|
assert ret == 0
|
||||||
|
|
||||||
df = "debug/model.hwx.zinir_graph_after_reg_spill.dot"
|
df = "debug/model.hwx.zinir_graph_after_reg_spill.dot"
|
||||||
|
|
||||||
|
|
|
@ -3,17 +3,21 @@ import sys
|
||||||
from hexdump import hexdump
|
from hexdump import hexdump
|
||||||
from macholib import MachO
|
from macholib import MachO
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
|
|
||||||
|
|
||||||
def get_macho(fn):
|
def get_macho(fn):
|
||||||
# mod to make the header okay
|
# mod to make the header okay
|
||||||
# MH_CIGAM_64 is good
|
# MH_CIGAM_64 is good
|
||||||
dat = open(fn, "rb").read()
|
dat = open(fn, "rb").read()
|
||||||
dat = b"\xcf\xfa\xed\xfe" + dat[4:]
|
dat = b"\xcf\xfa\xed\xfe" + dat[4:]
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
with NamedTemporaryFile(delete=False) as f:
|
with NamedTemporaryFile(delete=False) as f:
|
||||||
f.write(dat)
|
f.write(dat)
|
||||||
f.close()
|
f.close()
|
||||||
return MachO.MachO(f.name)
|
return MachO.MachO(f.name)
|
||||||
|
|
||||||
|
|
||||||
a = get_macho("model.hwx.golden")
|
a = get_macho("model.hwx.golden")
|
||||||
|
|
||||||
# load commands
|
# load commands
|
||||||
|
@ -23,12 +27,19 @@ for c in a.headers[0].commands:
|
||||||
hexdump(c[2])
|
hexdump(c[2])
|
||||||
pass
|
pass
|
||||||
if c[0].cmd == 6:
|
if c[0].cmd == 6:
|
||||||
print("name:", c[2].decode('utf-8'))
|
print("name:", c[2].decode("utf-8"))
|
||||||
if c[0].cmd == 8:
|
if c[0].cmd == 8:
|
||||||
print(c[2].decode('utf-8'))
|
print(c[2].decode("utf-8"))
|
||||||
if c[0].cmd == 25:
|
if c[0].cmd == 25:
|
||||||
for section in c[2]:
|
for section in c[2]:
|
||||||
print(section.segname.strip(b'\0'), section.sectname.strip(b'\0'), hex(section.addr), hex(section.size), "@", hex(c[1].fileoff))
|
print(
|
||||||
|
section.segname.strip(b"\0"),
|
||||||
|
section.sectname.strip(b"\0"),
|
||||||
|
hex(section.addr),
|
||||||
|
hex(section.size),
|
||||||
|
"@",
|
||||||
|
hex(c[1].fileoff),
|
||||||
|
)
|
||||||
# print(dir(section))
|
# print(dir(section))
|
||||||
if c[1].filesize > 0:
|
if c[1].filesize > 0:
|
||||||
if len(section.section_data) < 0x100:
|
if len(section.section_data) < 0x100:
|
||||||
|
@ -38,6 +49,7 @@ for c in a.headers[0].commands:
|
||||||
|
|
||||||
# this parser is wrong (fixed with 64-bit one)
|
# this parser is wrong (fixed with 64-bit one)
|
||||||
from macholib import SymbolTable
|
from macholib import SymbolTable
|
||||||
|
|
||||||
sym = SymbolTable.SymbolTable(a)
|
sym = SymbolTable.SymbolTable(a)
|
||||||
|
|
||||||
syms = {}
|
syms = {}
|
||||||
|
@ -52,6 +64,7 @@ for k,v in syms.items():
|
||||||
|
|
||||||
# **** document what we know ***
|
# **** document what we know ***
|
||||||
from ane import ANE_Struct, ANE
|
from ane import ANE_Struct, ANE
|
||||||
|
|
||||||
ane = ANE()
|
ane = ANE()
|
||||||
|
|
||||||
aneb = set()
|
aneb = set()
|
||||||
|
@ -65,6 +78,8 @@ for l in range(0x34, 0xF4):
|
||||||
aneb.add(l)
|
aneb.add(l)
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
def compare(x, y):
|
def compare(x, y):
|
||||||
ss = []
|
ss = []
|
||||||
ln = []
|
ln = []
|
||||||
|
@ -73,7 +88,7 @@ def compare(x, y):
|
||||||
ll = (max(len(x), len(y)) + 0xF) // 0x10 * 0x10
|
ll = (max(len(x), len(y)) + 0xF) // 0x10 * 0x10
|
||||||
|
|
||||||
highlight = False
|
highlight = False
|
||||||
next_highlight = 0x2b
|
next_highlight = 0x2B
|
||||||
for i in range(ll + 1):
|
for i in range(ll + 1):
|
||||||
if i == next_highlight:
|
if i == next_highlight:
|
||||||
highlight = True
|
highlight = True
|
||||||
|
@ -83,35 +98,37 @@ def compare(x, y):
|
||||||
next_highlight = None
|
next_highlight = None
|
||||||
else:
|
else:
|
||||||
highlight = False
|
highlight = False
|
||||||
a = "%02X" % x[i] if i < len(x) else "--", \
|
a = "%02X" % x[i] if i < len(x) else "--", "%02X" % y[i] if i < len(y) else "--"
|
||||||
"%02X" % y[i] if i < len(y) else "--"
|
|
||||||
def fj(x):
|
def fj(x):
|
||||||
ss = []
|
ss = []
|
||||||
for i in range(0, 0x10, 4):
|
for i in range(0, 0x10, 4):
|
||||||
ss.append(' '.join(x[i:i+4]))
|
ss.append(" ".join(x[i : i + 4]))
|
||||||
return ' '.join(ss)
|
return " ".join(ss)
|
||||||
|
|
||||||
if i != 0 and i % 0x10 == 0:
|
if i != 0 and i % 0x10 == 0:
|
||||||
ss.append("%8X: " % (i - 0x10) + fj(ln) + " | " + fj(ln2) + "\n")
|
ss.append("%8X: " % (i - 0x10) + fj(ln) + " | " + fj(ln2) + "\n")
|
||||||
ln = []
|
ln = []
|
||||||
ln2 = []
|
ln2 = []
|
||||||
if a[0] != a[1] and a[0] != "--" and a[1] != "--":
|
if a[0] != a[1] and a[0] != "--" and a[1] != "--":
|
||||||
ln.append(colored(a[0], 'green'))
|
ln.append(colored(a[0], "green"))
|
||||||
ln2.append(colored(a[1], 'red'))
|
ln2.append(colored(a[1], "red"))
|
||||||
else:
|
else:
|
||||||
if highlight:
|
if highlight:
|
||||||
ln.append(colored(a[0], 'yellow'))
|
ln.append(colored(a[0], "yellow"))
|
||||||
ln2.append(colored(a[1], 'yellow'))
|
ln2.append(colored(a[1], "yellow"))
|
||||||
else:
|
else:
|
||||||
if i in aneb:
|
if i in aneb:
|
||||||
ln.append(colored(a[0], 'white'))
|
ln.append(colored(a[0], "white"))
|
||||||
ln2.append(colored(a[1], 'white'))
|
ln2.append(colored(a[1], "white"))
|
||||||
else:
|
else:
|
||||||
ln.append(a[0])
|
ln.append(a[0])
|
||||||
ln2.append(a[1])
|
ln2.append(a[1])
|
||||||
return ''.join(ss)
|
return "".join(ss)
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
aneregs = dict(json.load(open("aneregs.json")))
|
aneregs = dict(json.load(open("aneregs.json")))
|
||||||
g = get_macho("model.hwx.golden" if len(sys.argv) < 2 else sys.argv[1])
|
g = get_macho("model.hwx.golden" if len(sys.argv) < 2 else sys.argv[1])
|
||||||
f1 = g.headers[0].commands[1][2][0].section_data
|
f1 = g.headers[0].commands[1][2][0].section_data
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
from ane import ANE
|
from ane import ANE
|
||||||
|
|
||||||
ane = ANE()
|
ane = ANE()
|
||||||
|
|
||||||
lens = {}
|
lens = {}
|
||||||
|
@ -30,7 +31,7 @@ for i in range(0x300):
|
||||||
pos.append((k, (i, j, lens[k])))
|
pos.append((k, (i, j, lens[k])))
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
jpos = json.dumps(pos, indent=2)
|
jpos = json.dumps(pos, indent=2)
|
||||||
with open("aneregs.json", "w") as f:
|
with open("aneregs.json", "w") as f:
|
||||||
f.write(jpos)
|
f.write(jpos)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import ctypes
|
||||||
from subprocess import check_output
|
from subprocess import check_output
|
||||||
from hexdump import hexdump
|
from hexdump import hexdump
|
||||||
|
|
||||||
|
|
||||||
def get_pid(name):
|
def get_pid(name):
|
||||||
try:
|
try:
|
||||||
output = check_output(["pgrep", name])
|
output = check_output(["pgrep", name])
|
||||||
|
@ -9,8 +10,10 @@ def get_pid(name):
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
from ctypes.util import find_library
|
from ctypes.util import find_library
|
||||||
libc = ctypes.CDLL(find_library('c'))
|
|
||||||
|
libc = ctypes.CDLL(find_library("c"))
|
||||||
|
|
||||||
amfid_pid = get_pid("amfid")
|
amfid_pid = get_pid("amfid")
|
||||||
|
|
||||||
|
@ -21,6 +24,7 @@ print(amfid_pid, ret, task, mytask)
|
||||||
|
|
||||||
# myport = libc.mach_task_self()
|
# myport = libc.mach_task_self()
|
||||||
|
|
||||||
|
|
||||||
class vm_region_submap_short_info_data_64(ctypes.Structure):
|
class vm_region_submap_short_info_data_64(ctypes.Structure):
|
||||||
_pack_ = 1
|
_pack_ = 1
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
|
@ -38,6 +42,8 @@ class vm_region_submap_short_info_data_64(ctypes.Structure):
|
||||||
("object_id", ctypes.c_uint32),
|
("object_id", ctypes.c_uint32),
|
||||||
("user_wired_count", ctypes.c_uint32),
|
("user_wired_count", ctypes.c_uint32),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
submap_info_size = ctypes.sizeof(vm_region_submap_short_info_data_64) // 4
|
submap_info_size = ctypes.sizeof(vm_region_submap_short_info_data_64) // 4
|
||||||
|
|
||||||
address = ctypes.c_ulong(0)
|
address = ctypes.c_ulong(0)
|
||||||
|
@ -48,21 +54,31 @@ depth = 0
|
||||||
|
|
||||||
c_depth = ctypes.c_uint32(depth)
|
c_depth = ctypes.c_uint32(depth)
|
||||||
for i in range(1):
|
for i in range(1):
|
||||||
ret = libc.mach_vm_region_recurse(task,
|
ret = libc.mach_vm_region_recurse(
|
||||||
ctypes.pointer(address), ctypes.pointer(mapsize),
|
task,
|
||||||
ctypes.pointer(c_depth), ctypes.pointer(sub_info),
|
ctypes.pointer(address),
|
||||||
ctypes.pointer(count))
|
ctypes.pointer(mapsize),
|
||||||
|
ctypes.pointer(c_depth),
|
||||||
|
ctypes.pointer(sub_info),
|
||||||
|
ctypes.pointer(count),
|
||||||
|
)
|
||||||
print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection)
|
print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection)
|
||||||
# address.value += mapsize.value
|
# address.value += mapsize.value
|
||||||
# exit(0)
|
# exit(0)
|
||||||
|
|
||||||
patch_address = address.value + 0x8e38
|
patch_address = address.value + 0x8E38
|
||||||
patch = b"\x00\x00\x80\xd2"
|
patch = b"\x00\x00\x80\xd2"
|
||||||
|
|
||||||
pdata = ctypes.c_void_p(0)
|
pdata = ctypes.c_void_p(0)
|
||||||
data_cnt = ctypes.c_uint32(0)
|
data_cnt = ctypes.c_uint32(0)
|
||||||
|
|
||||||
ret = libc.mach_vm_read(task, ctypes.c_ulong(patch_address), 4, ctypes.pointer(pdata), ctypes.pointer(data_cnt))
|
ret = libc.mach_vm_read(
|
||||||
|
task,
|
||||||
|
ctypes.c_ulong(patch_address),
|
||||||
|
4,
|
||||||
|
ctypes.pointer(pdata),
|
||||||
|
ctypes.pointer(data_cnt),
|
||||||
|
)
|
||||||
buf = ctypes.string_at(pdata.value, data_cnt.value)
|
buf = ctypes.string_at(pdata.value, data_cnt.value)
|
||||||
hexdump(buf)
|
hexdump(buf)
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,15 @@ import collections
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import faulthandler
|
import faulthandler
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
|
|
||||||
basedir = Path(__file__).resolve().parent
|
basedir = Path(__file__).resolve().parent
|
||||||
|
|
||||||
libane = None
|
libane = None
|
||||||
aneregs = None
|
aneregs = None
|
||||||
|
|
||||||
|
|
||||||
def init_libane():
|
def init_libane():
|
||||||
global libane, aneregs
|
global libane, aneregs
|
||||||
libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())
|
libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())
|
||||||
|
@ -32,71 +35,56 @@ def init_libane():
|
||||||
with open(basedir / "aneregs.json") as f:
|
with open(basedir / "aneregs.json") as f:
|
||||||
aneregs = json.load(f)
|
aneregs = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
ANE_Struct = [
|
ANE_Struct = [
|
||||||
# aneTD.Header
|
# aneTD.Header
|
||||||
("u32", 0x1C, "NextCommandOffset"),
|
("u32", 0x1C, "NextCommandOffset"),
|
||||||
|
|
||||||
# KernelDMASrc @ section @ 0x2C len 0xF4
|
# KernelDMASrc @ section @ 0x2C len 0xF4
|
||||||
# reloc 0x2c-0x34?? = weights
|
# reloc 0x2c-0x34?? = weights
|
||||||
# u32[16] 0x34-0x74 = 0x80 | 1 if used
|
# u32[16] 0x34-0x74 = 0x80 | 1 if used
|
||||||
# u32[16] 0x74-0xB4 = <channel data offset>
|
# u32[16] 0x74-0xB4 = <channel data offset>
|
||||||
# u32[16] 0xB4-0xF4 = <channel data length>
|
# u32[16] 0xB4-0xF4 = <channel data length>
|
||||||
|
|
||||||
# Common @ section @ 0x128 len 0x3C (conv)
|
# Common @ section @ 0x128 len 0x3C (conv)
|
||||||
("u16", 0x128, "InputWidth"),
|
("u16", 0x128, "InputWidth"),
|
||||||
("u16", 0x12A, "InputHeight"),
|
("u16", 0x12A, "InputHeight"),
|
||||||
("u16", 0x12C, "InputDepth"),
|
("u16", 0x12C, "InputDepth"),
|
||||||
|
|
||||||
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
|
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
|
||||||
# UInt8 = 0, Int8 = 1, Float16 = 2
|
# UInt8 = 0, Int8 = 1, Float16 = 2
|
||||||
|
|
||||||
("u32", 0x134, "InputChannels"),
|
("u32", 0x134, "InputChannels"),
|
||||||
("u32", 0x138, "OutputChannels"),
|
("u32", 0x138, "OutputChannels"),
|
||||||
|
|
||||||
("u16", 0x13C, "OutputWidth"),
|
("u16", 0x13C, "OutputWidth"),
|
||||||
("u16", 0x13E, "OutputHeight"),
|
("u16", 0x13E, "OutputHeight"),
|
||||||
("u16", 0x140, "OutputDepth"),
|
("u16", 0x140, "OutputDepth"),
|
||||||
|
|
||||||
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
|
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
|
||||||
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
|
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
|
||||||
|
|
||||||
("u16", 0x14C, "BatchSize"),
|
("u16", 0x14C, "BatchSize"),
|
||||||
|
|
||||||
# TileDMASrc @ section @ 0x16C len 0x6C (input)
|
# TileDMASrc @ section @ 0x16C len 0x6C (input)
|
||||||
# reloc 0x16c-0x174 = image
|
# reloc 0x16c-0x174 = image
|
||||||
("u32", 0x178, "InputRowStride"),
|
("u32", 0x178, "InputRowStride"),
|
||||||
("u32", 0x17C, "InputPlaneStride"),
|
("u32", 0x17C, "InputPlaneStride"),
|
||||||
("u32", 0x180, "InputDepthStride"),
|
("u32", 0x180, "InputDepthStride"),
|
||||||
("u32", 0x184, "InputBatchStride"),
|
("u32", 0x184, "InputBatchStride"),
|
||||||
|
|
||||||
("u8", 0x1A7, "InputInterleave"),
|
("u8", 0x1A7, "InputInterleave"),
|
||||||
|
|
||||||
# L2 @ section @ 0x1E0 len 0x44
|
# L2 @ section @ 0x1E0 len 0x44
|
||||||
# [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines
|
# [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines
|
||||||
# [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv?
|
# [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv?
|
||||||
# [0x21c, 0x220, 0x224] = engines for outconv?
|
# [0x21c, 0x220, 0x224] = engines for outconv?
|
||||||
|
|
||||||
# NE @ section @ 0x22c len 0xC (scaling)
|
# NE @ section @ 0x22c len 0xC (scaling)
|
||||||
("u16", 0x230, "BiasScalar"),
|
("u16", 0x230, "BiasScalar"),
|
||||||
("u16", 0x232, "ScaleScalar"),
|
("u16", 0x232, "ScaleScalar"),
|
||||||
|
|
||||||
# section @ 0x240 len 0x10
|
# section @ 0x240 len 0x10
|
||||||
("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom
|
("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom
|
||||||
("u32", 0x250, "PostScale"),
|
("u32", 0x250, "PostScale"),
|
||||||
|
|
||||||
# TileDMADst @ section @ 0x258 len 0x18
|
# TileDMADst @ section @ 0x258 len 0x18
|
||||||
|
|
||||||
# HandleTileDmaDstConfig
|
# HandleTileDmaDstConfig
|
||||||
# 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0;
|
# 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0;
|
||||||
# (GetCacheHintRegisterValue & 0xf) << 6;
|
# (GetCacheHintRegisterValue & 0xf) << 6;
|
||||||
("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at?
|
("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at?
|
||||||
|
|
||||||
# 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6;
|
# 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6;
|
||||||
("u32", 0x260, "OutputRowStride"),
|
("u32", 0x260, "OutputRowStride"),
|
||||||
("u32", 0x264, "OutputPlaneStride"),
|
("u32", 0x264, "OutputPlaneStride"),
|
||||||
("u32", 0x268, "OutputDepthStride"),
|
("u32", 0x268, "OutputDepthStride"),
|
||||||
("u32", 0x26C, "OutputBatchStride"),
|
("u32", 0x26C, "OutputBatchStride"),
|
||||||
|
|
||||||
# 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000;
|
# 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000;
|
||||||
# uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031;
|
# uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031;
|
||||||
# (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18;
|
# (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18;
|
||||||
|
@ -108,24 +96,26 @@ for typ, num, nam in ANE_Struct:
|
||||||
styp = {"u32": "I", "u16": "H", "u8": "B"}[typ]
|
styp = {"u32": "I", "u16": "H", "u8": "B"}[typ]
|
||||||
ANE_Struct_Dict[nam] = (styp, num)
|
ANE_Struct_Dict[nam] = (styp, num)
|
||||||
|
|
||||||
|
|
||||||
class ANETensor:
|
class ANETensor:
|
||||||
def __init__(self, *shape):
|
def __init__(self, *shape):
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
self.dtype = np.float16
|
self.dtype = np.float16
|
||||||
self.sz = int(np.prod(shape))
|
self.sz = int(np.prod(shape))
|
||||||
assert(self.sz <= 0x4000)
|
assert self.sz <= 0x4000
|
||||||
self.tt = libane.ANE_TensorCreate(self.sz, 1)
|
self.tt = libane.ANE_TensorCreate(self.sz, 1)
|
||||||
assert(self.tt is not None)
|
assert self.tt is not None
|
||||||
|
|
||||||
def data(self):
|
def data(self):
|
||||||
data = libane.ANE_TensorData(self.tt)
|
data = libane.ANE_TensorData(self.tt)
|
||||||
assert(data is not None)
|
assert data is not None
|
||||||
# print(hex(addressof(data.contents)))
|
# print(hex(addressof(data.contents)))
|
||||||
buf = np.ctypeslib.as_array(data, shape=(self.sz,))
|
buf = np.ctypeslib.as_array(data, shape=(self.sz,))
|
||||||
ret = np.frombuffer(buf, dtype=self.dtype)
|
ret = np.frombuffer(buf, dtype=self.dtype)
|
||||||
# print(ret.data)
|
# print(ret.data)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class ANE:
|
class ANE:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
init_libane()
|
init_libane()
|
||||||
|
@ -133,11 +123,13 @@ class ANE:
|
||||||
|
|
||||||
def compile(self, dat):
|
def compile(self, dat):
|
||||||
ret = libane.ANE_Compile(create_string_buffer(dat), len(dat))
|
ret = libane.ANE_Compile(create_string_buffer(dat), len(dat))
|
||||||
assert(ret is not None)
|
assert ret is not None
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def run(self, prog, tin, tout, tweights=None):
|
def run(self, prog, tin, tout, tweights=None):
|
||||||
libane.ANE_Run(prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0)
|
libane.ANE_Run(
|
||||||
|
prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
def tensor(self, shape):
|
def tensor(self, shape):
|
||||||
return ANETensor(shape)
|
return ANETensor(shape)
|
||||||
|
@ -165,9 +157,9 @@ class ANE:
|
||||||
return dat
|
return dat
|
||||||
|
|
||||||
def debug(self, dat, mems=0):
|
def debug(self, dat, mems=0):
|
||||||
add = [0x30, 0x1d4, 0x220, 0x29c, 0x2f0, 0x30c, 0x32c]
|
add = [0x30, 0x1D4, 0x220, 0x29C, 0x2F0, 0x30C, 0x32C]
|
||||||
lens = [244, 60, 108, 68, 12, 16, 24]
|
lens = [244, 60, 108, 68, 12, 16, 24]
|
||||||
ptr = 0x2b
|
ptr = 0x2B
|
||||||
ddat = dat[0:0x28]
|
ddat = dat[0:0x28]
|
||||||
for a, pm in zip(add, lens):
|
for a, pm in zip(add, lens):
|
||||||
# assert pm == dat[ptr]
|
# assert pm == dat[ptr]
|
||||||
|
@ -176,7 +168,12 @@ class ANE:
|
||||||
ptr += pm + 8
|
ptr += pm + 8
|
||||||
ddat += b"\x00" * 0x100
|
ddat += b"\x00" * 0x100
|
||||||
ret = collections.OrderedDict()
|
ret = collections.OrderedDict()
|
||||||
for ln in libane.ANE_RegDebug(0, create_string_buffer(ddat), mems).decode('utf-8').strip().split("\n"):
|
for ln in (
|
||||||
|
libane.ANE_RegDebug(0, create_string_buffer(ddat), mems)
|
||||||
|
.decode("utf-8")
|
||||||
|
.strip()
|
||||||
|
.split("\n")
|
||||||
|
):
|
||||||
lnn = ln.split(" = ")
|
lnn = ln.split(" = ")
|
||||||
if len(lnn) == 2:
|
if len(lnn) == 2:
|
||||||
ret[lnn[0]] = int(lnn[1])
|
ret[lnn[0]] = int(lnn[1])
|
||||||
|
@ -194,6 +191,7 @@ class ANE:
|
||||||
dat[base + a : base + a + len(x)] = x
|
dat[base + a : base + a + len(x)] = x
|
||||||
return dat
|
return dat
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ane = ANE()
|
ane = ANE()
|
||||||
|
|
||||||
|
@ -212,11 +210,10 @@ if __name__ == "__main__":
|
||||||
md = dat[0x4000:0x4300]
|
md = dat[0x4000:0x4300]
|
||||||
dd = ane.unpack(md)
|
dd = ane.unpack(md)
|
||||||
mdf = ane.pack(dd, md)
|
mdf = ane.pack(dd, md)
|
||||||
assert(md == mdf)
|
assert md == mdf
|
||||||
|
|
||||||
comp = ane.compile(dat)
|
comp = ane.compile(dat)
|
||||||
ret = ane.run(comp, tin, tout)
|
ret = ane.run(comp, tin, tout)
|
||||||
print("** after **")
|
print("** after **")
|
||||||
print(tind)
|
print(tind)
|
||||||
print(toutd)
|
print(toutd)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import time
|
import time
|
||||||
from ane import ANE, ANETensor
|
from ane import ANE, ANETensor
|
||||||
|
|
||||||
|
|
||||||
def benchmark(ane):
|
def benchmark(ane):
|
||||||
tin = ANETensor(512 * 0x20)
|
tin = ANETensor(512 * 0x20)
|
||||||
tout = ANETensor(512 * 0x20)
|
tout = ANETensor(512 * 0x20)
|
||||||
|
@ -14,7 +15,7 @@ def benchmark(ane):
|
||||||
for i in range(1000):
|
for i in range(1000):
|
||||||
ret = ane.run(comp, tin, tout)
|
ret = ane.run(comp, tin, tout)
|
||||||
et = time.time()
|
et = time.time()
|
||||||
ts = (et-st)
|
ts = et - st
|
||||||
ops = 1000 * 512 * 512 * 2
|
ops = 1000 * 512 * 512 * 2
|
||||||
|
|
||||||
print("%.2f ms, %.2f gigaops/sec" % (ts * 1000, ops * 1e-9 / ts))
|
print("%.2f ms, %.2f gigaops/sec" % (ts * 1000, ops * 1e-9 / ts))
|
||||||
|
@ -72,7 +73,7 @@ if __name__ == "__main__":
|
||||||
dd = ane.unpack(dat[0x4000:0x4300])
|
dd = ane.unpack(dat[0x4000:0x4300])
|
||||||
# use the 3rd arg as the weights
|
# use the 3rd arg as the weights
|
||||||
dd["aneTD.Header[9].KBase0"] = 6
|
dd["aneTD.Header[9].KBase0"] = 6
|
||||||
dd["aneRegs.NE.PostScale.PostScale"] = 0x3c00
|
dd["aneRegs.NE.PostScale.PostScale"] = 0x3C00
|
||||||
# dd["aneRegs.L2.L2Cfg.InputReLU"] = 1
|
# dd["aneRegs.L2.L2Cfg.InputReLU"] = 1
|
||||||
# dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1
|
# dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1
|
||||||
# dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0
|
# dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from .tensor import Device, Function, register
|
from .tensor import Device, Function, register
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def compile_wrapper(ane, dat):
|
def compile_wrapper(ane, dat):
|
||||||
return ane.compile(dat)
|
return ane.compile(dat)
|
||||||
|
|
||||||
|
|
||||||
def roundup(x, v):
|
def roundup(x, v):
|
||||||
return x + (v - x) % v
|
return x + (v - x) % v
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def compile_relu(ane, sz):
|
def compile_relu(ane, sz):
|
||||||
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
|
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
|
||||||
|
@ -17,16 +20,25 @@ def compile_relu(ane, sz):
|
||||||
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
|
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
|
||||||
# 0x1f4, 0x1f8?
|
# 0x1f4, 0x1f8?
|
||||||
# 0x214 = L2.ResultBase.Addr
|
# 0x214 = L2.ResultBase.Addr
|
||||||
dat = ane.fill(dat, [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214], "I", l2_stride)
|
dat = ane.fill(dat, [0x1EC, 0x1F0, 0x1F4, 0x1F8, 0x214], "I", l2_stride)
|
||||||
stride = roundup(sz * 2, 0x40)
|
stride = roundup(sz * 2, 0x40)
|
||||||
dat = ane.filln(dat, {
|
dat = ane.filln(
|
||||||
|
dat,
|
||||||
|
{
|
||||||
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
|
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
|
||||||
"InputWidth": sz, "OutputWidth": sz,
|
"InputWidth": sz,
|
||||||
"InputRowStride": stride, "InputPlaneStride": stride, "InputDepthStride": stride,
|
"OutputWidth": sz,
|
||||||
"OutputRowStride": stride, "OutputPlaneStride": stride, "OutputDepthStride": stride,
|
"InputRowStride": stride,
|
||||||
})
|
"InputPlaneStride": stride,
|
||||||
|
"InputDepthStride": stride,
|
||||||
|
"OutputRowStride": stride,
|
||||||
|
"OutputPlaneStride": stride,
|
||||||
|
"OutputDepthStride": stride,
|
||||||
|
},
|
||||||
|
)
|
||||||
return compile_wrapper(ane, bytes(dat))
|
return compile_wrapper(ane, bytes(dat))
|
||||||
|
|
||||||
|
|
||||||
class ReLU(Function):
|
class ReLU(Function):
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
ret = ctx.ane.tensor(input.shape)
|
ret = ctx.ane.tensor(input.shape)
|
||||||
|
@ -36,4 +48,5 @@ class ReLU(Function):
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
register('relu', ReLU, device=Device.ANE)
|
|
||||||
|
register("relu", ReLU, device=Device.ANE)
|
||||||
|
|
|
@ -31,13 +31,14 @@ for x in out.values(): x.realize()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from openvino.runtime import Core
|
from openvino.runtime import Core
|
||||||
|
|
||||||
core = Core()
|
core = Core()
|
||||||
devices = core.available_devices
|
devices = core.available_devices
|
||||||
for device in devices:
|
for device in devices:
|
||||||
device_name = core.get_property(device, "FULL_DEVICE_NAME")
|
device_name = core.get_property(device, "FULL_DEVICE_NAME")
|
||||||
print(f"{device}: {device_name}")
|
print(f"{device}: {device_name}")
|
||||||
model = core.read_model(onnx_path)
|
model = core.read_model(onnx_path)
|
||||||
compiled_model = core.compile_model(model, device_name='GPU.0')
|
compiled_model = core.compile_model(model, device_name="GPU.0")
|
||||||
print(compiled_model)
|
print(compiled_model)
|
||||||
ireq = compiled_model.create_infer_request()
|
ireq = compiled_model.create_infer_request()
|
||||||
for model_input in compiled_model.inputs:
|
for model_input in compiled_model.inputs:
|
||||||
|
@ -51,7 +52,7 @@ print("did one")
|
||||||
|
|
||||||
REPS = 20
|
REPS = 20
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
for i in range(REPS): ireq.infer()
|
for i in range(REPS):
|
||||||
|
ireq.infer()
|
||||||
et = time.perf_counter() - st
|
et = time.perf_counter() - st
|
||||||
print(f"{et*1000:.2f} ms {(CNT*N*N*N*REPS*2/et)*1e-9:.2f} GFLOPS")
|
print(f"{et*1000:.2f} ms {(CNT*N*N*N*REPS*2/et)*1e-9:.2f} GFLOPS")
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,12 @@ from tqdm import trange, tqdm
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
tests = {}
|
tests = {}
|
||||||
|
|
||||||
|
|
||||||
def register_test(fxn):
|
def register_test(fxn):
|
||||||
tests[fxn.__name__] = fxn
|
tests[fxn.__name__] = fxn
|
||||||
|
|
||||||
|
|
||||||
def warp_size2(nthread):
|
def warp_size2(nthread):
|
||||||
prg = """__kernel void warp_size2(
|
prg = """__kernel void warp_size2(
|
||||||
__global float* src,
|
__global float* src,
|
||||||
|
@ -27,16 +30,36 @@ def warp_size2(nthread):
|
||||||
src_buf = CLBuffer(1, dtypes.float32)
|
src_buf = CLBuffer(1, dtypes.float32)
|
||||||
dst_buf = CLBuffer(1, dtypes.int32)
|
dst_buf = CLBuffer(1, dtypes.int32)
|
||||||
cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32])
|
cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32])
|
||||||
return min([cl([nthread, 1024, 1], [nthread, 1, 1], src_buf, dst_buf, 10, 3, wait=True) for _ in range(5)])*1e9
|
return (
|
||||||
|
min(
|
||||||
|
[
|
||||||
|
cl(
|
||||||
|
[nthread, 1024, 1],
|
||||||
|
[nthread, 1, 1],
|
||||||
|
src_buf,
|
||||||
|
dst_buf,
|
||||||
|
10,
|
||||||
|
3,
|
||||||
|
wait=True,
|
||||||
|
)
|
||||||
|
for _ in range(5)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
* 1e9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test
|
@register_test
|
||||||
def test_warp_size():
|
def test_warp_size():
|
||||||
return [(nthread, warp_size2(nthread)) for nthread in trange(1, 256)]
|
return [(nthread, warp_size2(nthread)) for nthread in trange(1, 256)]
|
||||||
|
|
||||||
|
|
||||||
def reg_count(nthread, ngrp, nreg):
|
def reg_count(nthread, ngrp, nreg):
|
||||||
reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)])
|
reg_declr = "".join(
|
||||||
reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
|
[f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)]
|
||||||
reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
|
)
|
||||||
|
reg_comp = "".join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
|
||||||
|
reg_reduce = "".join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
|
||||||
prg = f"""__kernel void reg_count(
|
prg = f"""__kernel void reg_count(
|
||||||
__global float* out_buf,
|
__global float* out_buf,
|
||||||
__private const int niter
|
__private const int niter
|
||||||
|
@ -51,12 +74,25 @@ def reg_count(nthread, ngrp, nreg):
|
||||||
}}"""
|
}}"""
|
||||||
out_buf = CLBuffer(1, dtypes.float32)
|
out_buf = CLBuffer(1, dtypes.float32)
|
||||||
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
|
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
|
||||||
return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True) for _ in range(10)])*1e9
|
return (
|
||||||
|
min(
|
||||||
|
[
|
||||||
|
cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True)
|
||||||
|
for _ in range(10)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
* 1e9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test
|
@register_test
|
||||||
def test_reg_count(nthread=1, ngrp=1):
|
def test_reg_count(nthread=1, ngrp=1):
|
||||||
base = reg_count(nthread, ngrp, 1)
|
base = reg_count(nthread, ngrp, 1)
|
||||||
return [(nreg, (reg_count(nthread, ngrp, nreg)-base)/nreg) for nreg in trange(4, 513, 4)]
|
return [
|
||||||
|
(nreg, (reg_count(nthread, ngrp, nreg) - base) / nreg)
|
||||||
|
for nreg in trange(4, 513, 4)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
|
def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
|
||||||
ndata //= NCOMP * 4 # ptr size
|
ndata //= NCOMP * 4 # ptr size
|
||||||
|
@ -72,22 +108,40 @@ def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
|
||||||
*dst = idx;
|
*dst = idx;
|
||||||
}}"""
|
}}"""
|
||||||
idx_buf = np.zeros(ndata * NCOMP, dtype=np.int32)
|
idx_buf = np.zeros(ndata * NCOMP, dtype=np.int32)
|
||||||
for i in range(ndata): idx_buf[i*NCOMP] = (i + stride) % ndata
|
for i in range(ndata):
|
||||||
|
idx_buf[i * NCOMP] = (i + stride) % ndata
|
||||||
in_buf = CLBuffer.fromCPU(idx_buf)
|
in_buf = CLBuffer.fromCPU(idx_buf)
|
||||||
out_buf = CLBuffer(1, dtypes.int32)
|
out_buf = CLBuffer(1, dtypes.int32)
|
||||||
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
|
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
|
||||||
return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True)/steps for _ in range(5)])*1e9
|
return (
|
||||||
|
min(
|
||||||
|
[
|
||||||
|
cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True) / steps
|
||||||
|
for _ in range(5)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
* 1e9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test
|
@register_test
|
||||||
def test_memory_latency():
|
def test_memory_latency():
|
||||||
# requires cacheline < 16
|
# requires cacheline < 16
|
||||||
szs = [int(1.3**x) for x in range(20, 70)]
|
szs = [int(1.3**x) for x in range(20, 70)]
|
||||||
return [(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128*1024)) for ndata in tqdm(szs)]
|
return [
|
||||||
|
(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128 * 1024))
|
||||||
|
for ndata in tqdm(szs)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@register_test
|
@register_test
|
||||||
def test_cacheline_size():
|
def test_cacheline_size():
|
||||||
# TODO: this buffer must be at least 2x the L1 cache for this test to work
|
# TODO: this buffer must be at least 2x the L1 cache for this test to work
|
||||||
return [(stride, buf_cache_hierarchy_pchase(4*65536, stride, steps=65536)) for stride in trange(1,64)]
|
return [
|
||||||
|
(stride, buf_cache_hierarchy_pchase(4 * 65536, stride, steps=65536))
|
||||||
|
for stride in trange(1, 64)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def cl_read(sz, niter=1):
|
def cl_read(sz, niter=1):
|
||||||
prg = f"""__kernel void copy(
|
prg = f"""__kernel void copy(
|
||||||
|
@ -101,7 +155,16 @@ def cl_read(sz, niter=1):
|
||||||
out_buf = CLBuffer(1, dtypes.float32)
|
out_buf = CLBuffer(1, dtypes.float32)
|
||||||
cl = CLProgram("copy", prg)
|
cl = CLProgram("copy", prg)
|
||||||
# NOTE: if nay of the niters form a local group, this is wrong
|
# NOTE: if nay of the niters form a local group, this is wrong
|
||||||
return min([cl([sz//16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True) for _ in range(10)])*1e9
|
return (
|
||||||
|
min(
|
||||||
|
[
|
||||||
|
cl([sz // 16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True)
|
||||||
|
for _ in range(10)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
* 1e9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test
|
@register_test
|
||||||
def test_read_bandwidth():
|
def test_read_bandwidth():
|
||||||
|
@ -129,12 +192,17 @@ def gflops(niter=4, nroll=4, ngroups=4096):
|
||||||
cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math")
|
cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math")
|
||||||
FLOPS = NCOMP * 2 * 2 * niter * nroll * ngroups * 32
|
FLOPS = NCOMP * 2 * 2 * niter * nroll * ngroups * 32
|
||||||
# NOTE: if nay of the niters form a local group, this is wrong
|
# NOTE: if nay of the niters form a local group, this is wrong
|
||||||
return FLOPS/(min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])*1e9)
|
return FLOPS / (
|
||||||
|
min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])
|
||||||
|
* 1e9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_test
|
@register_test
|
||||||
def test_gflops():
|
def test_gflops():
|
||||||
return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)]
|
return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cache = {}
|
cache = {}
|
||||||
# cache = pickle.load(open("/tmp/cache.pkl", "rb"))
|
# cache = pickle.load(open("/tmp/cache.pkl", "rb"))
|
||||||
|
@ -144,8 +212,10 @@ if __name__ == "__main__":
|
||||||
print(f"running {k}")
|
print(f"running {k}")
|
||||||
plt.subplot(2, (len(tests) + 1) // 2, i + 1)
|
plt.subplot(2, (len(tests) + 1) // 2, i + 1)
|
||||||
plt.title(k)
|
plt.title(k)
|
||||||
if k == "test_memory_latency": plt.xscale('log')
|
if k == "test_memory_latency":
|
||||||
if k not in cache: cache[k] = test()
|
plt.xscale("log")
|
||||||
|
if k not in cache:
|
||||||
|
cache[k] = test()
|
||||||
plt.plot(*zip(*cache[k]))
|
plt.plot(*zip(*cache[k]))
|
||||||
# pickle.dump(cache, open("/tmp/cache.pkl", "wb"))
|
# pickle.dump(cache, open("/tmp/cache.pkl", "wb"))
|
||||||
|
|
||||||
|
|
|
@ -1,32 +1,69 @@
|
||||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
from typing import (
|
||||||
|
Tuple,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
DefaultDict,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
|
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
|
||||||
from tinygrad.ops import BinaryOps, UnaryOps
|
from tinygrad.ops import BinaryOps, UnaryOps
|
||||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
from tinygrad.shape.symbolic import (
|
||||||
|
Variable,
|
||||||
|
NumNode,
|
||||||
|
MulNode,
|
||||||
|
DivNode,
|
||||||
|
ModNode,
|
||||||
|
LtNode,
|
||||||
|
SumNode,
|
||||||
|
AndNode,
|
||||||
|
)
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
_type_to_letter = {
|
||||||
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
dtypes.float32: "f",
|
||||||
|
dtypes.bool: "p",
|
||||||
|
dtypes.int32: "i",
|
||||||
|
dtypes.int64: "a",
|
||||||
|
dtypes.uint32: "u",
|
||||||
|
dtypes.uint64: "b",
|
||||||
|
dtypes.float.vec(4): "x",
|
||||||
|
dtypes.uint8: "uc",
|
||||||
|
dtypes.float16: "h",
|
||||||
|
dtypes.int8: "c",
|
||||||
|
dtypes.uint16: "us",
|
||||||
|
dtypes.float64: "d",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Register(NamedTuple):
|
class Register(NamedTuple):
|
||||||
nm: str
|
nm: str
|
||||||
dtype: DType
|
dtype: DType
|
||||||
scalar: bool
|
scalar: bool
|
||||||
off: Optional[int] = None
|
off: Optional[int] = None
|
||||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||||
|
|
||||||
def subregs(self):
|
def subregs(self):
|
||||||
if self.dtype == dtypes.float.vec(4):
|
if self.dtype == dtypes.float.vec(4):
|
||||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class AssemblyInstruction(NamedTuple):
|
class AssemblyInstruction(NamedTuple):
|
||||||
op: UOps
|
op: UOps
|
||||||
out: Optional[Register]
|
out: Optional[Register]
|
||||||
vin: List[Union[Register, int, float]]
|
vin: List[Union[Register, int, float]]
|
||||||
arg: Any = None
|
arg: Any = None
|
||||||
|
|
||||||
|
|
||||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||||
class AssemblyLanguage:
|
class AssemblyLanguage:
|
||||||
supports_load3: bool = False
|
supports_load3: bool = False
|
||||||
|
@ -37,9 +74,15 @@ class AssemblyLanguage:
|
||||||
tor: Dict[Any, Register] = {}
|
tor: Dict[Any, Register] = {}
|
||||||
ins: List[AssemblyInstruction] = []
|
ins: List[AssemblyInstruction] = []
|
||||||
|
|
||||||
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
def type_to_letter(self, x):
|
||||||
|
return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||||
|
|
||||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||||||
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
self.tor[tok] = ret = Register(
|
||||||
|
f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}",
|
||||||
|
dtype,
|
||||||
|
scalar,
|
||||||
|
)
|
||||||
if dtype == dtypes.float.vec(4):
|
if dtype == dtypes.float.vec(4):
|
||||||
for off in range(4):
|
for off in range(4):
|
||||||
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||||
|
@ -48,30 +91,72 @@ class AssemblyLanguage:
|
||||||
|
|
||||||
def render_numnode(self, b) -> Register:
|
def render_numnode(self, b) -> Register:
|
||||||
key = ("num", b)
|
key = ("num", b)
|
||||||
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
if key not in self.tor:
|
||||||
|
self.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b
|
||||||
|
)
|
||||||
|
)
|
||||||
return self.tor[key]
|
return self.tor[key]
|
||||||
|
|
||||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
def render_alu(
|
||||||
|
self, op, a: Register, b: Union[Register, int, float], dtype=dtypes.int32
|
||||||
|
) -> Register:
|
||||||
key = (op, a, b)
|
key = (op, a, b)
|
||||||
if key not in self.tor:
|
if key not in self.tor:
|
||||||
# if not isinstance(b, Register): b = render_numnode(b)
|
# if not isinstance(b, Register): b = render_numnode(b)
|
||||||
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
self.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ALU,
|
||||||
|
self.newreg(
|
||||||
|
key,
|
||||||
|
dtype=dtype,
|
||||||
|
scalar=a.scalar and (not isinstance(b, Register) or b.scalar),
|
||||||
|
),
|
||||||
|
[a, b],
|
||||||
|
op,
|
||||||
|
)
|
||||||
|
)
|
||||||
return self.tor[key]
|
return self.tor[key]
|
||||||
|
|
||||||
def render_cast(self, a: Register, new_dtype: DType) -> Register:
|
def render_cast(self, a: Register, new_dtype: DType) -> Register:
|
||||||
if a.dtype == new_dtype: return a
|
if a.dtype == new_dtype:
|
||||||
|
return a
|
||||||
key = (a, new_dtype)
|
key = (a, new_dtype)
|
||||||
if key not in self.tor:
|
if key not in self.tor:
|
||||||
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
self.ins.append(
|
||||||
|
AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])
|
||||||
|
)
|
||||||
return self.tor[key]
|
return self.tor[key]
|
||||||
|
|
||||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
render_ops: Any = {
|
||||||
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
Variable: lambda self, ops, ctx: ctx.tor[self],
|
||||||
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||||
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
MulNode: lambda self, ops, ctx: ctx.render_alu(
|
||||||
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
BinaryOps.MUL, self.a.render(ops, ctx), self.b
|
||||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
),
|
||||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
DivNode: lambda self, ops, ctx: ctx.render_alu(
|
||||||
|
BinaryOps.DIV, self.a.render(ops, ctx), self.b
|
||||||
|
),
|
||||||
|
ModNode: lambda self, ops, ctx: ctx.render_alu(
|
||||||
|
BinaryOps.MOD, self.a.render(ops, ctx), self.b
|
||||||
|
),
|
||||||
|
LtNode: lambda self, ops, ctx: ctx.render_alu(
|
||||||
|
BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool
|
||||||
|
),
|
||||||
|
SumNode: lambda self, ops, ctx: functools.reduce(
|
||||||
|
lambda a, b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops, ctx)),
|
||||||
|
self.nodes[1:],
|
||||||
|
self.nodes[0].render(ops, ctx),
|
||||||
|
),
|
||||||
|
AndNode: lambda self, ops, ctx: functools.reduce(
|
||||||
|
lambda a, b: ctx.render_alu(
|
||||||
|
BinaryOps.MUL, a, b.render(ops, ctx), dtype=dtypes.bool
|
||||||
|
),
|
||||||
|
self.nodes[1:],
|
||||||
|
self.nodes[0].render(ops, ctx),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def addr_w_offset(self, args):
|
def addr_w_offset(self, args):
|
||||||
assert isinstance(args, MemOp)
|
assert isinstance(args, MemOp)
|
||||||
|
@ -79,110 +164,264 @@ class AssemblyLanguage:
|
||||||
off = 0 # TODO: should this be None?
|
off = 0 # TODO: should this be None?
|
||||||
if isinstance(idx, SumNode):
|
if isinstance(idx, SumNode):
|
||||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||||
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
if (
|
||||||
|
nums and nums[0] < 4096 and (idx - nums[0]).min >= 0
|
||||||
|
): # TODO: different for each GPU?
|
||||||
idx -= nums[0]
|
idx -= nums[0]
|
||||||
off = cast(int, nums[0])
|
off = cast(int, nums[0])
|
||||||
reg = idx.render(self.render_ops, self)
|
reg = idx.render(self.render_ops, self)
|
||||||
if self.supports_load3:
|
if self.supports_load3:
|
||||||
if reg.scalar:
|
if reg.scalar:
|
||||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
new_reg = self.newreg((reg.nm, "vec"), dtype=reg.dtype)
|
||||||
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
self.ins.append(
|
||||||
|
AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)
|
||||||
|
)
|
||||||
reg = new_reg
|
reg = new_reg
|
||||||
return self.tor[args.name], reg, off
|
return self.tor[args.name], reg, off
|
||||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
reg = self.render_alu(
|
||||||
|
BinaryOps.ADD,
|
||||||
|
self.render_cast(reg, dtypes.uint64),
|
||||||
|
self.tor[args.name],
|
||||||
|
dtype=dtypes.uint64,
|
||||||
|
)
|
||||||
return reg, None, off
|
return reg, None, off
|
||||||
|
|
||||||
|
|
||||||
def uops_to_asmstyle(lang, function_name: str, uops: List[UOp]):
|
def uops_to_asmstyle(lang, function_name: str, uops: List[UOp]):
|
||||||
# TODO: Do not use clear()
|
# TODO: Do not use clear()
|
||||||
lang.ins.clear()
|
lang.ins.clear()
|
||||||
lang.tor.clear()
|
lang.tor.clear()
|
||||||
lang.cnts.clear()
|
lang.cnts.clear()
|
||||||
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
|
buf_to_dtype = {
|
||||||
|
args[0]: args[1] for uop, _, _, args, _ in uops if uop == UOps.DEFINE_GLOBAL
|
||||||
|
}
|
||||||
global_size, local_size = [], []
|
global_size, local_size = [], []
|
||||||
skipload_branch = 0
|
skipload_branch = 0
|
||||||
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
lang.ins += [
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf
|
||||||
|
)
|
||||||
|
for buf in buf_to_dtype
|
||||||
|
]
|
||||||
for u in uops:
|
for u in uops:
|
||||||
uop, dtype, vin, args, _ = u
|
uop, dtype, vin, args, _ = u
|
||||||
if uop == UOps.DEFINE_LOCAL:
|
if uop == UOps.DEFINE_LOCAL:
|
||||||
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ALU,
|
||||||
|
lang.newreg(args[0], dtype=dtypes.uint64),
|
||||||
|
[args[0]],
|
||||||
|
UnaryOps.NOOP,
|
||||||
|
)
|
||||||
|
)
|
||||||
elif uop == UOps.LOOP:
|
elif uop == UOps.LOOP:
|
||||||
if args[1] == "global":
|
if args[1] == "global":
|
||||||
for i, var in enumerate(args[0]):
|
for i, var in enumerate(args[0]):
|
||||||
global_size.append(var.max + 1)
|
global_size.append(var.max + 1)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.SPECIAL,
|
||||||
|
lang.newreg(var, dtype=dtypes.int32),
|
||||||
|
[],
|
||||||
|
f"gid{len(args[0])-1-i}",
|
||||||
|
)
|
||||||
|
)
|
||||||
elif args[1] == "local":
|
elif args[1] == "local":
|
||||||
for i, var in enumerate(args[0]):
|
for i, var in enumerate(args[0]):
|
||||||
local_size.append(var.max + 1)
|
local_size.append(var.max + 1)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.SPECIAL,
|
||||||
|
lang.newreg(var, dtype=dtypes.int32),
|
||||||
|
[],
|
||||||
|
f"lid{len(args[0])-1-i}",
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for var in args[0]:
|
for var in args[0]:
|
||||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
if not isinstance(
|
||||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
var, NumNode
|
||||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
): # TODO: why is this coming through?
|
||||||
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.LOAD,
|
||||||
|
lang.newreg(var, dtype=dtypes.int32, scalar=True),
|
||||||
|
[],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.LABEL, None, [], "$loop_" + var.expr
|
||||||
|
)
|
||||||
|
)
|
||||||
elif uop == UOps.ENDLOOP:
|
elif uop == UOps.ENDLOOP:
|
||||||
if args[1] not in ["global", "local", "global+local"]:
|
if args[1] not in ["global", "local", "global+local"]:
|
||||||
for var in reversed(args[0]):
|
for var in reversed(args[0]):
|
||||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
if not isinstance(
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
var, NumNode
|
||||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
): # TODO: why is this coming through?
|
||||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ALU,
|
||||||
|
lang.tor[var],
|
||||||
|
[lang.tor[var], 1],
|
||||||
|
BinaryOps.ADD,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pred = lang.render_alu(
|
||||||
|
BinaryOps.CMPLT, lang.tor[var], var.max + 1, dtypes.bool
|
||||||
|
)
|
||||||
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.COND_BRANCH,
|
||||||
|
None,
|
||||||
|
[pred],
|
||||||
|
("$loop_" + var.expr, True),
|
||||||
|
)
|
||||||
|
)
|
||||||
elif args[1] == "global+local":
|
elif args[1] == "global+local":
|
||||||
for i, var in enumerate(reversed(args[0])):
|
for i, var in enumerate(reversed(args[0])):
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
lang.ins.append(
|
||||||
elif args[1] == 'local':
|
AssemblyInstruction(
|
||||||
|
UOps.ENDLOOP,
|
||||||
|
None,
|
||||||
|
[lang.tor[var]],
|
||||||
|
(var.max + 1, f"gid{i}"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif args[1] == "local":
|
||||||
for i, var in enumerate(reversed(args[0])):
|
for i, var in enumerate(reversed(args[0])):
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ENDLOOP,
|
||||||
|
None,
|
||||||
|
[lang.tor[var]],
|
||||||
|
(var.max + 1, f"lid{i}"),
|
||||||
|
)
|
||||||
|
)
|
||||||
elif uop == UOps.CAST:
|
elif uop == UOps.CAST:
|
||||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||||
out = lang.newreg(u, dtype)
|
out = lang.newreg(u, dtype)
|
||||||
for i, sr in enumerate(out.subregs()):
|
for i, sr in enumerate(out.subregs()):
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)
|
||||||
|
)
|
||||||
elif uop == UOps.ALU:
|
elif uop == UOps.ALU:
|
||||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||||
# this is the only thing that can violate SSA
|
# this is the only thing that can violate SSA
|
||||||
if args in [BinaryOps.CMPLT]:
|
if args in [BinaryOps.CMPLT]:
|
||||||
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
pred_reg = lang.newreg((u, "pred"), dtype=dtypes.bool)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args
|
||||||
|
)
|
||||||
|
)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||||
elif args == BinaryOps.DIV and lang.no_div:
|
elif args == BinaryOps.DIV and lang.no_div:
|
||||||
tmp = lang.newreg((u, "rcp"))
|
tmp = lang.newreg((u, "rcp"))
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
lang.ins.append(
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
AssemblyInstruction(
|
||||||
|
UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP
|
||||||
|
)
|
||||||
|
)
|
||||||
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL
|
||||||
|
)
|
||||||
|
)
|
||||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||||
tmp = lang.newreg((u, "2pi"))
|
tmp = lang.newreg((u, "2pi"))
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ALU,
|
||||||
|
tmp,
|
||||||
|
[lang.tor[vin[0]], 1 / (math.pi * 2)],
|
||||||
|
BinaryOps.MUL,
|
||||||
|
)
|
||||||
|
)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||||
else:
|
else:
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)
|
||||||
|
)
|
||||||
elif uop == UOps.DEFINE_ACC:
|
elif uop == UOps.DEFINE_ACC:
|
||||||
reg = lang.newreg(u, dtype=dtype)
|
reg = lang.newreg(u, dtype=dtype)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
||||||
elif uop == UOps.SPECIAL:
|
elif uop == UOps.SPECIAL:
|
||||||
lang.tor[u] = lang.tor[args]
|
lang.tor[u] = lang.tor[args]
|
||||||
elif uop == UOps.CONST:
|
elif uop == UOps.CONST:
|
||||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)
|
||||||
|
)
|
||||||
elif uop == UOps.LOAD:
|
elif uop == UOps.LOAD:
|
||||||
idx, treg, off = lang.addr_w_offset(args)
|
idx, treg, off = lang.addr_w_offset(args)
|
||||||
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
reg = lang.newreg(
|
||||||
|
u,
|
||||||
|
dtype=dtype,
|
||||||
|
scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)),
|
||||||
|
)
|
||||||
if args.valid.min == 0:
|
if args.valid.min == 0:
|
||||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||||
if args.valid.max == 1:
|
if args.valid.max == 1:
|
||||||
pred = args.valid.render(lang.render_ops, lang)
|
pred = args.valid.render(lang.render_ops, lang)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.COND_BRANCH,
|
||||||
|
None,
|
||||||
|
[pred],
|
||||||
|
(f"$skipload_{skipload_branch}", False),
|
||||||
|
)
|
||||||
|
)
|
||||||
if args.valid.max == 1:
|
if args.valid.max == 1:
|
||||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.LOAD,
|
||||||
|
reg,
|
||||||
|
[idx] + ([treg] if treg is not None else []),
|
||||||
|
(
|
||||||
|
off,
|
||||||
|
"global" if not args.local else "shared",
|
||||||
|
args.memory_dtype
|
||||||
|
if args.memory_dtype != dtypes.float
|
||||||
|
else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
if args.valid.min == 0 and args.valid.max == 1:
|
if args.valid.min == 0 and args.valid.max == 1:
|
||||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.LABEL, None, [], f"$skipload_{skipload_branch}"
|
||||||
|
)
|
||||||
|
)
|
||||||
skipload_branch += 1
|
skipload_branch += 1
|
||||||
elif uop == UOps.STORE:
|
elif uop == UOps.STORE:
|
||||||
if args is None:
|
if args is None:
|
||||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
idx, treg, off = lang.addr_w_offset(args)
|
idx, treg, off = lang.addr_w_offset(args)
|
||||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
lang.ins.append(
|
||||||
|
AssemblyInstruction(
|
||||||
|
UOps.STORE,
|
||||||
|
None,
|
||||||
|
[idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []),
|
||||||
|
(
|
||||||
|
off,
|
||||||
|
"global" if not args.local else "shared",
|
||||||
|
args.memory_dtype
|
||||||
|
if args.memory_dtype != dtypes.float
|
||||||
|
else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if DEBUG >= 4:
|
if DEBUG >= 4:
|
||||||
for tins in lang.ins: print(tins)
|
for tins in lang.ins:
|
||||||
|
print(tins)
|
||||||
return global_size, local_size
|
return global_size, local_size
|
||||||
|
|
|
@ -6,28 +6,60 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
||||||
from tinygrad.helpers import dtypes, CI
|
from tinygrad.helpers import dtypes, CI
|
||||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||||
|
|
||||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
|
||||||
|
def float_to_hex(x):
|
||||||
|
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
|
||||||
|
|
||||||
|
|
||||||
def compute_offsets(total):
|
def compute_offsets(total):
|
||||||
quotient, remainder = divmod(total, 4096)
|
quotient, remainder = divmod(total, 4096)
|
||||||
return [4096] * quotient + [remainder] if remainder else [4096] * quotient
|
return [4096] * quotient + [remainder] if remainder else [4096] * quotient
|
||||||
|
|
||||||
#NOTE: Darwin needs names to start with a "_"
|
|
||||||
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
|
||||||
|
|
||||||
class ARM64Language(AssemblyLanguage): pass
|
# NOTE: Darwin needs names to start with a "_"
|
||||||
|
def get_name(name):
|
||||||
|
return ("_" if system() == "Darwin" else "") + name
|
||||||
|
|
||||||
|
|
||||||
|
class ARM64Language(AssemblyLanguage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def specialize_to_arm64(fn_nm, asm):
|
def specialize_to_arm64(fn_nm, asm):
|
||||||
var_size = 16
|
var_size = 16
|
||||||
prev_uop: Optional[UOps] = None
|
prev_uop: Optional[UOps] = None
|
||||||
ins = []
|
ins = []
|
||||||
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
x_regs = ["x" + str(i) for i in reversed(range(12))]
|
||||||
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
s_regs = ["s" + str(i) for i in reversed(range(3, 32)) if i <= 7 or i >= 16]
|
||||||
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
type_to_reg = {
|
||||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
dtypes.double: "d",
|
||||||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
dtypes.half: "h",
|
||||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
dtypes.float32: "s",
|
||||||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
dtypes.bool: "w",
|
||||||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
dtypes.int8: "w",
|
||||||
|
dtypes.int32: "w",
|
||||||
|
dtypes.int64: "x",
|
||||||
|
dtypes.uint8: "w",
|
||||||
|
dtypes.uint32: "w",
|
||||||
|
dtypes.uint64: "x",
|
||||||
|
}
|
||||||
|
alu = {
|
||||||
|
BinaryOps.ADD: "add",
|
||||||
|
BinaryOps.SUB: "sub",
|
||||||
|
BinaryOps.MUL: "mul",
|
||||||
|
BinaryOps.DIV: "div",
|
||||||
|
BinaryOps.MAX: "max",
|
||||||
|
BinaryOps.MOD: "",
|
||||||
|
BinaryOps.CMPLT: "subs",
|
||||||
|
UnaryOps.NOOP: "mov",
|
||||||
|
UnaryOps.NEG: "neg",
|
||||||
|
UnaryOps.SIN: "bl " + get_name("sinf"),
|
||||||
|
UnaryOps.LOG2: "bl " + get_name("log2f"),
|
||||||
|
UnaryOps.EXP2: "bl " + get_name("exp2f"),
|
||||||
|
UnaryOps.SQRT: "bl " + get_name("sqrtf"),
|
||||||
|
TernaryOps.MULACC: "madd",
|
||||||
|
TernaryOps.WHERE: "fcsel",
|
||||||
|
}
|
||||||
|
|
||||||
def mov_imm(value, reg):
|
def mov_imm(value, reg):
|
||||||
# Manually move value into reg if value can't fit
|
# Manually move value into reg if value can't fit
|
||||||
|
@ -35,7 +67,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||||
ins.append(f"movz w15, #{value & 0xffff}")
|
ins.append(f"movz w15, #{value & 0xffff}")
|
||||||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||||
ins.append(f"sxtw {reg}, w15")
|
ins.append(f"sxtw {reg}, w15")
|
||||||
elif reg[0] == 's':
|
elif reg[0] == "s":
|
||||||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||||
ins.append("str x15, [sp, 16]")
|
ins.append("str x15, [sp, 16]")
|
||||||
|
@ -46,42 +78,51 @@ def specialize_to_arm64(fn_nm, asm):
|
||||||
# Get variables intervals
|
# Get variables intervals
|
||||||
live_range: Dict[str, List[int]] = {}
|
live_range: Dict[str, List[int]] = {}
|
||||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
for var in [v for v in [out] + vin if v is not None and v.__class__ is not int]:
|
||||||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
live_range[var.nm] = (
|
||||||
|
[i, i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||||
|
)
|
||||||
|
|
||||||
mem_vars: Dict[str, int] = {}
|
mem_vars: Dict[str, int] = {}
|
||||||
rtor: Dict[str, str] = {}
|
rtor: Dict[str, str] = {}
|
||||||
|
|
||||||
def allocate_regs(mvars):
|
def allocate_regs(mvars):
|
||||||
nonlocal var_size
|
nonlocal var_size
|
||||||
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
for v in [
|
||||||
|
v
|
||||||
|
for v in mvars
|
||||||
|
if v is not None and v.__class__ is not int and v.nm not in rtor
|
||||||
|
]:
|
||||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||||
# NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
# NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||||
if not available_regs:
|
if not available_regs:
|
||||||
# ARM needs the stack 16-byte aligned
|
# ARM needs the stack 16-byte aligned
|
||||||
var_size += 16
|
var_size += 16
|
||||||
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
|
available_regs.append("s0" if dtypes.is_float(out[1]) else "x12")
|
||||||
mem_vars[v.nm] = var_size
|
mem_vars[v.nm] = var_size
|
||||||
rtor[v.nm] = available_regs.pop()
|
rtor[v.nm] = available_regs.pop()
|
||||||
|
|
||||||
temp_floats = ['s0', 's1', 's2']
|
temp_floats = ["s0", "s1", "s2"]
|
||||||
temp_ints = ['x12', 'x13', 'x16']
|
temp_ints = ["x12", "x13", "x16"]
|
||||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||||
# Clear regs out of interval
|
# Clear regs out of interval
|
||||||
for var, reg in list(rtor.items()):
|
for var, reg in list(rtor.items()):
|
||||||
available_regs = s_regs if reg[0] == 's' else x_regs
|
available_regs = s_regs if reg[0] == "s" else x_regs
|
||||||
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
if var[1] not in "B" and var not in mem_vars and i > live_range[var][1]:
|
||||||
available_regs.append(rtor.pop(var))
|
available_regs.append(rtor.pop(var))
|
||||||
# Assign a registers to the variables using live ranges.
|
# Assign a registers to the variables using live ranges.
|
||||||
allocate_regs([out] + vin)
|
allocate_regs([out] + vin)
|
||||||
# Assign temp regs to vin and load them before direct use
|
# Assign temp regs to vin and load them before direct use
|
||||||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
for i, v in enumerate(
|
||||||
|
[v for v in vin if v.__class__ is not int and v.nm in mem_vars]
|
||||||
|
):
|
||||||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||||
|
|
||||||
if uop == UOps.SPECIAL:
|
if uop == UOps.SPECIAL:
|
||||||
if arg.startswith('data'):
|
if arg.startswith("data"):
|
||||||
# data 8 to n into the stack
|
# data 8 to n into the stack
|
||||||
if int(arg[4:]) >= 8:
|
if int(arg[4:]) >= 8:
|
||||||
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
||||||
|
@ -91,28 +132,40 @@ def specialize_to_arm64(fn_nm, asm):
|
||||||
ins.append(f"loop_{arg}:")
|
ins.append(f"loop_{arg}:")
|
||||||
elif uop == UOps.CAST:
|
elif uop == UOps.CAST:
|
||||||
if arg == BinaryOps.CMPLT:
|
if arg == BinaryOps.CMPLT:
|
||||||
if rtor[out.nm][0] == 's':
|
if rtor[out.nm][0] == "s":
|
||||||
mov_imm(0.0, 's0')
|
mov_imm(0.0, "s0")
|
||||||
mov_imm(1.0, 's1')
|
mov_imm(1.0, "s1")
|
||||||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||||
if rtor[out.nm][0] == 'x':
|
if rtor[out.nm][0] == "x":
|
||||||
mov_imm(0, 'x14')
|
mov_imm(0, "x14")
|
||||||
mov_imm(1, 'x15')
|
mov_imm(1, "x15")
|
||||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||||
else:
|
else:
|
||||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||||
elif uop == UOps.ALU:
|
elif uop == UOps.ALU:
|
||||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
if len(vin) == 2 and vin[1].__class__ is int:
|
||||||
|
mov_imm(vin[1], "x15")
|
||||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
ins.append(
|
||||||
|
f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||||
|
)
|
||||||
elif arg == TernaryOps.WHERE:
|
elif arg == TernaryOps.WHERE:
|
||||||
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
|
ins.append(
|
||||||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
f"fcmp {rtor[vin[0].nm]}, #0.0"
|
||||||
|
if rtor[vin[0].nm][0] == "s"
|
||||||
|
else f"cmp {rtor[vin[0].nm]}, #0"
|
||||||
|
)
|
||||||
|
ins.append(
|
||||||
|
f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne"
|
||||||
|
)
|
||||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||||
# NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
# NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||||
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
if CI:
|
||||||
|
ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||||
else:
|
else:
|
||||||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
save_regs = [
|
||||||
|
k for k in rtor.keys() if k != out.nm and k not in mem_vars
|
||||||
|
]
|
||||||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||||
# Save the registers before they are cleared by func call
|
# Save the registers before they are cleared by func call
|
||||||
for i, k in enumerate(save_regs, 1):
|
for i, k in enumerate(save_regs, 1):
|
||||||
|
@ -128,27 +181,49 @@ def specialize_to_arm64(fn_nm, asm):
|
||||||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||||
elif arg == BinaryOps.CMPLT:
|
elif arg == BinaryOps.CMPLT:
|
||||||
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
ins.append(
|
||||||
|
f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||||
|
if not dtypes.is_float(vin[0][1])
|
||||||
|
else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}"
|
||||||
|
)
|
||||||
elif arg == BinaryOps.MOD:
|
elif arg == BinaryOps.MOD:
|
||||||
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
|
rhs = "x15" if vin[1].__class__ is int else rtor[vin[1].nm]
|
||||||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
||||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||||
else:
|
else:
|
||||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
ins.append(
|
||||||
|
f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||||
|
)
|
||||||
elif uop == UOps.LOAD:
|
elif uop == UOps.LOAD:
|
||||||
if arg.__class__ in (int, float):
|
if arg.__class__ in (int, float):
|
||||||
mov_imm(arg, rtor[out.nm])
|
mov_imm(arg, rtor[out.nm])
|
||||||
else:
|
else:
|
||||||
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
reg_in = (
|
||||||
|
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
||||||
|
if arg[2] is not None
|
||||||
|
else rtor[out.nm]
|
||||||
|
)
|
||||||
mov_imm(arg[0], "x15")
|
mov_imm(arg[0], "x15")
|
||||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
ins.append(
|
||||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]"
|
||||||
|
)
|
||||||
|
if arg[2] is not None:
|
||||||
|
ins.append(
|
||||||
|
f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}"
|
||||||
|
)
|
||||||
elif uop == UOps.STORE:
|
elif uop == UOps.STORE:
|
||||||
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
reg_out = (
|
||||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
||||||
|
if arg[2] is not None
|
||||||
|
else rtor[vin[1].nm]
|
||||||
|
)
|
||||||
|
if arg[2] is not None:
|
||||||
|
ins.append(
|
||||||
|
f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}"
|
||||||
|
)
|
||||||
ins.append(f"mov x15, #{arg[0]}")
|
ins.append(f"mov x15, #{arg[0]}")
|
||||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||||
elif uop == UOps.COND_BRANCH:
|
elif uop == UOps.COND_BRANCH:
|
||||||
|
@ -168,9 +243,31 @@ def specialize_to_arm64(fn_nm, asm):
|
||||||
if out is not None and out.nm in mem_vars:
|
if out is not None and out.nm in mem_vars:
|
||||||
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||||
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||||
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
return "\n".join(
|
||||||
|
[
|
||||||
|
f"//varsize {var_size}",
|
||||||
|
".arch armv8-a",
|
||||||
|
".text",
|
||||||
|
f".global {get_name(fn_nm)}",
|
||||||
|
".p2align 2",
|
||||||
|
f"{get_name(fn_nm)}:",
|
||||||
|
"mov x17, sp",
|
||||||
|
]
|
||||||
|
+ [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
||||||
|
+ ins
|
||||||
|
+ [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
||||||
|
+ ["ret", "\n"]
|
||||||
|
)
|
||||||
|
|
||||||
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
|
||||||
|
def uops_to_arm64_asm(
|
||||||
|
fn_nm: str, uops: List[UOp]
|
||||||
|
) -> Tuple[str, List[int], List[int], bool]:
|
||||||
lang = ARM64Language()
|
lang = ARM64Language()
|
||||||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||||
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
return (
|
||||||
|
specialize_to_arm64(fn_nm, lang.ins),
|
||||||
|
global_size[::-1],
|
||||||
|
local_size[::-1],
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
|
@ -6,50 +6,113 @@ from tinygrad.helpers import dtypes
|
||||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||||
from tinygrad.runtime.ops_cuda import arch
|
from tinygrad.runtime.ops_cuda import arch
|
||||||
|
|
||||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
|
dtype_to_nvtype = {
|
||||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
dtypes.float32: "f32",
|
||||||
|
dtypes.float16: "f16",
|
||||||
|
dtypes.int64: "s64",
|
||||||
|
dtypes.int32: "s32",
|
||||||
|
dtypes.int8: "s8",
|
||||||
|
dtypes.bool: "pred",
|
||||||
|
dtypes.uint64: "u64",
|
||||||
|
dtypes.uint32: "u32",
|
||||||
|
dtypes.uint16: "u16",
|
||||||
|
dtypes.uint8: "u8",
|
||||||
|
"bits16": "b16",
|
||||||
|
dtypes.float64: "f64",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def float_to_hex(x):
|
||||||
|
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
|
||||||
|
|
||||||
|
|
||||||
|
def ptx_needs_cast(dest_dtype, src_dtype):
|
||||||
|
return (
|
||||||
|
dtypes.is_float(dest_dtype)
|
||||||
|
and dtypes.is_int(src_dtype)
|
||||||
|
or dtypes.is_int(dest_dtype)
|
||||||
|
and dtypes.is_float(src_dtype)
|
||||||
|
or (
|
||||||
|
dtypes.is_float(src_dtype)
|
||||||
|
and dtypes.is_float(dest_dtype)
|
||||||
|
and dest_dtype.itemsize != src_dtype.itemsize
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
|
|
||||||
|
|
||||||
def render_cast(ins, inp, out):
|
def render_cast(ins, inp, out):
|
||||||
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
|
if inp.dtype == dtypes.bool and (
|
||||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
|
dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)
|
||||||
|
):
|
||||||
|
ins.append(
|
||||||
|
f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};"
|
||||||
|
)
|
||||||
elif out.dtype == dtypes.bool:
|
elif out.dtype == dtypes.bool:
|
||||||
if inp.dtype == dtypes.bool:
|
if inp.dtype == dtypes.bool:
|
||||||
ins.append(f"mov.pred {out}, {inp};")
|
ins.append(f"mov.pred {out}, {inp};")
|
||||||
else:
|
else:
|
||||||
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
|
ins.append(
|
||||||
|
f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
|
round_mod = (
|
||||||
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
|
".rzi"
|
||||||
|
if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype)
|
||||||
|
else ".rz"
|
||||||
|
if dtypes.is_float(out.dtype)
|
||||||
|
and (
|
||||||
|
dtypes.is_int(inp.dtype)
|
||||||
|
or dtypes.is_float(inp.dtype)
|
||||||
|
and inp.dtype.itemsize > out.dtype.itemsize
|
||||||
|
)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
ins.append(
|
||||||
|
f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||||
|
|
||||||
|
|
||||||
class PTXLanguage(AssemblyLanguage):
|
class PTXLanguage(AssemblyLanguage):
|
||||||
supports_constant_folding: bool = True
|
supports_constant_folding: bool = True
|
||||||
|
|
||||||
|
|
||||||
def specialize_to_ptx(lang, function_name):
|
def specialize_to_ptx(lang, function_name):
|
||||||
param_cnt = 0
|
param_cnt = 0
|
||||||
ins = []
|
ins = []
|
||||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
alu = {
|
||||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
|
BinaryOps.ADD: "add",
|
||||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
BinaryOps.SUB: "sub",
|
||||||
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
BinaryOps.MUL: "mul",
|
||||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
BinaryOps.DIV: "div",
|
||||||
|
BinaryOps.MAX: "max",
|
||||||
|
BinaryOps.MOD: "rem",
|
||||||
|
BinaryOps.CMPLT: "setp.lt",
|
||||||
|
UnaryOps.SQRT: "sqrt.approx",
|
||||||
|
UnaryOps.NOOP: "mov",
|
||||||
|
UnaryOps.NEG: "neg",
|
||||||
|
UnaryOps.SIN: "sin.approx",
|
||||||
|
UnaryOps.LOG2: "lg2.approx",
|
||||||
|
UnaryOps.EXP2: "ex2.approx.ftz",
|
||||||
|
TernaryOps.MULACC: "fma.rn",
|
||||||
|
TernaryOps.WHERE: "selp",
|
||||||
|
}
|
||||||
for uop, out, vin, arg in lang.ins:
|
for uop, out, vin, arg in lang.ins:
|
||||||
if uop == UOps.ENDLOOP:
|
if uop == UOps.ENDLOOP:
|
||||||
ins.append("bar.sync 0;")
|
ins.append("bar.sync 0;")
|
||||||
elif uop == UOps.DEFINE_LOCAL:
|
elif uop == UOps.DEFINE_LOCAL:
|
||||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||||
elif uop == UOps.SPECIAL:
|
elif uop == UOps.SPECIAL:
|
||||||
if arg.startswith('data'):
|
if arg.startswith("data"):
|
||||||
param_cnt += 1
|
param_cnt += 1
|
||||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||||
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
||||||
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||||
elif arg.startswith('gid'):
|
elif arg.startswith("gid"):
|
||||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||||
elif arg.startswith('lid'):
|
elif arg.startswith("lid"):
|
||||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||||
elif uop == UOps.ALU:
|
elif uop == UOps.ALU:
|
||||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||||
|
@ -60,31 +123,64 @@ def specialize_to_ptx(lang, function_name):
|
||||||
if vin[0].dtype == dtypes.bool:
|
if vin[0].dtype == dtypes.bool:
|
||||||
reg = vin[0]
|
reg = vin[0]
|
||||||
else:
|
else:
|
||||||
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
|
reg = lang.newreg((vin[0], "bool"), dtypes.bool)
|
||||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
ins.append(
|
||||||
|
f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};"
|
||||||
|
)
|
||||||
vin = vin[1:] + [reg]
|
vin = vin[1:] + [reg]
|
||||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
ins.append(
|
||||||
|
f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};"
|
||||||
|
)
|
||||||
elif uop == UOps.LOAD:
|
elif uop == UOps.LOAD:
|
||||||
if arg.__class__ in (int, float):
|
if arg.__class__ in (int, float):
|
||||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
ins.append(
|
||||||
|
f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};"
|
||||||
|
)
|
||||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||||
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
|
dt = (
|
||||||
|
("u16", dtypes.uint16)
|
||||||
|
if arg[2] == dtypes.bool == out.dtype
|
||||||
|
else ("u8", dtypes.uint8)
|
||||||
|
if arg[2] == dtypes.bool
|
||||||
|
else ("b16", dtypes.float16)
|
||||||
|
if arg[2] == dtypes.half
|
||||||
|
else (dtype_to_nvtype[arg[2]], arg[2])
|
||||||
|
)
|
||||||
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
||||||
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
ins.append(
|
||||||
|
f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
|
||||||
|
)
|
||||||
render_cast(ins, reg, out)
|
render_cast(ins, reg, out)
|
||||||
else:
|
else:
|
||||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
ins.append(
|
||||||
|
f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
|
||||||
|
)
|
||||||
elif uop == UOps.STORE:
|
elif uop == UOps.STORE:
|
||||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
if (
|
||||||
|
ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype)
|
||||||
|
or arg[2] == dtypes.bool
|
||||||
|
):
|
||||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
prereg = lang.newreg((vin[1], "bool"), dtype=dtypes.bool)
|
||||||
render_cast(ins, vin[1], prereg)
|
render_cast(ins, vin[1], prereg)
|
||||||
else: prereg = vin[1]
|
|
||||||
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
|
|
||||||
render_cast(ins, prereg, reg)
|
|
||||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
|
||||||
else:
|
else:
|
||||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
prereg = vin[1]
|
||||||
|
reg = lang.newreg(
|
||||||
|
(prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]),
|
||||||
|
dtype=dtypes.uint16
|
||||||
|
if arg[2] == dtypes.bool
|
||||||
|
else dtypes.float
|
||||||
|
if arg[2] is None
|
||||||
|
else arg[2],
|
||||||
|
)
|
||||||
|
render_cast(ins, prereg, reg)
|
||||||
|
ins.append(
|
||||||
|
f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ins.append(
|
||||||
|
f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};"
|
||||||
|
)
|
||||||
elif uop == UOps.CAST:
|
elif uop == UOps.CAST:
|
||||||
render_cast(ins, vin[0], out)
|
render_cast(ins, vin[0], out)
|
||||||
elif uop == UOps.LABEL:
|
elif uop == UOps.LABEL:
|
||||||
|
@ -92,14 +188,29 @@ def specialize_to_ptx(lang, function_name):
|
||||||
elif uop == UOps.COND_BRANCH:
|
elif uop == UOps.COND_BRANCH:
|
||||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||||
|
|
||||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
ins_prefix = [
|
||||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
|
".version 7.8",
|
||||||
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
".target " + arch(),
|
||||||
|
".address_size 64",
|
||||||
|
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{",
|
||||||
|
]
|
||||||
|
for arg in [
|
||||||
|
(dtype, lang.type_to_letter(dtype), c) for dtype, c in lang.cnts.items()
|
||||||
|
]:
|
||||||
|
ins_prefix.append(
|
||||||
|
f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",
|
||||||
|
)
|
||||||
ins = ins_prefix + ins
|
ins = ins_prefix + ins
|
||||||
ins += ["ret;", "}"]
|
ins += ["ret;", "}"]
|
||||||
return '\n'.join(ins)
|
return "\n".join(ins)
|
||||||
|
|
||||||
|
|
||||||
def uops_to_ptx_asm(function_name: str, uops: List[UOp]):
|
def uops_to_ptx_asm(function_name: str, uops: List[UOp]):
|
||||||
lang = PTXLanguage()
|
lang = PTXLanguage()
|
||||||
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
||||||
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
|
return (
|
||||||
|
specialize_to_ptx(lang, function_name),
|
||||||
|
global_size[::-1],
|
||||||
|
local_size[::-1],
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
||||||
|
|
||||||
# ugh, is this really needed?
|
# ugh, is this really needed?
|
||||||
from extra.helpers import enable_early_exec
|
from extra.helpers import enable_early_exec
|
||||||
|
|
||||||
early_exec = enable_early_exec()
|
early_exec = enable_early_exec()
|
||||||
|
|
||||||
boilerplate_start = """
|
boilerplate_start = """
|
||||||
|
@ -24,6 +25,7 @@ code_start = """.end_amdhsa_kernel
|
||||||
code:
|
code:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
||||||
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
||||||
# RDNA3 is actually a SIMD machine!
|
# RDNA3 is actually a SIMD machine!
|
||||||
|
@ -36,107 +38,202 @@ class RDNACodegen(AssemblyCodegen):
|
||||||
|
|
||||||
def specialize(self, asm) -> Tuple[str, str]:
|
def specialize(self, asm) -> Tuple[str, str]:
|
||||||
args = []
|
args = []
|
||||||
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
|
for i, b in enumerate(self.bufs):
|
||||||
|
args.append(
|
||||||
|
{
|
||||||
|
".address_space": "global",
|
||||||
|
".name": f"buf_{i}",
|
||||||
|
".offset": i * 8,
|
||||||
|
".size": 8,
|
||||||
|
".type_name": b.dtype.name + "*",
|
||||||
|
".value_kind": "global_buffer",
|
||||||
|
}
|
||||||
|
)
|
||||||
ins = []
|
ins = []
|
||||||
|
|
||||||
v_cnt = 3 # v[0:2] is local_xyz
|
v_cnt = 3 # v[0:2] is local_xyz
|
||||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||||
|
|
||||||
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
|
dtype_to_rdnatype = {
|
||||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
|
dtypes.float32: "f32",
|
||||||
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
|
dtypes.int64: "i64",
|
||||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
|
dtypes.int32: "i32",
|
||||||
BinaryOps.CMPLT: "cmp_lt"}
|
dtypes.uint64: "u64",
|
||||||
|
dtypes.bool: "i32",
|
||||||
|
}
|
||||||
|
alu = {
|
||||||
|
BinaryOps.ADD: "add",
|
||||||
|
BinaryOps.SUB: "sub",
|
||||||
|
BinaryOps.MUL: "mul",
|
||||||
|
TernaryOps.MULACC: "fma",
|
||||||
|
BinaryOps.MAX: "max",
|
||||||
|
UnaryOps.RECIP: "rcp",
|
||||||
|
UnaryOps.NOOP: "mov",
|
||||||
|
UnaryOps.SIN: "sin",
|
||||||
|
UnaryOps.LOG2: "log",
|
||||||
|
UnaryOps.EXP2: "exp",
|
||||||
|
BinaryOps.CMPLT: "cmp_lt",
|
||||||
|
}
|
||||||
|
|
||||||
pend_regs: Set[Register] = set()
|
pend_regs: Set[Register] = set()
|
||||||
rtor: Dict[Register, str] = {}
|
rtor: Dict[Register, str] = {}
|
||||||
|
|
||||||
def reg_in(x):
|
def reg_in(x):
|
||||||
nonlocal pend_regs
|
nonlocal pend_regs
|
||||||
# print("reg_in", x, rtor[x], pend_regs)
|
# print("reg_in", x, rtor[x], pend_regs)
|
||||||
if x in pend_regs:
|
if x in pend_regs:
|
||||||
# print("clear")
|
# print("clear")
|
||||||
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
|
ins.append("s_waitcnt lgkmcnt(0), vmcnt(0)")
|
||||||
pend_regs.clear()
|
pend_regs.clear()
|
||||||
return rtor[x]
|
return rtor[x]
|
||||||
|
|
||||||
def reg_out(x):
|
def reg_out(x):
|
||||||
return rtor[x]
|
return rtor[x]
|
||||||
|
|
||||||
for uop, out, vin, arg in asm:
|
for uop, out, vin, arg in asm:
|
||||||
if uop == UOps.DEFINE_REGISTER:
|
if uop == UOps.DEFINE_REGISTER:
|
||||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
if arg[0][0] in [
|
||||||
|
dtypes.uint32,
|
||||||
|
dtypes.uint64,
|
||||||
|
dtypes.int64,
|
||||||
|
dtypes.int32,
|
||||||
|
dtypes.float32,
|
||||||
|
dtypes.float.vec(4),
|
||||||
|
]:
|
||||||
for i in range(arg[2]):
|
for i in range(arg[2]):
|
||||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||||
align = int(arg[0][0].itemsize / 4)
|
align = int(arg[0][0].itemsize / 4)
|
||||||
if arg[0][1]:
|
if arg[0][1]:
|
||||||
s_cnt += s_cnt % align
|
s_cnt += s_cnt % align
|
||||||
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
|
reg_name = (
|
||||||
|
f"s[{s_cnt}:{s_cnt + align - 1}]"
|
||||||
|
if align > 1
|
||||||
|
else f"s{s_cnt}"
|
||||||
|
)
|
||||||
s_cnt += align
|
s_cnt += align
|
||||||
else:
|
else:
|
||||||
v_cnt += v_cnt % align
|
v_cnt += v_cnt % align
|
||||||
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
|
reg_name = (
|
||||||
|
f"v[{v_cnt}:{v_cnt + align - 1}]"
|
||||||
|
if align > 1
|
||||||
|
else f"v{v_cnt}"
|
||||||
|
)
|
||||||
v_cnt += align
|
v_cnt += align
|
||||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||||
|
|
||||||
if arg[0][0] == dtypes.float.vec(4):
|
if arg[0][0] == dtypes.float.vec(4):
|
||||||
for off in range(4):
|
for off in range(4):
|
||||||
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
reg_name = (
|
||||||
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
f"s{s_cnt-align+off}"
|
||||||
|
if arg[0][1]
|
||||||
|
else f"v{v_cnt-align+off}"
|
||||||
|
)
|
||||||
|
rtor[
|
||||||
|
Register(
|
||||||
|
f"%{arg[1]}{i}", dtypes.float, False, off=off
|
||||||
|
)
|
||||||
|
] = reg_name
|
||||||
elif arg[0][0] == dtypes.bool:
|
elif arg[0][0] == dtypes.bool:
|
||||||
for i in range(arg[2]):
|
for i in range(arg[2]):
|
||||||
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
|
reg_name = (
|
||||||
|
"scc" if arg[0][1] else "vcc_lo"
|
||||||
|
) # `_lo` suffix since we're running wavefront_size=32
|
||||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
raise NotImplementedError(
|
||||||
|
"DEFINE_REGISTER not implemented for arg: ", arg
|
||||||
|
)
|
||||||
elif uop == UOps.SPECIAL:
|
elif uop == UOps.SPECIAL:
|
||||||
if arg.startswith('buf'):
|
if arg.startswith("buf"):
|
||||||
i = int(arg[3:])
|
i = int(arg[3:])
|
||||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
ins.append(f"s_load_b64 {reg_out(out)}, s[0:1], {i*8}")
|
||||||
pend_regs.add(out)
|
pend_regs.add(out)
|
||||||
for r in out.subregs(): pend_regs.add(r)
|
for r in out.subregs():
|
||||||
elif arg.startswith('gid'):
|
pend_regs.add(r)
|
||||||
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
|
elif arg.startswith("gid"):
|
||||||
|
ins.append(f"v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}")
|
||||||
# the docs lied, this is actually y
|
# the docs lied, this is actually y
|
||||||
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
if int(arg[3]) == 2:
|
||||||
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
|
ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||||
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
if int(arg[3]) == 1:
|
||||||
|
ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||||
|
elif int(arg[3]) == 0:
|
||||||
|
ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||||
# get local size
|
# get local size
|
||||||
offset = len(args) * 8
|
offset = len(args) * 8
|
||||||
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
|
args.append(
|
||||||
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
|
{
|
||||||
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
|
".offset": offset,
|
||||||
|
".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}",
|
||||||
|
".size": 8,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ins.append(f"s_load_b32 s{2+int(arg[3])}, s[0:1], {offset}")
|
||||||
|
ins.append("s_waitcnt vmcnt(0) lgkmcnt(0)")
|
||||||
pend_regs.clear()
|
pend_regs.clear()
|
||||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
ins.append(
|
||||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
f"v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}"
|
||||||
|
)
|
||||||
|
ins.append(
|
||||||
|
f"v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}"
|
||||||
|
)
|
||||||
elif uop == UOps.CONST:
|
elif uop == UOps.CONST:
|
||||||
if arg == float('inf'): arg = "0x7f800000"
|
if arg == float("inf"):
|
||||||
elif arg == float('-inf'): arg = "0xff800000"
|
arg = "0x7f800000"
|
||||||
|
elif arg == float("-inf"):
|
||||||
|
arg = "0xff800000"
|
||||||
if out.dtype == dtypes.float.vec(4):
|
if out.dtype == dtypes.float.vec(4):
|
||||||
for off in range(4):
|
for off in range(4):
|
||||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
ins.append(
|
||||||
|
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
ins.append(
|
||||||
|
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}"
|
||||||
|
)
|
||||||
elif uop == UOps.ALU:
|
elif uop == UOps.ALU:
|
||||||
if arg in [BinaryOps.CMPLT]:
|
if arg in [BinaryOps.CMPLT]:
|
||||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
ins.append(
|
||||||
|
f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
alu_arg = alu[arg]
|
alu_arg = alu[arg]
|
||||||
if arg == TernaryOps.MULACC and out == vin[2]:
|
if arg == TernaryOps.MULACC and out == vin[2]:
|
||||||
alu_arg = "fmac"
|
alu_arg = "fmac"
|
||||||
vin = vin[0:2]
|
vin = vin[0:2]
|
||||||
if out.dtype == dtypes.float.vec(4):
|
if out.dtype == dtypes.float.vec(4):
|
||||||
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
|
for rr in zip(
|
||||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
*[
|
||||||
|
x.subregs()
|
||||||
|
if x.dtype == dtypes.float.vec(4)
|
||||||
|
else [x, x, x, x]
|
||||||
|
for x in [out] + vin
|
||||||
|
]
|
||||||
|
):
|
||||||
|
ins.append(
|
||||||
|
f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
ins.append(
|
||||||
|
f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
|
||||||
|
)
|
||||||
elif uop == UOps.LOAD:
|
elif uop == UOps.LOAD:
|
||||||
if out.scalar:
|
if out.scalar:
|
||||||
# swap arg order
|
# swap arg order
|
||||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
ins.append(
|
||||||
|
f"s_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
ins.append(
|
||||||
|
f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
|
||||||
|
)
|
||||||
pend_regs.add(out)
|
pend_regs.add(out)
|
||||||
for r in out.subregs(): pend_regs.add(r)
|
for r in out.subregs():
|
||||||
|
pend_regs.add(r)
|
||||||
elif uop == UOps.STORE:
|
elif uop == UOps.STORE:
|
||||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
ins.append(
|
||||||
|
f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
|
||||||
|
)
|
||||||
elif uop == UOps.LABEL:
|
elif uop == UOps.LABEL:
|
||||||
ins.append(f"{arg}:")
|
ins.append(f"{arg}:")
|
||||||
elif uop == UOps.COND_BRANCH:
|
elif uop == UOps.COND_BRANCH:
|
||||||
|
@ -144,29 +241,40 @@ class RDNACodegen(AssemblyCodegen):
|
||||||
elif uop == UOps.CAST:
|
elif uop == UOps.CAST:
|
||||||
if vin[0].dtype == dtypes.bool:
|
if vin[0].dtype == dtypes.bool:
|
||||||
if out.dtype == dtypes.float32:
|
if out.dtype == dtypes.float32:
|
||||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
ins.append(
|
||||||
|
f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(uop)
|
raise NotImplementedError(uop)
|
||||||
|
|
||||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
ins += ["s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm", "s_code_end"]
|
||||||
|
|
||||||
# dual alu group
|
# dual alu group
|
||||||
seen = set()
|
seen = set()
|
||||||
new_ins = []
|
new_ins = []
|
||||||
for i, tins in enumerate(ins):
|
for i, tins in enumerate(ins):
|
||||||
if tins in seen: continue
|
if tins in seen:
|
||||||
|
continue
|
||||||
if tins.startswith("v_fmac_f32"):
|
if tins.startswith("v_fmac_f32"):
|
||||||
for gins in reversed(ins[i + 1 :]):
|
for gins in reversed(ins[i + 1 :]):
|
||||||
if gins in seen: continue
|
if gins in seen:
|
||||||
|
continue
|
||||||
if gins.startswith("v_fmac_f32"):
|
if gins.startswith("v_fmac_f32"):
|
||||||
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
|
r0 = [int(x[1:].strip(",")) for x in tins.split(" ")[1:]]
|
||||||
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
|
r1 = [int(x[1:].strip(",")) for x in gins.split(" ")[1:]]
|
||||||
if r0[0]%2 == r1[0]%2: continue
|
if r0[0] % 2 == r1[0] % 2:
|
||||||
if r0[1]%2 == r1[1]%2: continue
|
continue
|
||||||
if r0[2]%2 == r1[2]%2: continue
|
if r0[1] % 2 == r1[1] % 2:
|
||||||
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
|
continue
|
||||||
|
if r0[2] % 2 == r1[2] % 2:
|
||||||
|
continue
|
||||||
|
new_ins.append(
|
||||||
|
tins.replace("v_", "v_dual_")
|
||||||
|
+ " :: "
|
||||||
|
+ gins.replace("v_", "v_dual_")
|
||||||
|
)
|
||||||
seen.add(tins)
|
seen.add(tins)
|
||||||
seen.add(gins)
|
seen.add(gins)
|
||||||
break
|
break
|
||||||
|
@ -174,30 +282,102 @@ class RDNACodegen(AssemblyCodegen):
|
||||||
new_ins.append(tins)
|
new_ins.append(tins)
|
||||||
ins = new_ins
|
ins = new_ins
|
||||||
|
|
||||||
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
return "code", self.assemble(args, ins, v_cnt, s_cnt)
|
||||||
|
|
||||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||||
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
kernel_desc = {
|
||||||
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
".amdhsa_group_segment_fixed_size": 0,
|
||||||
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
".amdhsa_private_segment_fixed_size": 0,
|
||||||
'.amdhsa_next_free_sgpr': s_cnt,
|
".amdhsa_kernarg_size": 0,
|
||||||
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
|
".amdhsa_next_free_vgpr": v_cnt, # this matters!
|
||||||
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
".amdhsa_reserve_vcc": 0,
|
||||||
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
".amdhsa_reserve_xnack_mask": 0,
|
||||||
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
".amdhsa_next_free_sgpr": s_cnt,
|
||||||
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
".amdhsa_float_round_mode_32": 0,
|
||||||
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
".amdhsa_float_round_mode_16_64": 0,
|
||||||
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
".amdhsa_float_denorm_mode_32": 3,
|
||||||
|
".amdhsa_float_denorm_mode_16_64": 3,
|
||||||
|
".amdhsa_dx10_clamp": 1,
|
||||||
|
".amdhsa_ieee_mode": 1,
|
||||||
|
".amdhsa_fp16_overflow": 0,
|
||||||
|
".amdhsa_workgroup_processor_mode": 1,
|
||||||
|
".amdhsa_memory_ordered": 1,
|
||||||
|
".amdhsa_forward_progress": 0,
|
||||||
|
".amdhsa_enable_private_segment": 0,
|
||||||
|
".amdhsa_system_sgpr_workgroup_id_x": 1,
|
||||||
|
".amdhsa_system_sgpr_workgroup_id_y": 1,
|
||||||
|
".amdhsa_system_sgpr_workgroup_id_z": 1,
|
||||||
|
".amdhsa_system_sgpr_workgroup_info": 0,
|
||||||
|
".amdhsa_system_vgpr_workitem_id": 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||||
|
".amdhsa_exception_fp_ieee_invalid_op": 0,
|
||||||
|
".amdhsa_exception_fp_denorm_src": 0,
|
||||||
|
".amdhsa_exception_fp_ieee_div_zero": 0,
|
||||||
|
".amdhsa_exception_fp_ieee_overflow": 0,
|
||||||
|
".amdhsa_exception_fp_ieee_underflow": 0,
|
||||||
|
".amdhsa_exception_fp_ieee_inexact": 0,
|
||||||
|
".amdhsa_exception_int_div_zero": 0,
|
||||||
|
".amdhsa_user_sgpr_dispatch_ptr": 0,
|
||||||
|
".amdhsa_user_sgpr_queue_ptr": 0,
|
||||||
|
".amdhsa_user_sgpr_kernarg_segment_ptr": 1,
|
||||||
|
".amdhsa_user_sgpr_dispatch_id": 0,
|
||||||
|
".amdhsa_user_sgpr_private_segment_size": 0,
|
||||||
|
".amdhsa_wavefront_size32": 1,
|
||||||
|
".amdhsa_uses_dynamic_stack": 0,
|
||||||
|
}
|
||||||
|
|
||||||
metadata = {'amdhsa.kernels': [{'.args': args,
|
metadata = {
|
||||||
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
"amdhsa.kernels": [
|
||||||
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
{
|
||||||
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
".args": args,
|
||||||
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
".group_segment_fixed_size": 0,
|
||||||
'.wavefront_size': 32}],
|
".kernarg_segment_align": 8,
|
||||||
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
".kernarg_segment_size": args[-1][".offset"] + args[-1][".size"],
|
||||||
|
".language": "OpenCL C",
|
||||||
|
".language_version": [1, 2],
|
||||||
|
".max_flat_workgroup_size": 256,
|
||||||
|
".name": "code",
|
||||||
|
".private_segment_fixed_size": 0,
|
||||||
|
".sgpr_count": s_cnt,
|
||||||
|
".sgpr_spill_count": 0,
|
||||||
|
".symbol": "code.kd",
|
||||||
|
".uses_dynamic_stack": False,
|
||||||
|
".vgpr_count": v_cnt,
|
||||||
|
".vgpr_spill_count": 0,
|
||||||
|
".wavefront_size": 32,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"amdhsa.target": "amdgcn-amd-amdhsa--gfx1100",
|
||||||
|
"amdhsa.version": [1, 2],
|
||||||
|
}
|
||||||
|
|
||||||
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
|
code = (
|
||||||
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
boilerplate_start
|
||||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
|
+ "\n"
|
||||||
|
+ "\n".join("%s %d" % x for x in kernel_desc.items())
|
||||||
|
+ "\n"
|
||||||
|
+ code_start
|
||||||
|
+ "\n".join(ins)
|
||||||
|
+ "\n.amdgpu_metadata\n"
|
||||||
|
+ yaml.dump(metadata)
|
||||||
|
+ ".end_amdgpu_metadata"
|
||||||
|
)
|
||||||
|
obj = early_exec(
|
||||||
|
(
|
||||||
|
[
|
||||||
|
ROCM_LLVM_PATH / "llvm-mc",
|
||||||
|
"--arch=amdgcn",
|
||||||
|
"--mcpu=gfx1100",
|
||||||
|
"--triple=amdgcn-amd-amdhsa",
|
||||||
|
"--filetype=obj",
|
||||||
|
"-",
|
||||||
|
],
|
||||||
|
code.encode("utf-8"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asm = early_exec(
|
||||||
|
(
|
||||||
|
[ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"],
|
||||||
|
obj,
|
||||||
|
)
|
||||||
|
)
|
||||||
return asm
|
return asm
|
||||||
|
|
|
@ -4,7 +4,9 @@ from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
|
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
|
||||||
prg = CUDAProgram("test", """
|
prg = CUDAProgram(
|
||||||
|
"test",
|
||||||
|
"""
|
||||||
.version 7.8
|
.version 7.8
|
||||||
.target sm_86
|
.target sm_86
|
||||||
.address_size 64
|
.address_size 64
|
||||||
|
@ -17,7 +19,8 @@ if __name__ == "__main__":
|
||||||
mov.u32 %r1, 0x40000000; // 2.0 in float
|
mov.u32 %r1, 0x40000000; // 2.0 in float
|
||||||
st.global.u32 [%rd2], %r1;
|
st.global.u32 [%rd2], %r1;
|
||||||
ret;
|
ret;
|
||||||
}""", binary=True)
|
}""",
|
||||||
|
binary=True,
|
||||||
|
)
|
||||||
prg([1], [1], test)
|
prg([1], [1], test)
|
||||||
print(test.toCPU())
|
print(test.toCPU())
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import pathlib
|
||||||
from hexdump import hexdump
|
from hexdump import hexdump
|
||||||
from tinygrad.helpers import colored
|
from tinygrad.helpers import colored
|
||||||
from extra.helpers import enable_early_exec
|
from extra.helpers import enable_early_exec
|
||||||
|
|
||||||
early_exec = enable_early_exec()
|
early_exec = enable_early_exec()
|
||||||
|
|
||||||
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH
|
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH
|
||||||
|
@ -37,29 +38,49 @@ for j in range(1):
|
||||||
c = (y * KX + x) * 8
|
c = (y * KX + x) * 8
|
||||||
a = (KY * KX * 8) + y * 8
|
a = (KY * KX * 8) + y * 8
|
||||||
b = (KY * KX * 8) + (KY * 8) + x * 8
|
b = (KY * KX * 8) + (KY * 8) + x * 8
|
||||||
gen.append(f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]")
|
gen.append(
|
||||||
|
f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]"
|
||||||
|
)
|
||||||
FLOPS += 16 * 8 * 2
|
FLOPS += 16 * 8 * 2
|
||||||
else:
|
else:
|
||||||
for i in range(0, MAX_REG, 6):
|
for i in range(0, MAX_REG, 6):
|
||||||
if DUAL_ALU:
|
if DUAL_ALU:
|
||||||
if F32:
|
if F32:
|
||||||
gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
gen.append(
|
||||||
|
f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}"
|
||||||
|
)
|
||||||
FLOPS += 4
|
FLOPS += 4
|
||||||
else:
|
else:
|
||||||
gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}")
|
gen.append(
|
||||||
|
f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}"
|
||||||
|
)
|
||||||
FLOPS += 8
|
FLOPS += 8
|
||||||
else:
|
else:
|
||||||
assert F32
|
assert F32
|
||||||
gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}")
|
gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}")
|
||||||
gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
||||||
code = code.replace("// FLOPS", '\n'.join(gen))
|
code = code.replace("// FLOPS", "\n".join(gen))
|
||||||
print(code)
|
print(code)
|
||||||
|
|
||||||
|
|
||||||
# fix: COMGR failed to get code object ISA name. set triple to 'amdgcn-amd-amdhsa'
|
# fix: COMGR failed to get code object ISA name. set triple to 'amdgcn-amd-amdhsa'
|
||||||
|
|
||||||
object = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
object = early_exec(
|
||||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object))
|
(
|
||||||
|
[
|
||||||
|
ROCM_LLVM_PATH / "llvm-mc",
|
||||||
|
"--arch=amdgcn",
|
||||||
|
"--mcpu=gfx1100",
|
||||||
|
"--triple=amdgcn-amd-amdhsa",
|
||||||
|
"--filetype=obj",
|
||||||
|
"-",
|
||||||
|
],
|
||||||
|
code.encode("utf-8"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
asm = early_exec(
|
||||||
|
([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object)
|
||||||
|
)
|
||||||
|
|
||||||
with open("/tmp/cc2.o", "wb") as f:
|
with open("/tmp/cc2.o", "wb") as f:
|
||||||
f.write(object)
|
f.write(object)
|
||||||
|
|
|
@ -2,12 +2,14 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
cwd = Path.cwd()
|
cwd = Path.cwd()
|
||||||
sys.path.append(cwd.as_posix())
|
sys.path.append(cwd.as_posix())
|
||||||
sys.path.append((cwd / 'test').as_posix())
|
sys.path.append((cwd / "test").as_posix())
|
||||||
from extra.datasets import fetch_mnist
|
from extra.datasets import fetch_mnist
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
|
||||||
def augment_img(X, rotate=10, px=3):
|
def augment_img(X, rotate=10, px=3):
|
||||||
Xaug = np.zeros_like(X)
|
Xaug = np.zeros_like(X)
|
||||||
for i in trange(len(X)):
|
for i in trange(len(X)):
|
||||||
|
@ -20,8 +22,10 @@ def augment_img(X, rotate=10, px=3):
|
||||||
Xaug[i] = im
|
Xaug[i] = im
|
||||||
return Xaug
|
return Xaug
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||||
|
@ -29,14 +33,18 @@ if __name__ == "__main__":
|
||||||
fig, a = plt.subplots(2, len(X))
|
fig, a = plt.subplots(2, len(X))
|
||||||
Xaug = augment_img(X)
|
Xaug = augment_img(X)
|
||||||
for i in range(len(X)):
|
for i in range(len(X)):
|
||||||
a[0][i].imshow(X[i], cmap='gray')
|
a[0][i].imshow(X[i], cmap="gray")
|
||||||
a[1][i].imshow(Xaug[i],cmap='gray')
|
a[1][i].imshow(Xaug[i], cmap="gray")
|
||||||
a[0][i].axis('off')
|
a[0][i].axis("off")
|
||||||
a[1][i].axis('off')
|
a[1][i].axis("off")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# create some nice gifs for doc?!
|
# create some nice gifs for doc?!
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
im = Image.fromarray(X_train[7353 + i])
|
im = Image.fromarray(X_train[7353 + i])
|
||||||
im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
|
im_aug = [
|
||||||
im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0)
|
Image.fromarray(x) for x in augment_img(np.array([X_train[7353 + i]] * 100))
|
||||||
|
]
|
||||||
|
im.save(
|
||||||
|
f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0
|
||||||
|
)
|
||||||
|
|
|
@ -3,30 +3,53 @@ import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import dtypes, fetch
|
from tinygrad.helpers import dtypes, fetch
|
||||||
|
|
||||||
|
|
||||||
def fetch_mnist(tensors=False):
|
def fetch_mnist(tensors=False):
|
||||||
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
||||||
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
|
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
|
||||||
X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
X_train = (
|
||||||
|
parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:]
|
||||||
|
.reshape((-1, 28 * 28))
|
||||||
|
.astype(np.float32)
|
||||||
|
)
|
||||||
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
|
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
|
||||||
X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
X_test = (
|
||||||
|
parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:]
|
||||||
|
.reshape((-1, 28 * 28))
|
||||||
|
.astype(np.float32)
|
||||||
|
)
|
||||||
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
|
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
|
||||||
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
|
if tensors:
|
||||||
else: return X_train, Y_train, X_test, Y_test
|
return (
|
||||||
|
Tensor(X_train).reshape(-1, 1, 28, 28),
|
||||||
|
Tensor(Y_train),
|
||||||
|
Tensor(X_test).reshape(-1, 1, 28, 28),
|
||||||
|
Tensor(Y_test),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return X_train, Y_train, X_test, Y_test
|
||||||
|
|
||||||
|
|
||||||
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
|
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
|
||||||
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
||||||
|
|
||||||
|
|
||||||
def fetch_cifar():
|
def fetch_cifar():
|
||||||
X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8)
|
X_train = Tensor.empty(
|
||||||
Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64)
|
50000, 3 * 32 * 32, device=f"disk:/tmp/cifar_train_x", dtype=dtypes.uint8
|
||||||
X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8)
|
)
|
||||||
Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64)
|
Y_train = Tensor.empty(50000, device=f"disk:/tmp/cifar_train_y", dtype=dtypes.int64)
|
||||||
|
X_test = Tensor.empty(
|
||||||
|
10000, 3 * 32 * 32, device=f"disk:/tmp/cifar_test_x", dtype=dtypes.uint8
|
||||||
|
)
|
||||||
|
Y_test = Tensor.empty(10000, device=f"disk:/tmp/cifar_test_y", dtype=dtypes.int64)
|
||||||
|
|
||||||
if not os.path.isfile("/tmp/cifar_extracted"):
|
if not os.path.isfile("/tmp/cifar_extracted"):
|
||||||
|
|
||||||
def _load_disk_tensor(X, Y, db_list):
|
def _load_disk_tensor(X, Y, db_list):
|
||||||
idx = 0
|
idx = 0
|
||||||
for db in db_list:
|
for db in db_list:
|
||||||
x, y = db[b'data'], np.array(db[b'labels'])
|
x, y = db[b"data"], np.array(db[b"labels"])
|
||||||
assert x.shape[0] == y.shape[0]
|
assert x.shape[0] == y.shape[0]
|
||||||
X[idx : idx + x.shape[0]].assign(x)
|
X[idx : idx + x.shape[0]].assign(x)
|
||||||
Y[idx : idx + x.shape[0]].assign(y)
|
Y[idx : idx + x.shape[0]].assign(y)
|
||||||
|
@ -34,10 +57,28 @@ def fetch_cifar():
|
||||||
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
|
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
|
||||||
|
|
||||||
print("downloading and extracting CIFAR...")
|
print("downloading and extracting CIFAR...")
|
||||||
fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
|
fn = fetch("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz")
|
||||||
tt = tarfile.open(fn, mode='r:gz')
|
tt = tarfile.open(fn, mode="r:gz")
|
||||||
_load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
|
_load_disk_tensor(
|
||||||
_load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])
|
X_train,
|
||||||
|
Y_train,
|
||||||
|
[
|
||||||
|
pickle.load(
|
||||||
|
tt.extractfile(f"cifar-10-batches-py/data_batch_{i}"),
|
||||||
|
encoding="bytes",
|
||||||
|
)
|
||||||
|
for i in range(1, 6)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
_load_disk_tensor(
|
||||||
|
X_test,
|
||||||
|
Y_test,
|
||||||
|
[
|
||||||
|
pickle.load(
|
||||||
|
tt.extractfile("cifar-10-batches-py/test_batch"), encoding="bytes"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
open("/tmp/cifar_extracted", "wb").close()
|
open("/tmp/cifar_extracted", "wb").close()
|
||||||
|
|
||||||
return X_train, Y_train, X_test, Y_test
|
return X_train, Y_train, X_test, Y_test
|
||||||
|
|
|
@ -15,32 +15,36 @@ frPyObjects = _mask.frPyObjects
|
||||||
BASEDIR = pathlib.Path(__file__).parent / "COCO"
|
BASEDIR = pathlib.Path(__file__).parent / "COCO"
|
||||||
BASEDIR.mkdir(exist_ok=True)
|
BASEDIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
|
|
||||||
|
def create_dict(key_row, val_row, rows):
|
||||||
|
return {row[key_row]: row[val_row] for row in rows}
|
||||||
|
|
||||||
|
|
||||||
if not pathlib.Path(BASEDIR/'val2017').is_dir():
|
if not pathlib.Path(BASEDIR / "val2017").is_dir():
|
||||||
fn = fetch('http://images.cocodataset.org/zips/val2017.zip')
|
fn = fetch("http://images.cocodataset.org/zips/val2017.zip")
|
||||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
with zipfile.ZipFile(fn, "r") as zip_ref:
|
||||||
zip_ref.extractall(BASEDIR)
|
zip_ref.extractall(BASEDIR)
|
||||||
fn.unlink()
|
fn.unlink()
|
||||||
|
|
||||||
|
|
||||||
if not pathlib.Path(BASEDIR/'annotations').is_dir():
|
if not pathlib.Path(BASEDIR / "annotations").is_dir():
|
||||||
fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip')
|
fn = fetch("http://images.cocodataset.org/annotations/annotations_trainval2017.zip")
|
||||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
with zipfile.ZipFile(fn, "r") as zip_ref:
|
||||||
zip_ref.extractall(BASEDIR)
|
zip_ref.extractall(BASEDIR)
|
||||||
fn.unlink()
|
fn.unlink()
|
||||||
|
|
||||||
with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
|
with open(BASEDIR / "annotations/instances_val2017.json", "r") as f:
|
||||||
annotations_raw = json.loads(f.read())
|
annotations_raw = json.loads(f.read())
|
||||||
images = annotations_raw['images']
|
images = annotations_raw["images"]
|
||||||
categories = annotations_raw['categories']
|
categories = annotations_raw["categories"]
|
||||||
annotations = annotations_raw['annotations']
|
annotations = annotations_raw["annotations"]
|
||||||
file_name_to_id = create_dict('file_name', 'id', images)
|
file_name_to_id = create_dict("file_name", "id", images)
|
||||||
id_to_width = create_dict('id', 'width', images)
|
id_to_width = create_dict("id", "width", images)
|
||||||
id_to_height = create_dict('id', 'height', images)
|
id_to_height = create_dict("id", "height", images)
|
||||||
json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
|
json_category_id_to_contiguous_id = {v["id"]: i + 1 for i, v in enumerate(categories)}
|
||||||
contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
|
contiguous_category_id_to_json_id = {
|
||||||
|
v: k for k, v in json_category_id_to_contiguous_id.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def encode(bimask):
|
def encode(bimask):
|
||||||
|
@ -48,7 +52,8 @@ def encode(bimask):
|
||||||
return _mask.encode(bimask)
|
return _mask.encode(bimask)
|
||||||
elif len(bimask.shape) == 2:
|
elif len(bimask.shape) == 2:
|
||||||
h, w = bimask.shape
|
h, w = bimask.shape
|
||||||
return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
|
return _mask.encode(bimask.reshape((h, w, 1), order="F"))[0]
|
||||||
|
|
||||||
|
|
||||||
def decode(rleObjs):
|
def decode(rleObjs):
|
||||||
if type(rleObjs) == list:
|
if type(rleObjs) == list:
|
||||||
|
@ -56,12 +61,14 @@ def decode(rleObjs):
|
||||||
else:
|
else:
|
||||||
return _mask.decode([rleObjs])[:, :, 0]
|
return _mask.decode([rleObjs])[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
def area(rleObjs):
|
def area(rleObjs):
|
||||||
if type(rleObjs) == list:
|
if type(rleObjs) == list:
|
||||||
return _mask.area(rleObjs)
|
return _mask.area(rleObjs)
|
||||||
else:
|
else:
|
||||||
return _mask.area([rleObjs])[0]
|
return _mask.area([rleObjs])[0]
|
||||||
|
|
||||||
|
|
||||||
def toBbox(rleObjs):
|
def toBbox(rleObjs):
|
||||||
if type(rleObjs) == list:
|
if type(rleObjs) == list:
|
||||||
return _mask.toBbox(rleObjs)
|
return _mask.toBbox(rleObjs)
|
||||||
|
@ -102,8 +109,10 @@ def convert_prediction_to_coco_bbox(file_name, prediction):
|
||||||
print(file_name, e)
|
print(file_name, e)
|
||||||
return coco_results
|
return coco_results
|
||||||
|
|
||||||
|
|
||||||
masker = Masker(threshold=0.5, padding=1)
|
masker = Masker(threshold=0.5, padding=1)
|
||||||
|
|
||||||
|
|
||||||
def convert_prediction_to_coco_mask(file_name, prediction):
|
def convert_prediction_to_coco_mask(file_name, prediction):
|
||||||
coco_results = []
|
coco_results = []
|
||||||
try:
|
try:
|
||||||
|
@ -122,8 +131,7 @@ def convert_prediction_to_coco_mask(file_name, prediction):
|
||||||
masks = masker([masks], [prediction])[0].numpy()
|
masks = masker([masks], [prediction])[0].numpy()
|
||||||
|
|
||||||
rles = [
|
rles = [
|
||||||
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
|
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] for mask in masks
|
||||||
for mask in masks
|
|
||||||
]
|
]
|
||||||
for rle in rles:
|
for rle in rles:
|
||||||
rle["counts"] = rle["counts"].decode("utf-8")
|
rle["counts"] = rle["counts"].decode("utf-8")
|
||||||
|
@ -146,20 +154,22 @@ def convert_prediction_to_coco_mask(file_name, prediction):
|
||||||
return coco_results
|
return coco_results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
|
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
|
||||||
path = pathlib.Path(json_result_file)
|
path = pathlib.Path(json_result_file)
|
||||||
if rm and path.exists(): path.unlink()
|
if rm and path.exists():
|
||||||
|
path.unlink()
|
||||||
with open(path, "a") as f:
|
with open(path, "a") as f:
|
||||||
for s in coco_results:
|
for s in coco_results:
|
||||||
f.write(json.dumps(s))
|
f.write(json.dumps(s))
|
||||||
f.write('\n')
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
def remove_dup(l):
|
def remove_dup(l):
|
||||||
seen = set()
|
seen = set()
|
||||||
seen_add = seen.add
|
seen_add = seen.add
|
||||||
return [x for x in l if not (x in seen or seen_add(x))]
|
return [x for x in l if not (x in seen or seen_add(x))]
|
||||||
|
|
||||||
|
|
||||||
class NpEncoder(json.JSONEncoder):
|
class NpEncoder(json.JSONEncoder):
|
||||||
def default(self, obj):
|
def default(self, obj):
|
||||||
if isinstance(obj, np.integer):
|
if isinstance(obj, np.integer):
|
||||||
|
@ -177,23 +187,28 @@ def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
|
||||||
for line in f:
|
for line in f:
|
||||||
coco_results.append(json.loads(line))
|
coco_results.append(json.loads(line))
|
||||||
|
|
||||||
coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
|
coco_gt = COCO(str(BASEDIR / "annotations/instances_val2017.json"))
|
||||||
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
|
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
|
||||||
unique_list = [json.loads(s) for s in set_of_json]
|
unique_list = [json.loads(s) for s in set_of_json]
|
||||||
|
|
||||||
with open(f'{json_result_file}.flattend', "w") as f:
|
with open(f"{json_result_file}.flattend", "w") as f:
|
||||||
json.dump(unique_list, f)
|
json.dump(unique_list, f)
|
||||||
|
|
||||||
coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
|
coco_dt = coco_gt.loadRes(str(f"{json_result_file}.flattend"))
|
||||||
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
||||||
coco_eval.evaluate()
|
coco_eval.evaluate()
|
||||||
coco_eval.accumulate()
|
coco_eval.accumulate()
|
||||||
coco_eval.summarize()
|
coco_eval.summarize()
|
||||||
return coco_eval
|
return coco_eval
|
||||||
|
|
||||||
|
|
||||||
def iterate(files, bs=1):
|
def iterate(files, bs=1):
|
||||||
batch = []
|
batch = []
|
||||||
for file in files:
|
for file in files:
|
||||||
batch.append(file)
|
batch.append(file)
|
||||||
if len(batch) >= bs: yield batch; batch = []
|
if len(batch) >= bs:
|
||||||
if len(batch) > 0: yield batch; batch = []
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if len(batch) > 0:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
|
|
@ -9,36 +9,45 @@ BASEDIR = pathlib.Path(__file__).parent / "imagenet"
|
||||||
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
||||||
cir = {v[0]: int(k) for k, v in ci.items()}
|
cir = {v[0]: int(k) for k, v in ci.items()}
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def get_train_files():
|
def get_train_files():
|
||||||
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
|
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
|
||||||
return [(BASEDIR / "train" / x) for x in train_files]
|
return [(BASEDIR / "train" / x) for x in train_files]
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def get_val_files():
|
def get_val_files():
|
||||||
val_files = glob.glob(str(BASEDIR / "val/*/*"))
|
val_files = glob.glob(str(BASEDIR / "val/*/*"))
|
||||||
return val_files
|
return val_files
|
||||||
|
|
||||||
|
|
||||||
# rrc = transforms.RandomResizedCrop(224)
|
# rrc = transforms.RandomResizedCrop(224)
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
|
|
||||||
|
|
||||||
def image_load(fn):
|
def image_load(fn):
|
||||||
img = Image.open(fn).convert('RGB')
|
img = Image.open(fn).convert("RGB")
|
||||||
img = F.resize(img, 256, Image.BILINEAR)
|
img = F.resize(img, 256, Image.BILINEAR)
|
||||||
img = F.center_crop(img, 224)
|
img = F.center_crop(img, 224)
|
||||||
ret = np.array(img)
|
ret = np.array(img)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def iterate(bs=32, val=True, shuffle=True):
|
def iterate(bs=32, val=True, shuffle=True):
|
||||||
files = get_val_files() if val else get_train_files()
|
files = get_val_files() if val else get_train_files()
|
||||||
order = list(range(0, len(files)))
|
order = list(range(0, len(files)))
|
||||||
if shuffle: random.shuffle(order)
|
if shuffle:
|
||||||
|
random.shuffle(order)
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
p = Pool(16)
|
p = Pool(16)
|
||||||
for i in range(0, len(files), bs):
|
for i in range(0, len(files), bs):
|
||||||
X = p.map(image_load, [files[i] for i in order[i : i + bs]])
|
X = p.map(image_load, [files[i] for i in order[i : i + bs]])
|
||||||
Y = [cir[files[i].split("/")[-2]] for i in order[i : i + bs]]
|
Y = [cir[files[i].split("/")[-2]] for i in order[i : i + bs]]
|
||||||
yield (np.array(X), np.array(Y))
|
yield (np.array(X), np.array(Y))
|
||||||
|
|
||||||
|
|
||||||
def fetch_batch(bs, val=False):
|
def fetch_batch(bs, val=False):
|
||||||
files = get_val_files() if val else get_train_files()
|
files = get_val_files() if val else get_train_files()
|
||||||
samp = np.random.randint(0, len(files), size=(bs))
|
samp = np.random.randint(0, len(files), size=(bs))
|
||||||
|
@ -47,7 +56,7 @@ def fetch_batch(bs, val=False):
|
||||||
Y = [cir[x.split("/")[0]] for x in files]
|
Y = [cir[x.split("/")[0]] for x in files]
|
||||||
return np.array(X), np.array(Y)
|
return np.array(X), np.array(Y)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
X, Y = fetch_batch(64)
|
X, Y = fetch_batch(64)
|
||||||
print(X.shape, Y)
|
print(X.shape, Y)
|
||||||
|
|
||||||
|
|
|
@ -4,17 +4,26 @@ from pathlib import Path
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import tarfile, os
|
import tarfile, os
|
||||||
|
|
||||||
|
|
||||||
def imagenet_extract(file, path, small=False):
|
def imagenet_extract(file, path, small=False):
|
||||||
with tarfile.open(name=file) as tar:
|
with tarfile.open(name=file) as tar:
|
||||||
if small: # Show progressbar only for big files
|
if small: # Show progressbar only for big files
|
||||||
for member in tar.getmembers(): tar.extract(path=path, member=member)
|
for member in tar.getmembers():
|
||||||
|
tar.extract(path=path, member=member)
|
||||||
else:
|
else:
|
||||||
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member)
|
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())):
|
||||||
|
tar.extract(path=path, member=member)
|
||||||
tar.close()
|
tar.close()
|
||||||
|
|
||||||
|
|
||||||
def imagenet_prepare_val():
|
def imagenet_prepare_val():
|
||||||
# Read in the labels file
|
# Read in the labels file
|
||||||
with open(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt", 'r') as f:
|
with open(
|
||||||
|
Path(__file__).parent
|
||||||
|
/ "imagenet"
|
||||||
|
/ "imagenet_2012_validation_synset_labels.txt",
|
||||||
|
"r",
|
||||||
|
) as f:
|
||||||
labels = f.read().splitlines()
|
labels = f.read().splitlines()
|
||||||
f.close()
|
f.close()
|
||||||
# Get a list of images
|
# Get a list of images
|
||||||
|
@ -23,8 +32,16 @@ def imagenet_prepare_val():
|
||||||
# Create folders and move files into those
|
# Create folders and move files into those
|
||||||
for co, dir in enumerate(labels):
|
for co, dir in enumerate(labels):
|
||||||
os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True)
|
os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True)
|
||||||
os.replace(Path(__file__).parent / "imagenet" / "val" / images[co], Path(__file__).parent / "imagenet" / "val" / dir / images[co])
|
os.replace(
|
||||||
os.remove(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt")
|
Path(__file__).parent / "imagenet" / "val" / images[co],
|
||||||
|
Path(__file__).parent / "imagenet" / "val" / dir / images[co],
|
||||||
|
)
|
||||||
|
os.remove(
|
||||||
|
Path(__file__).parent
|
||||||
|
/ "imagenet"
|
||||||
|
/ "imagenet_2012_validation_synset_labels.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def imagenet_prepare_train():
|
def imagenet_prepare_train():
|
||||||
images = os.listdir(Path(__file__).parent / "imagenet" / "train")
|
images = os.listdir(Path(__file__).parent / "imagenet" / "train")
|
||||||
|
@ -32,20 +49,47 @@ def imagenet_prepare_train():
|
||||||
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
|
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
|
||||||
if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file():
|
if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file():
|
||||||
images[co] = tarf[:-4] # remove .tar from extracted tar files
|
images[co] = tarf[:-4] # remove .tar from extracted tar files
|
||||||
os.makedirs(Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True)
|
os.makedirs(
|
||||||
imagenet_extract(Path(__file__).parent / "imagenet" / "train" / tarf, Path(__file__).parent/ "imagenet" / "train" / images[co], small=True)
|
Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True
|
||||||
|
)
|
||||||
|
imagenet_extract(
|
||||||
|
Path(__file__).parent / "imagenet" / "train" / tarf,
|
||||||
|
Path(__file__).parent / "imagenet" / "train" / images[co],
|
||||||
|
small=True,
|
||||||
|
)
|
||||||
os.remove(Path(__file__).parent / "imagenet" / "train" / tarf)
|
os.remove(Path(__file__).parent / "imagenet" / "train" / tarf)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True)
|
os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True)
|
||||||
os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True)
|
os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True)
|
||||||
os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True)
|
os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True)
|
||||||
fetch("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent / "imagenet" / "imagenet_class_index.json")
|
fetch(
|
||||||
fetch("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent / "imagenet"/ "imagenet_2012_validation_synset_labels.txt")
|
"https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json",
|
||||||
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar") # 7GB
|
Path(__file__).parent / "imagenet" / "imagenet_class_index.json",
|
||||||
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "val")
|
)
|
||||||
|
fetch(
|
||||||
|
"https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt",
|
||||||
|
Path(__file__).parent
|
||||||
|
/ "imagenet"
|
||||||
|
/ "imagenet_2012_validation_synset_labels.txt",
|
||||||
|
)
|
||||||
|
fetch(
|
||||||
|
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar",
|
||||||
|
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
|
||||||
|
) # 7GB
|
||||||
|
imagenet_extract(
|
||||||
|
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
|
||||||
|
Path(__file__).parent / "imagenet" / "val",
|
||||||
|
)
|
||||||
imagenet_prepare_val()
|
imagenet_prepare_val()
|
||||||
if os.getenv('IMGNET_TRAIN', None) is not None:
|
if os.getenv("IMGNET_TRAIN", None) is not None:
|
||||||
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar") #138GB!
|
fetch(
|
||||||
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "train")
|
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar",
|
||||||
|
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
|
||||||
|
) # 138GB!
|
||||||
|
imagenet_extract(
|
||||||
|
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
|
||||||
|
Path(__file__).parent / "imagenet" / "train",
|
||||||
|
)
|
||||||
imagenet_prepare_train()
|
imagenet_prepare_train()
|
||||||
|
|
|
@ -23,41 +23,70 @@ mv kits extra/datasets
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def get_val_files():
|
def get_val_files():
|
||||||
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
|
data = fetch(
|
||||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
|
"https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt"
|
||||||
|
).read_text()
|
||||||
|
return sorted(
|
||||||
|
[x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_pair(file_path):
|
def load_pair(file_path):
|
||||||
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
|
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(
|
||||||
|
file_path / "segmentation.nii.gz"
|
||||||
|
)
|
||||||
image_spacings = image.header["pixdim"][1:4].tolist()
|
image_spacings = image.header["pixdim"][1:4].tolist()
|
||||||
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
|
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
|
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
|
||||||
return image, label, image_spacings
|
return image, label, image_spacings
|
||||||
|
|
||||||
|
|
||||||
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
|
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
|
||||||
if image_spacings != target_spacing:
|
if image_spacings != target_spacing:
|
||||||
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
|
spc_arr, targ_arr, shp_arr = (
|
||||||
|
np.array(image_spacings),
|
||||||
|
np.array(target_spacing),
|
||||||
|
np.array(image.shape[1:]),
|
||||||
|
)
|
||||||
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
|
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
|
||||||
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
|
image = F.interpolate(
|
||||||
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
|
torch.from_numpy(np.expand_dims(image, axis=0)),
|
||||||
|
size=new_shape,
|
||||||
|
mode="trilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
label = F.interpolate(
|
||||||
|
torch.from_numpy(np.expand_dims(label, axis=0)),
|
||||||
|
size=new_shape,
|
||||||
|
mode="nearest",
|
||||||
|
)
|
||||||
image = np.squeeze(image.numpy(), axis=0)
|
image = np.squeeze(image.numpy(), axis=0)
|
||||||
label = np.squeeze(label.numpy(), axis=0)
|
label = np.squeeze(label.numpy(), axis=0)
|
||||||
return image, label
|
return image, label
|
||||||
|
|
||||||
|
|
||||||
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
|
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
|
||||||
image = np.clip(image, min_clip, max_clip)
|
image = np.clip(image, min_clip, max_clip)
|
||||||
image = (image - mean) / std
|
image = (image - mean) / std
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
|
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
|
||||||
current_shape = image.shape[1:]
|
current_shape = image.shape[1:]
|
||||||
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
|
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
|
||||||
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
|
paddings = [(0, 0)] + [
|
||||||
|
(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)
|
||||||
|
]
|
||||||
image = np.pad(image, paddings, mode="edge")
|
image = np.pad(image, paddings, mode="edge")
|
||||||
label = np.pad(label, paddings, mode="edge")
|
label = np.pad(label, paddings, mode="edge")
|
||||||
return image, label
|
return image, label
|
||||||
|
|
||||||
|
|
||||||
def preprocess(file_path):
|
def preprocess(file_path):
|
||||||
image, label, image_spacings = load_pair(file_path)
|
image, label, image_spacings = load_pair(file_path)
|
||||||
image, label = resample3d(image, label, image_spacings)
|
image, label = resample3d(image, label, image_spacings)
|
||||||
|
@ -65,16 +94,20 @@ def preprocess(file_path):
|
||||||
image, label = pad_to_min_shape(image, label)
|
image, label = pad_to_min_shape(image, label)
|
||||||
return image, label
|
return image, label
|
||||||
|
|
||||||
|
|
||||||
def iterate(val=True, shuffle=False):
|
def iterate(val=True, shuffle=False):
|
||||||
if not val: raise NotImplementedError
|
if not val:
|
||||||
|
raise NotImplementedError
|
||||||
files = get_val_files()
|
files = get_val_files()
|
||||||
order = list(range(0, len(files)))
|
order = list(range(0, len(files)))
|
||||||
if shuffle: random.shuffle(order)
|
if shuffle:
|
||||||
|
random.shuffle(order)
|
||||||
for file in files:
|
for file in files:
|
||||||
X, Y = preprocess(file)
|
X, Y = preprocess(file)
|
||||||
X = np.expand_dims(X, axis=0)
|
X = np.expand_dims(X, axis=0)
|
||||||
yield (X, Y)
|
yield (X, Y)
|
||||||
|
|
||||||
|
|
||||||
def gaussian_kernel(n, std):
|
def gaussian_kernel(n, std):
|
||||||
gaussian_1d = signal.gaussian(n, std)
|
gaussian_1d = signal.gaussian(n, std)
|
||||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||||
|
@ -84,14 +117,44 @@ def gaussian_kernel(n, std):
|
||||||
gaussian_3d /= gaussian_3d.max()
|
gaussian_3d /= gaussian_3d.max()
|
||||||
return gaussian_3d
|
return gaussian_3d
|
||||||
|
|
||||||
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
|
|
||||||
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
|
|
||||||
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
|
|
||||||
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
|
|
||||||
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
|
|
||||||
|
|
||||||
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
|
def pad_input(
|
||||||
|
volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3
|
||||||
|
):
|
||||||
|
bounds = [
|
||||||
|
(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)
|
||||||
|
]
|
||||||
|
bounds = [
|
||||||
|
bounds[i]
|
||||||
|
if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i]
|
||||||
|
else bounds[i] + strides[i]
|
||||||
|
for i in range(dim)
|
||||||
|
]
|
||||||
|
paddings = [
|
||||||
|
bounds[2] // 2,
|
||||||
|
bounds[2] - bounds[2] // 2,
|
||||||
|
bounds[1] // 2,
|
||||||
|
bounds[1] - bounds[1] // 2,
|
||||||
|
bounds[0] // 2,
|
||||||
|
bounds[0] - bounds[0] // 2,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
return (
|
||||||
|
F.pad(
|
||||||
|
torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val
|
||||||
|
).numpy(),
|
||||||
|
paddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sliding_window_inference(
|
||||||
|
model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5
|
||||||
|
):
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
|
|
||||||
mdl_run = TinyJit(lambda x: model(x).realize())
|
mdl_run = TinyJit(lambda x: model(x).realize())
|
||||||
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
||||||
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
|
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
|
||||||
|
@ -119,13 +182,40 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
|
||||||
for i in range(0, strides[0] * size[0], strides[0]):
|
for i in range(0, strides[0] * size[0], strides[0]):
|
||||||
for j in range(0, strides[1] * size[1], strides[1]):
|
for j in range(0, strides[1] * size[1], strides[1]):
|
||||||
for k in range(0, strides[2] * size[2], strides[2]):
|
for k in range(0, strides[2] * size[2], strides[2]):
|
||||||
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
|
out = mdl_run(
|
||||||
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
|
Tensor(
|
||||||
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
|
inputs[
|
||||||
|
...,
|
||||||
|
i : roi_shape[0] + i,
|
||||||
|
j : roi_shape[1] + j,
|
||||||
|
k : roi_shape[2] + k,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
).numpy()
|
||||||
|
result[
|
||||||
|
...,
|
||||||
|
i : roi_shape[0] + i,
|
||||||
|
j : roi_shape[1] + j,
|
||||||
|
k : roi_shape[2] + k,
|
||||||
|
] += (
|
||||||
|
out * norm_patch
|
||||||
|
)
|
||||||
|
norm_map[
|
||||||
|
...,
|
||||||
|
i : roi_shape[0] + i,
|
||||||
|
j : roi_shape[1] + j,
|
||||||
|
k : roi_shape[2] + k,
|
||||||
|
] += norm_patch
|
||||||
result /= norm_map
|
result /= norm_map
|
||||||
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
|
result = result[
|
||||||
|
...,
|
||||||
|
paddings[4] : image_shape[0] + paddings[4],
|
||||||
|
paddings[2] : image_shape[1] + paddings[2],
|
||||||
|
paddings[0] : image_shape[2] + paddings[0],
|
||||||
|
]
|
||||||
return result, labels
|
return result, labels
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
for X, Y in iterate():
|
for X, Y in iterate():
|
||||||
print(X.shape, Y.shape)
|
print(X.shape, Y.shape)
|
||||||
|
|
|
@ -19,17 +19,30 @@ BASEDIR = pathlib.Path(__file__).parent / "librispeech"
|
||||||
with open(BASEDIR / "dev-clean-wav.json") as f:
|
with open(BASEDIR / "dev-clean-wav.json") as f:
|
||||||
ci = json.load(f)
|
ci = json.load(f)
|
||||||
|
|
||||||
FILTER_BANK = np.expand_dims(librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0)
|
FILTER_BANK = np.expand_dims(
|
||||||
|
librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0
|
||||||
|
)
|
||||||
WINDOW = librosa.filters.get_window("hann", 320)
|
WINDOW = librosa.filters.get_window("hann", 320)
|
||||||
|
|
||||||
|
|
||||||
def feature_extract(x, x_lens):
|
def feature_extract(x, x_lens):
|
||||||
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
|
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
|
||||||
|
|
||||||
# pre-emphasis
|
# pre-emphasis
|
||||||
x = np.concatenate((np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1)
|
x = np.concatenate(
|
||||||
|
(np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1
|
||||||
|
)
|
||||||
|
|
||||||
# stft
|
# stft
|
||||||
x = librosa.stft(x, n_fft=512, window=WINDOW, hop_length=160, win_length=320, center=True, pad_mode="reflect")
|
x = librosa.stft(
|
||||||
|
x,
|
||||||
|
n_fft=512,
|
||||||
|
window=WINDOW,
|
||||||
|
hop_length=160,
|
||||||
|
win_length=320,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
)
|
||||||
x = np.stack((x.real, x.imag), axis=-1)
|
x = np.stack((x.real, x.imag), axis=-1)
|
||||||
|
|
||||||
# power spectrum
|
# power spectrum
|
||||||
|
@ -56,18 +69,24 @@ def feature_extract(x, x_lens):
|
||||||
features_mean[i, :] = features[i, :, : x_lens[i]].mean(axis=1)
|
features_mean[i, :] = features[i, :, : x_lens[i]].mean(axis=1)
|
||||||
features_std[i, :] = features[i, :, : x_lens[i]].std(axis=1, ddof=1)
|
features_std[i, :] = features[i, :, : x_lens[i]].std(axis=1, ddof=1)
|
||||||
features_std += 1e-5
|
features_std += 1e-5
|
||||||
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(features_std, 2)
|
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(
|
||||||
|
features_std, 2
|
||||||
|
)
|
||||||
|
|
||||||
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
|
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def load_wav(file):
|
def load_wav(file):
|
||||||
sample = soundfile.read(file)[0].astype(np.float32)
|
sample = soundfile.read(file)[0].astype(np.float32)
|
||||||
return sample, sample.shape[0]
|
return sample, sample.shape[0]
|
||||||
|
|
||||||
|
|
||||||
def iterate(bs=1, start=0):
|
def iterate(bs=1, start=0):
|
||||||
print(f"there are {len(ci)} samples in the dataset")
|
print(f"there are {len(ci)} samples in the dataset")
|
||||||
for i in range(start, len(ci), bs):
|
for i in range(start, len(ci), bs):
|
||||||
samples, sample_lens = zip(*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]])
|
samples, sample_lens = zip(
|
||||||
|
*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]]
|
||||||
|
)
|
||||||
samples = list(samples)
|
samples = list(samples)
|
||||||
# pad to same length
|
# pad to same length
|
||||||
max_len = max(sample_lens)
|
max_len = max(sample_lens)
|
||||||
|
@ -75,7 +94,10 @@ def iterate(bs=1, start=0):
|
||||||
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
|
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
|
||||||
samples, sample_lens = np.array(samples), np.array(sample_lens)
|
samples, sample_lens = np.array(samples), np.array(sample_lens)
|
||||||
|
|
||||||
yield feature_extract(samples, sample_lens), np.array([v["transcript"] for v in ci[i : i + bs]])
|
yield feature_extract(samples, sample_lens), np.array(
|
||||||
|
[v["transcript"] for v in ci[i : i + bs]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
X, Y = next(iterate())
|
X, Y = next(iterate())
|
||||||
|
|
|
@ -12,133 +12,441 @@ import concurrent.futures
|
||||||
|
|
||||||
BASEDIR = pathlib.Path(__file__).parent / "open-images-v6-mlperf"
|
BASEDIR = pathlib.Path(__file__).parent / "open-images-v6-mlperf"
|
||||||
BUCKET_NAME = "open-images-dataset"
|
BUCKET_NAME = "open-images-dataset"
|
||||||
BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
BBOX_ANNOTATIONS_URL = (
|
||||||
MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
"https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
||||||
MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana',
|
)
|
||||||
'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle',
|
MAP_CLASSES_URL = (
|
||||||
'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot',
|
"https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
||||||
'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread',
|
)
|
||||||
'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry',
|
MLPERF_CLASSES = [
|
||||||
'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart',
|
"Airplane",
|
||||||
'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken',
|
"Antelope",
|
||||||
'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin',
|
"Apple",
|
||||||
'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store',
|
"Backpack",
|
||||||
'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard',
|
"Balloon",
|
||||||
'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly',
|
"Banana",
|
||||||
'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant',
|
"Barrel",
|
||||||
'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork',
|
"Baseball bat",
|
||||||
'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses',
|
"Baseball glove",
|
||||||
'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar',
|
"Bee",
|
||||||
'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels',
|
"Beer",
|
||||||
'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard',
|
"Bench",
|
||||||
'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair',
|
"Bicycle",
|
||||||
'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream',
|
"Bicycle helmet",
|
||||||
'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite',
|
"Bicycle wheel",
|
||||||
'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse',
|
"Billboard",
|
||||||
'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror',
|
"Book",
|
||||||
'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule',
|
"Bookcase",
|
||||||
'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building',
|
"Boot",
|
||||||
'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen',
|
"Bottle",
|
||||||
'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow',
|
"Bowl",
|
||||||
'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle',
|
"Bowling equipment",
|
||||||
'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion',
|
"Box",
|
||||||
'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard',
|
"Boy",
|
||||||
'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon',
|
"Brassiere",
|
||||||
'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light',
|
"Bread",
|
||||||
'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan',
|
"Broccoli",
|
||||||
'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television',
|
"Bronze sculpture",
|
||||||
'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower',
|
"Bull",
|
||||||
'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase',
|
"Bus",
|
||||||
'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch',
|
"Bust",
|
||||||
'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman',
|
"Butterfly",
|
||||||
'Zebra', 'Zucchini',
|
"Cabinetry",
|
||||||
|
"Cake",
|
||||||
|
"Camel",
|
||||||
|
"Camera",
|
||||||
|
"Candle",
|
||||||
|
"Candy",
|
||||||
|
"Cannon",
|
||||||
|
"Canoe",
|
||||||
|
"Carrot",
|
||||||
|
"Cart",
|
||||||
|
"Castle",
|
||||||
|
"Cat",
|
||||||
|
"Cattle",
|
||||||
|
"Cello",
|
||||||
|
"Chair",
|
||||||
|
"Cheese",
|
||||||
|
"Chest of drawers",
|
||||||
|
"Chicken",
|
||||||
|
"Christmas tree",
|
||||||
|
"Coat",
|
||||||
|
"Cocktail",
|
||||||
|
"Coffee",
|
||||||
|
"Coffee cup",
|
||||||
|
"Coffee table",
|
||||||
|
"Coin",
|
||||||
|
"Common sunflower",
|
||||||
|
"Computer keyboard",
|
||||||
|
"Computer monitor",
|
||||||
|
"Convenience store",
|
||||||
|
"Cookie",
|
||||||
|
"Countertop",
|
||||||
|
"Cowboy hat",
|
||||||
|
"Crab",
|
||||||
|
"Crocodile",
|
||||||
|
"Cucumber",
|
||||||
|
"Cupboard",
|
||||||
|
"Curtain",
|
||||||
|
"Deer",
|
||||||
|
"Desk",
|
||||||
|
"Dinosaur",
|
||||||
|
"Dog",
|
||||||
|
"Doll",
|
||||||
|
"Dolphin",
|
||||||
|
"Door",
|
||||||
|
"Dragonfly",
|
||||||
|
"Drawer",
|
||||||
|
"Dress",
|
||||||
|
"Drum",
|
||||||
|
"Duck",
|
||||||
|
"Eagle",
|
||||||
|
"Earrings",
|
||||||
|
"Egg (Food)",
|
||||||
|
"Elephant",
|
||||||
|
"Falcon",
|
||||||
|
"Fedora",
|
||||||
|
"Flag",
|
||||||
|
"Flowerpot",
|
||||||
|
"Football",
|
||||||
|
"Football helmet",
|
||||||
|
"Fork",
|
||||||
|
"Fountain",
|
||||||
|
"French fries",
|
||||||
|
"French horn",
|
||||||
|
"Frog",
|
||||||
|
"Giraffe",
|
||||||
|
"Girl",
|
||||||
|
"Glasses",
|
||||||
|
"Goat",
|
||||||
|
"Goggles",
|
||||||
|
"Goldfish",
|
||||||
|
"Gondola",
|
||||||
|
"Goose",
|
||||||
|
"Grape",
|
||||||
|
"Grapefruit",
|
||||||
|
"Guitar",
|
||||||
|
"Hamburger",
|
||||||
|
"Handbag",
|
||||||
|
"Harbor seal",
|
||||||
|
"Headphones",
|
||||||
|
"Helicopter",
|
||||||
|
"High heels",
|
||||||
|
"Hiking equipment",
|
||||||
|
"Horse",
|
||||||
|
"House",
|
||||||
|
"Houseplant",
|
||||||
|
"Human arm",
|
||||||
|
"Human beard",
|
||||||
|
"Human body",
|
||||||
|
"Human ear",
|
||||||
|
"Human eye",
|
||||||
|
"Human face",
|
||||||
|
"Human foot",
|
||||||
|
"Human hair",
|
||||||
|
"Human hand",
|
||||||
|
"Human head",
|
||||||
|
"Human leg",
|
||||||
|
"Human mouth",
|
||||||
|
"Human nose",
|
||||||
|
"Ice cream",
|
||||||
|
"Jacket",
|
||||||
|
"Jeans",
|
||||||
|
"Jellyfish",
|
||||||
|
"Juice",
|
||||||
|
"Kitchen & dining room table",
|
||||||
|
"Kite",
|
||||||
|
"Lamp",
|
||||||
|
"Lantern",
|
||||||
|
"Laptop",
|
||||||
|
"Lavender (Plant)",
|
||||||
|
"Lemon",
|
||||||
|
"Light bulb",
|
||||||
|
"Lighthouse",
|
||||||
|
"Lily",
|
||||||
|
"Lion",
|
||||||
|
"Lipstick",
|
||||||
|
"Lizard",
|
||||||
|
"Man",
|
||||||
|
"Maple",
|
||||||
|
"Microphone",
|
||||||
|
"Mirror",
|
||||||
|
"Mixing bowl",
|
||||||
|
"Mobile phone",
|
||||||
|
"Monkey",
|
||||||
|
"Motorcycle",
|
||||||
|
"Muffin",
|
||||||
|
"Mug",
|
||||||
|
"Mule",
|
||||||
|
"Mushroom",
|
||||||
|
"Musical keyboard",
|
||||||
|
"Necklace",
|
||||||
|
"Nightstand",
|
||||||
|
"Office building",
|
||||||
|
"Orange",
|
||||||
|
"Owl",
|
||||||
|
"Oyster",
|
||||||
|
"Paddle",
|
||||||
|
"Palm tree",
|
||||||
|
"Parachute",
|
||||||
|
"Parrot",
|
||||||
|
"Pen",
|
||||||
|
"Penguin",
|
||||||
|
"Personal flotation device",
|
||||||
|
"Piano",
|
||||||
|
"Picture frame",
|
||||||
|
"Pig",
|
||||||
|
"Pillow",
|
||||||
|
"Pizza",
|
||||||
|
"Plate",
|
||||||
|
"Platter",
|
||||||
|
"Porch",
|
||||||
|
"Poster",
|
||||||
|
"Pumpkin",
|
||||||
|
"Rabbit",
|
||||||
|
"Rifle",
|
||||||
|
"Roller skates",
|
||||||
|
"Rose",
|
||||||
|
"Salad",
|
||||||
|
"Sandal",
|
||||||
|
"Saucer",
|
||||||
|
"Saxophone",
|
||||||
|
"Scarf",
|
||||||
|
"Sea lion",
|
||||||
|
"Sea turtle",
|
||||||
|
"Sheep",
|
||||||
|
"Shelf",
|
||||||
|
"Shirt",
|
||||||
|
"Shorts",
|
||||||
|
"Shrimp",
|
||||||
|
"Sink",
|
||||||
|
"Skateboard",
|
||||||
|
"Ski",
|
||||||
|
"Skull",
|
||||||
|
"Skyscraper",
|
||||||
|
"Snake",
|
||||||
|
"Sock",
|
||||||
|
"Sofa bed",
|
||||||
|
"Sparrow",
|
||||||
|
"Spider",
|
||||||
|
"Spoon",
|
||||||
|
"Sports uniform",
|
||||||
|
"Squirrel",
|
||||||
|
"Stairs",
|
||||||
|
"Stool",
|
||||||
|
"Strawberry",
|
||||||
|
"Street light",
|
||||||
|
"Studio couch",
|
||||||
|
"Suit",
|
||||||
|
"Sun hat",
|
||||||
|
"Sunglasses",
|
||||||
|
"Surfboard",
|
||||||
|
"Sushi",
|
||||||
|
"Swan",
|
||||||
|
"Swimming pool",
|
||||||
|
"Swimwear",
|
||||||
|
"Tank",
|
||||||
|
"Tap",
|
||||||
|
"Taxi",
|
||||||
|
"Tea",
|
||||||
|
"Teddy bear",
|
||||||
|
"Television",
|
||||||
|
"Tent",
|
||||||
|
"Tie",
|
||||||
|
"Tiger",
|
||||||
|
"Tin can",
|
||||||
|
"Tire",
|
||||||
|
"Toilet",
|
||||||
|
"Tomato",
|
||||||
|
"Tortoise",
|
||||||
|
"Tower",
|
||||||
|
"Traffic light",
|
||||||
|
"Train",
|
||||||
|
"Tripod",
|
||||||
|
"Truck",
|
||||||
|
"Trumpet",
|
||||||
|
"Umbrella",
|
||||||
|
"Van",
|
||||||
|
"Vase",
|
||||||
|
"Vehicle registration plate",
|
||||||
|
"Violin",
|
||||||
|
"Wall clock",
|
||||||
|
"Waste container",
|
||||||
|
"Watch",
|
||||||
|
"Whale",
|
||||||
|
"Wheel",
|
||||||
|
"Wheelchair",
|
||||||
|
"Whiteboard",
|
||||||
|
"Window",
|
||||||
|
"Wine",
|
||||||
|
"Wine glass",
|
||||||
|
"Woman",
|
||||||
|
"Zebra",
|
||||||
|
"Zucchini",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def openimages():
|
def openimages():
|
||||||
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
|
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
|
||||||
if not ann_file.is_file():
|
if not ann_file.is_file():
|
||||||
fetch_openimages(ann_file)
|
fetch_openimages(ann_file)
|
||||||
return ann_file
|
return ann_file
|
||||||
|
|
||||||
|
|
||||||
# this slows down the conversion a lot!
|
# this slows down the conversion a lot!
|
||||||
# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py
|
# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py
|
||||||
def extract_dims(path): return Image.open(path).size[::-1]
|
def extract_dims(path):
|
||||||
|
return Image.open(path).size[::-1]
|
||||||
|
|
||||||
def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES):
|
|
||||||
|
def export_to_coco(
|
||||||
|
class_map,
|
||||||
|
annotations,
|
||||||
|
image_list,
|
||||||
|
dataset_path,
|
||||||
|
output_path,
|
||||||
|
classes=MLPERF_CLASSES,
|
||||||
|
):
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
|
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
|
||||||
categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"])
|
categories_map = pd.DataFrame(
|
||||||
class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner")
|
[(i, c) for i, c in enumerate(classes)],
|
||||||
|
columns=["category_id", "category_name"],
|
||||||
|
)
|
||||||
|
class_map = class_map.merge(
|
||||||
|
categories_map, left_on="DisplayName", right_on="category_name", how="inner"
|
||||||
|
)
|
||||||
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
|
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
|
||||||
annotations = annotations.merge(class_map, on="LabelName", how="inner")
|
annotations = annotations.merge(class_map, on="LabelName", how="inner")
|
||||||
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
|
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
|
||||||
annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand")
|
annotations[["height", "width"]] = annotations.apply(
|
||||||
|
lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"),
|
||||||
|
axis=1,
|
||||||
|
result_type="expand",
|
||||||
|
)
|
||||||
|
|
||||||
# Images
|
# Images
|
||||||
imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None}
|
imgs = [
|
||||||
for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows())
|
{
|
||||||
|
"id": int(id + 1),
|
||||||
|
"file_name": f"{image_id}.jpg",
|
||||||
|
"height": row["height"],
|
||||||
|
"width": row["width"],
|
||||||
|
"license": None,
|
||||||
|
"coco_url": None,
|
||||||
|
}
|
||||||
|
for (id, image_id), row in (
|
||||||
|
annotations.groupby(["image_id", "ImageID"]).first().iterrows()
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Annotations
|
# Annotations
|
||||||
annots = []
|
annots = []
|
||||||
for i, row in annotations.iterrows():
|
for i, row in annotations.iterrows():
|
||||||
xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]]
|
xmin, ymin, xmax, ymax, img_w, img_h = [
|
||||||
x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h
|
row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]
|
||||||
coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h}
|
]
|
||||||
coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]})
|
x, y, w, h = (
|
||||||
|
xmin * img_w,
|
||||||
|
ymin * img_h,
|
||||||
|
(xmax - xmin) * img_w,
|
||||||
|
(ymax - ymin) * img_h,
|
||||||
|
)
|
||||||
|
coco_annot = {
|
||||||
|
"id": int(i) + 1,
|
||||||
|
"image_id": int(row["image_id"] + 1),
|
||||||
|
"category_id": int(row["category_id"]),
|
||||||
|
"bbox": [x, y, w, h],
|
||||||
|
"area": w * h,
|
||||||
|
}
|
||||||
|
coco_annot.update(
|
||||||
|
{
|
||||||
|
k: row[k]
|
||||||
|
for k in [
|
||||||
|
"IsOccluded",
|
||||||
|
"IsInside",
|
||||||
|
"IsDepiction",
|
||||||
|
"IsTruncated",
|
||||||
|
"IsGroupOf",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
coco_annot["iscrowd"] = int(row["IsGroupOf"])
|
coco_annot["iscrowd"] = int(row["IsGroupOf"])
|
||||||
annots.append(coco_annot)
|
annots.append(coco_annot)
|
||||||
|
|
||||||
info = {"dataset": "openimages_mlperf", "version": "v6"}
|
info = {"dataset": "openimages_mlperf", "version": "v6"}
|
||||||
coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots}
|
coco_annotations = {
|
||||||
|
"info": info,
|
||||||
|
"licenses": [],
|
||||||
|
"categories": cats,
|
||||||
|
"images": imgs,
|
||||||
|
"annotations": annots,
|
||||||
|
}
|
||||||
with open(output_path, "w") as fp:
|
with open(output_path, "w") as fp:
|
||||||
json.dump(coco_annotations, fp)
|
json.dump(coco_annotations, fp)
|
||||||
|
|
||||||
|
|
||||||
def get_image_list(class_map, annotations, classes=MLPERF_CLASSES):
|
def get_image_list(class_map, annotations, classes=MLPERF_CLASSES):
|
||||||
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
|
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
|
||||||
image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique()
|
image_ids = annotations[np.isin(annotations["LabelName"], labels)][
|
||||||
|
"ImageID"
|
||||||
|
].unique()
|
||||||
return image_ids
|
return image_ids
|
||||||
|
|
||||||
|
|
||||||
def download_image(bucket, image_id, data_dir):
|
def download_image(bucket, image_id, data_dir):
|
||||||
try:
|
try:
|
||||||
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
|
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
|
||||||
except botocore.exceptions.ClientError as exception:
|
except botocore.exceptions.ClientError as exception:
|
||||||
sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}")
|
sys.exit(
|
||||||
|
f"ERROR when downloading image `validation/{image_id}`: {str(exception)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_openimages(output_fn):
|
def fetch_openimages(output_fn):
|
||||||
bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
|
bucket = boto3.resource(
|
||||||
|
"s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)
|
||||||
|
).Bucket(BUCKET_NAME)
|
||||||
|
|
||||||
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
|
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
|
||||||
annotations_dir.mkdir(parents=True, exist_ok=True)
|
annotations_dir.mkdir(parents=True, exist_ok=True)
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1]
|
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split("/")[-1]
|
||||||
fetch(BBOX_ANNOTATIONS_URL, annotations_fn)
|
fetch(BBOX_ANNOTATIONS_URL, annotations_fn)
|
||||||
annotations = pd.read_csv(annotations_fn)
|
annotations = pd.read_csv(annotations_fn)
|
||||||
|
|
||||||
classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1]
|
classmap_fn = annotations_dir / MAP_CLASSES_URL.split("/")[-1]
|
||||||
fetch(MAP_CLASSES_URL, classmap_fn)
|
fetch(MAP_CLASSES_URL, classmap_fn)
|
||||||
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
|
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
|
||||||
|
|
||||||
image_list = get_image_list(class_map, annotations)
|
image_list = get_image_list(class_map, annotations)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list]
|
futures = [
|
||||||
for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))):
|
executor.submit(download_image, bucket, image_id, data_dir)
|
||||||
|
for image_id in image_list
|
||||||
|
]
|
||||||
|
for future in (
|
||||||
|
t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))
|
||||||
|
):
|
||||||
t.set_description(f"Downloading images")
|
t.set_description(f"Downloading images")
|
||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
print("Converting annotations to COCO format...")
|
print("Converting annotations to COCO format...")
|
||||||
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
|
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
|
||||||
|
|
||||||
|
|
||||||
def image_load(fn):
|
def image_load(fn):
|
||||||
img_folder = BASEDIR / "validation/data"
|
img_folder = BASEDIR / "validation/data"
|
||||||
img = Image.open(img_folder / fn).convert('RGB')
|
img = Image.open(img_folder / fn).convert("RGB")
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
|
|
||||||
ret = F.resize(img, size=(800, 800))
|
ret = F.resize(img, size=(800, 800))
|
||||||
ret = np.array(ret)
|
ret = np.array(ret)
|
||||||
return ret, img.size[::-1]
|
return ret, img.size[::-1]
|
||||||
|
|
||||||
|
|
||||||
def prepare_target(annotations, img_id, img_size):
|
def prepare_target(annotations, img_id, img_size):
|
||||||
boxes = [annot["bbox"] for annot in annotations]
|
boxes = [annot["bbox"] for annot in annotations]
|
||||||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
|
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
|
||||||
|
@ -150,7 +458,13 @@ def prepare_target(annotations, img_id, img_size):
|
||||||
classes = [annot["category_id"] for annot in annotations]
|
classes = [annot["category_id"] for annot in annotations]
|
||||||
classes = np.array(classes, dtype=np.int64)
|
classes = np.array(classes, dtype=np.int64)
|
||||||
classes = classes[keep]
|
classes = classes[keep]
|
||||||
return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size}
|
return {
|
||||||
|
"boxes": boxes,
|
||||||
|
"labels": classes,
|
||||||
|
"image_id": img_id,
|
||||||
|
"image_size": img_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def iterate(coco, bs=8):
|
def iterate(coco, bs=8):
|
||||||
image_ids = sorted(coco.imgs.keys())
|
image_ids = sorted(coco.imgs.keys())
|
||||||
|
|
|
@ -13,10 +13,15 @@ if __name__ == "__main__":
|
||||||
assert x.shape[0] == y.shape[0]
|
assert x.shape[0] == y.shape[0]
|
||||||
bs = x.shape[0]
|
bs = x.shape[0]
|
||||||
if X is None:
|
if X is None:
|
||||||
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8)
|
X = Tensor.empty(
|
||||||
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64)
|
sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8
|
||||||
|
)
|
||||||
|
Y = Tensor.empty(
|
||||||
|
sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64
|
||||||
|
)
|
||||||
print(X.shape, Y.shape)
|
print(X.shape, Y.shape)
|
||||||
X[idx : idx + bs].assign(x)
|
X[idx : idx + bs].assign(x)
|
||||||
Y[idx : idx + bs].assign(y)
|
Y[idx : idx + bs].assign(y)
|
||||||
idx += bs
|
idx += bs
|
||||||
if idx >= sz: break
|
if idx >= sz:
|
||||||
|
break
|
||||||
|
|
|
@ -6,9 +6,14 @@ import numpy as np
|
||||||
from tinygrad.helpers import fetch
|
from tinygrad.helpers import fetch
|
||||||
|
|
||||||
BASEDIR = Path(__file__).parent / "squad"
|
BASEDIR = Path(__file__).parent / "squad"
|
||||||
|
|
||||||
|
|
||||||
def init_dataset():
|
def init_dataset():
|
||||||
os.makedirs(BASEDIR, exist_ok=True)
|
os.makedirs(BASEDIR, exist_ok=True)
|
||||||
fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
|
fetch(
|
||||||
|
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
|
||||||
|
BASEDIR / "dev-v1.1.json",
|
||||||
|
)
|
||||||
with open(BASEDIR / "dev-v1.1.json") as f:
|
with open(BASEDIR / "dev-v1.1.json") as f:
|
||||||
data = json.load(f)["data"]
|
data = json.load(f)["data"]
|
||||||
|
|
||||||
|
@ -32,14 +37,17 @@ def init_dataset():
|
||||||
qa_id = qa["id"]
|
qa_id = qa["id"]
|
||||||
q_text = qa["question"]
|
q_text = qa["question"]
|
||||||
|
|
||||||
examples.append({
|
examples.append(
|
||||||
|
{
|
||||||
"id": qa_id,
|
"id": qa_id,
|
||||||
"question": q_text,
|
"question": q_text,
|
||||||
"context": doc_tokens,
|
"context": doc_tokens,
|
||||||
"answers": list(map(lambda x: x["text"], qa["answers"]))
|
"answers": list(map(lambda x: x["text"], qa["answers"])),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||||
best_score, best_span_index = None, None
|
best_score, best_span_index = None, None
|
||||||
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
||||||
|
@ -56,6 +64,7 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||||
best_span_index = di
|
best_span_index = di
|
||||||
return cur_span_index == best_span_index
|
return cur_span_index == best_span_index
|
||||||
|
|
||||||
|
|
||||||
def convert_example_to_features(example, tokenizer):
|
def convert_example_to_features(example, tokenizer):
|
||||||
query_tokens = tokenizer.tokenize(example["question"])
|
query_tokens = tokenizer.tokenize(example["question"])
|
||||||
|
|
||||||
|
@ -101,7 +110,9 @@ def convert_example_to_features(example, tokenizer):
|
||||||
for i in range(doc_length):
|
for i in range(doc_length):
|
||||||
split_token_index = doc_start + i
|
split_token_index = doc_start + i
|
||||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||||
token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
|
token_is_max_context[len(tokens)] = _check_is_max_context(
|
||||||
|
doc_spans, di, split_token_index
|
||||||
|
)
|
||||||
tokens.append(all_doc_tokens[split_token_index])
|
tokens.append(all_doc_tokens[split_token_index])
|
||||||
segment_ids.append(1)
|
segment_ids.append(1)
|
||||||
tokens.append("[SEP]")
|
tokens.append("[SEP]")
|
||||||
|
@ -119,17 +130,24 @@ def convert_example_to_features(example, tokenizer):
|
||||||
assert len(input_mask) == 384
|
assert len(input_mask) == 384
|
||||||
assert len(segment_ids) == 384
|
assert len(segment_ids) == 384
|
||||||
|
|
||||||
outputs.append({
|
outputs.append(
|
||||||
|
{
|
||||||
"input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
|
"input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
|
||||||
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
|
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(
|
||||||
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
|
np.float32
|
||||||
|
),
|
||||||
|
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(
|
||||||
|
np.float32
|
||||||
|
),
|
||||||
"token_to_orig_map": token_to_orig_map,
|
"token_to_orig_map": token_to_orig_map,
|
||||||
"token_is_max_context": token_is_max_context,
|
"token_is_max_context": token_is_max_context,
|
||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def iterate(tokenizer, start=0):
|
def iterate(tokenizer, start=0):
|
||||||
examples = init_dataset()
|
examples = init_dataset()
|
||||||
print(f"there are {len(examples)} pairs in the dataset")
|
print(f"there are {len(examples)} pairs in the dataset")
|
||||||
|
@ -140,8 +158,11 @@ def iterate(tokenizer, start=0):
|
||||||
# we need to yield all features here as the f1 score is the maximum over all features
|
# we need to yield all features here as the f1 score is the maximum over all features
|
||||||
yield features, example
|
yield features, example
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))
|
tokenizer = BertTokenizer(
|
||||||
|
str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt")
|
||||||
|
)
|
||||||
|
|
||||||
X, Y = next(iterate(tokenizer))
|
X, Y = next(iterate(tokenizer))
|
||||||
print(" ".join(X[0]["tokens"]))
|
print(" ".join(X[0]["tokens"]))
|
||||||
|
|
|
@ -5,11 +5,13 @@ from tinygrad.helpers import DEBUG, getenv
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
# this needs to be called before everything else if you are using distributed
|
# this needs to be called before everything else if you are using distributed
|
||||||
def preinit():
|
def preinit():
|
||||||
os.environ["DELAYED_RUNTIME_INIT"] = "1"
|
os.environ["DELAYED_RUNTIME_INIT"] = "1"
|
||||||
mp.set_start_method("spawn")
|
mp.set_start_method("spawn")
|
||||||
|
|
||||||
|
|
||||||
# out-of-band communication/synchronization
|
# out-of-band communication/synchronization
|
||||||
class _OOB:
|
class _OOB:
|
||||||
def __init__(self, pipes: List[Tuple[Connection, Connection]]):
|
def __init__(self, pipes: List[Tuple[Connection, Connection]]):
|
||||||
|
@ -22,14 +24,18 @@ class _OOB:
|
||||||
# receive some data from a target rank, blocks until data is received
|
# receive some data from a target rank, blocks until data is received
|
||||||
def recv(self, target_rank: int) -> Any:
|
def recv(self, target_rank: int) -> Any:
|
||||||
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
|
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
|
||||||
|
|
||||||
|
|
||||||
OOB: Optional[_OOB] = None
|
OOB: Optional[_OOB] = None
|
||||||
|
|
||||||
|
|
||||||
def init_oob(world_size: int):
|
def init_oob(world_size: int):
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
|
|
||||||
global OOB
|
global OOB
|
||||||
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
|
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
|
||||||
|
|
||||||
|
|
||||||
# this runs in the spawned process so we can do all the delayed runtime initialization
|
# this runs in the spawned process so we can do all the delayed runtime initialization
|
||||||
def _process_wrap(rank: int, device: str, oob: _OOB, fn: Callable, args=()):
|
def _process_wrap(rank: int, device: str, oob: _OOB, fn: Callable, args=()):
|
||||||
# setup the rank
|
# setup the rank
|
||||||
|
@ -41,19 +47,27 @@ def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
|
||||||
|
|
||||||
# do specific runtime initialization for distributed
|
# do specific runtime initialization for distributed
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1])
|
|
||||||
|
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(
|
||||||
|
device.split(":")[-1]
|
||||||
|
)
|
||||||
if "GPU" in device:
|
if "GPU" in device:
|
||||||
from tinygrad.runtime.ops_gpu import CL
|
from tinygrad.runtime.ops_gpu import CL
|
||||||
|
|
||||||
CL.post_init(device_num)
|
CL.post_init(device_num)
|
||||||
elif "HIP" in device:
|
elif "HIP" in device:
|
||||||
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(device_num)
|
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(
|
||||||
if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}")
|
device_num
|
||||||
|
)
|
||||||
|
if DEBUG >= 1:
|
||||||
|
print(f"distributed process {rank} initialized runtime for device {device}")
|
||||||
|
|
||||||
# convert device to be process specific
|
# convert device to be process specific
|
||||||
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
|
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
|
||||||
|
|
||||||
fn(*args)
|
fn(*args)
|
||||||
|
|
||||||
|
|
||||||
# wrapper around mp.Process that initializes the runtime
|
# wrapper around mp.Process that initializes the runtime
|
||||||
def spawn(rank: int, device: str, fn: Callable, args=()) -> mp.Process:
|
def spawn(rank: int, device: str, fn: Callable, args=()) -> mp.Process:
|
||||||
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()
|
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()
|
||||||
|
|
|
@ -3,6 +3,7 @@ from tinygrad.helpers import getenv
|
||||||
|
|
||||||
from extra.dist import world
|
from extra.dist import world
|
||||||
|
|
||||||
|
|
||||||
def allreduce(t: Tensor) -> Tensor:
|
def allreduce(t: Tensor) -> Tensor:
|
||||||
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
||||||
|
|
||||||
|
@ -11,7 +12,9 @@ def allreduce(t:Tensor) -> Tensor:
|
||||||
|
|
||||||
# pad to evenly divide
|
# pad to evenly divide
|
||||||
if flattened.shape[0] % WORLD_SIZE != 0:
|
if flattened.shape[0] % WORLD_SIZE != 0:
|
||||||
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
flattened = Tensor.cat(
|
||||||
|
flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE))
|
||||||
|
)
|
||||||
|
|
||||||
# chunk
|
# chunk
|
||||||
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
||||||
|
|
|
@ -4,15 +4,18 @@ from multiprocessing import shared_memory
|
||||||
from tinygrad.helpers import DEBUG, colored, getenv
|
from tinygrad.helpers import DEBUG, colored, getenv
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut
|
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import gpuctypes.hip as hip
|
import gpuctypes.hip as hip
|
||||||
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
|
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
|
||||||
except: RawHIPBuffer = None
|
except:
|
||||||
|
RawHIPBuffer = None
|
||||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||||
from tinygrad.jit import CacheCollector
|
from tinygrad.jit import CacheCollector
|
||||||
from tinygrad.tensor import Tensor, Function
|
from tinygrad.tensor import Tensor, Function
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
# match the function signature of JITRunner so we can put it in the cache
|
# match the function signature of JITRunner so we can put it in the cache
|
||||||
def __send_rb(args, variables=None, wait=False, jit=False):
|
def __send_rb(args, variables=None, wait=False, jit=False):
|
||||||
x, target_rank, y = args[:3]
|
x, target_rank, y = args[:3]
|
||||||
|
@ -20,19 +23,31 @@ def __send_rb(args, variables=None, wait=False, jit=False):
|
||||||
check(hip.hipSetDevice(x._device))
|
check(hip.hipSetDevice(x._device))
|
||||||
check(hip.hipDeviceSynchronize())
|
check(hip.hipDeviceSynchronize())
|
||||||
else:
|
else:
|
||||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
if isinstance(x, RawBufferCopyInOut):
|
||||||
else: y.fromCPU(x.toCPU())
|
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||||
|
else:
|
||||||
|
y.fromCPU(x.toCPU())
|
||||||
dist.OOB.send(None, target_rank)
|
dist.OOB.send(None, target_rank)
|
||||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
if DEBUG >= 2:
|
||||||
|
print(
|
||||||
|
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def __recv_rb(args, variables=None, wait=False, jit=False):
|
def __recv_rb(args, variables=None, wait=False, jit=False):
|
||||||
x, target_rank, y = args[:3]
|
x, target_rank, y = args[:3]
|
||||||
dist.OOB.recv(target_rank)
|
dist.OOB.recv(target_rank)
|
||||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||||
x._transfer(y)
|
x._transfer(y)
|
||||||
elif isinstance(x, RawBuffer): x._copyin(y.toCPU())
|
elif isinstance(x, RawBuffer):
|
||||||
else: x.fromCPU(y.toCPU())
|
x._copyin(y.toCPU())
|
||||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}")
|
else:
|
||||||
|
x.fromCPU(y.toCPU())
|
||||||
|
if DEBUG >= 2:
|
||||||
|
print(
|
||||||
|
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# send a rawbuffer from out rank to the target rank
|
# send a rawbuffer from out rank to the target rank
|
||||||
def _send_rb(x: RawBuffer, target_rank: int):
|
def _send_rb(x: RawBuffer, target_rank: int):
|
||||||
|
@ -40,7 +55,11 @@ def _send_rb(x:RawBuffer, target_rank:int):
|
||||||
# send ipc handle
|
# send ipc handle
|
||||||
check(hip.hipSetDevice(x._device))
|
check(hip.hipSetDevice(x._device))
|
||||||
check(hip.hipDeviceSynchronize())
|
check(hip.hipDeviceSynchronize())
|
||||||
check(hip.hipIpcGetMemHandle(ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf))
|
check(
|
||||||
|
hip.hipIpcGetMemHandle(
|
||||||
|
ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf
|
||||||
|
)
|
||||||
|
)
|
||||||
dist.OOB.send((handle, x._device), target_rank)
|
dist.OOB.send((handle, x._device), target_rank)
|
||||||
|
|
||||||
# jit support
|
# jit support
|
||||||
|
@ -48,20 +67,26 @@ def _send_rb(x:RawBuffer, target_rank:int):
|
||||||
CacheCollector.add(__send_rb, [x, target_rank, None], {})
|
CacheCollector.add(__send_rb, [x, target_rank, None], {})
|
||||||
else:
|
else:
|
||||||
# create shared memory
|
# create shared memory
|
||||||
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
|
shm_name = (
|
||||||
|
s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)
|
||||||
|
).name
|
||||||
s.close()
|
s.close()
|
||||||
|
|
||||||
# copy the buffer into shared memory
|
# copy the buffer into shared memory
|
||||||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
||||||
# fast path when we can directly copyout
|
# fast path when we can directly copyout
|
||||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
if isinstance(x, RawBufferCopyInOut):
|
||||||
else: y.fromCPU(x.toCPU())
|
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||||
|
else:
|
||||||
|
y.fromCPU(x.toCPU())
|
||||||
|
|
||||||
dist.OOB.send(shm_name, target_rank)
|
dist.OOB.send(shm_name, target_rank)
|
||||||
|
|
||||||
# jit support
|
# jit support
|
||||||
CacheCollector.add(__send_rb, [x, target_rank, y], {})
|
CacheCollector.add(__send_rb, [x, target_rank, y], {})
|
||||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
if DEBUG >= 2:
|
||||||
|
print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||||
|
|
||||||
|
|
||||||
# receive a rawbuffer from the target rank
|
# receive a rawbuffer from the target rank
|
||||||
def _recv_rb(x: RawBuffer, target_rank: int):
|
def _recv_rb(x: RawBuffer, target_rank: int):
|
||||||
|
@ -69,7 +94,9 @@ def _recv_rb(x:RawBuffer, target_rank:int):
|
||||||
# open ipc handle
|
# open ipc handle
|
||||||
handle, y_device = dist.OOB.recv(target_rank)
|
handle, y_device = dist.OOB.recv(target_rank)
|
||||||
check(hip.hipSetDevice(y_device))
|
check(hip.hipSetDevice(y_device))
|
||||||
check(hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0))
|
check(
|
||||||
|
hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0)
|
||||||
|
)
|
||||||
|
|
||||||
# build a new buffer
|
# build a new buffer
|
||||||
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
|
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
|
||||||
|
@ -81,34 +108,50 @@ def _recv_rb(x:RawBuffer, target_rank:int):
|
||||||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
||||||
|
|
||||||
# fast path when we can directly copyin
|
# fast path when we can directly copyin
|
||||||
if isinstance(x, RawBuffer): x._copyin(y.toCPU())
|
if isinstance(x, RawBuffer):
|
||||||
else: x.fromCPU(y.toCPU())
|
x._copyin(y.toCPU())
|
||||||
|
else:
|
||||||
|
x.fromCPU(y.toCPU())
|
||||||
|
|
||||||
# jit support
|
# jit support
|
||||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
if DEBUG >= 2:
|
||||||
|
print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
||||||
|
|
||||||
|
|
||||||
# sends a lazybuffer from our rank to the target rank
|
# sends a lazybuffer from our rank to the target rank
|
||||||
def _send_lb(x: LazyBuffer, target_rank: int) -> None:
|
def _send_lb(x: LazyBuffer, target_rank: int) -> None:
|
||||||
assert x.st.contiguous and x.realized, "sending buffer must be contiguous and realized"
|
assert (
|
||||||
|
x.st.contiguous and x.realized
|
||||||
|
), "sending buffer must be contiguous and realized"
|
||||||
_send_rb(x.realized, target_rank)
|
_send_rb(x.realized, target_rank)
|
||||||
|
|
||||||
|
|
||||||
# receive a lazybuffer from the target rank
|
# receive a lazybuffer from the target rank
|
||||||
def _recv_lb(x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
def _recv_lb(x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||||
assert x.st.contiguous and x.realized, "receiving buffer must be contiguous and realized"
|
assert (
|
||||||
|
x.st.contiguous and x.realized
|
||||||
|
), "receiving buffer must be contiguous and realized"
|
||||||
_recv_rb(x.realized, target_rank)
|
_recv_rb(x.realized, target_rank)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Send(Function):
|
class Send(Function):
|
||||||
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||||
self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype
|
self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype
|
||||||
_send_lb(x, target_rank)
|
_send_lb(x, target_rank)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Recv(Function):
|
class Recv(Function):
|
||||||
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||||
self.target_rank = target_rank
|
self.target_rank = target_rank
|
||||||
return _recv_lb(x, target_rank)
|
return _recv_lb(x, target_rank)
|
||||||
|
|
||||||
def send(x:Tensor, target_rank:int) -> Tensor: return Send.apply(x.contiguous().realize(), target_rank=target_rank)
|
|
||||||
def recv(x:Tensor, target_rank:int) -> Tensor: return Recv.apply(x.contiguous().realize(), target_rank=target_rank)
|
def send(x: Tensor, target_rank: int) -> Tensor:
|
||||||
|
return Send.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||||
|
|
||||||
|
|
||||||
|
def recv(x: Tensor, target_rank: int) -> Tensor:
|
||||||
|
return Recv.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||||
|
|
|
@ -17,5 +17,10 @@ if __name__ == "__main__":
|
||||||
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
||||||
for f in cur3.fetchall():
|
for f in cur3.fetchall():
|
||||||
v = pickle.loads(f[-1])
|
v = pickle.loads(f[-1])
|
||||||
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
|
print(
|
||||||
|
" ",
|
||||||
|
len(f[0]) if isinstance(f[0], str) else f[0],
|
||||||
|
f[1:-1],
|
||||||
|
str(v)[0:50],
|
||||||
|
)
|
||||||
# print(f"{len(k):10d}, {sk} -> {v}")
|
# print(f"{len(k):10d}, {sk} -> {v}")
|
||||||
|
|
|
@ -7,77 +7,190 @@ import json
|
||||||
|
|
||||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
|
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
|
||||||
|
|
||||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
|
||||||
|
def compile_net(
|
||||||
|
run: TinyJit, special_names: Dict[int, str]
|
||||||
|
) -> Tuple[
|
||||||
|
Dict[str, str],
|
||||||
|
List[Tuple[str, List[str], List[int]]],
|
||||||
|
Dict[str, Tuple[int, DType, int]],
|
||||||
|
Dict[str, Tensor],
|
||||||
|
]:
|
||||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||||
for ji in run.jit_cache:
|
for ji in run.jit_cache:
|
||||||
fxn = ji.prg
|
fxn = ji.prg
|
||||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
functions[
|
||||||
|
fxn.name
|
||||||
|
] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||||
cargs = []
|
cargs = []
|
||||||
for i, arg in enumerate(ji.rawbufs):
|
for i, arg in enumerate(ji.rawbufs):
|
||||||
key = id(arg)
|
key = id(arg)
|
||||||
if key not in bufs:
|
if key not in bufs:
|
||||||
if key in special_names:
|
if key in special_names:
|
||||||
bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key)
|
bufs[key] = (
|
||||||
|
special_names[key],
|
||||||
|
arg.size * arg.dtype.itemsize,
|
||||||
|
arg.dtype,
|
||||||
|
key,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key)
|
bufs[key] = (
|
||||||
|
f"buf_{bufnum}",
|
||||||
|
arg.size * arg.dtype.itemsize,
|
||||||
|
arg.dtype,
|
||||||
|
key,
|
||||||
|
)
|
||||||
bufnum += 1
|
bufnum += 1
|
||||||
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
|
if i > 0:
|
||||||
|
bufs_to_save[
|
||||||
|
bufs[key][0]
|
||||||
|
] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||||
cargs.append(bufs[key][0])
|
cargs.append(bufs[key][0])
|
||||||
statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size))
|
statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size))
|
||||||
|
|
||||||
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
|
return (
|
||||||
|
functions,
|
||||||
|
statements,
|
||||||
|
{name: (size, dtype, key) for (name, size, dtype, key) in bufs.values()},
|
||||||
|
bufs_to_save,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def jit_model(model, *args) -> Tuple[TinyJit, Dict[int, str]]:
|
def jit_model(model, *args) -> Tuple[TinyJit, Dict[int, str]]:
|
||||||
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
|
assert hasattr(model, "forward") or callable(
|
||||||
|
model
|
||||||
|
), "model needs a forward function"
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def run(*x):
|
def run(*x):
|
||||||
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
|
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
|
||||||
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
assert (
|
||||||
|
isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor)
|
||||||
|
), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||||
out = [out] if isinstance(out, Tensor) else out
|
out = [out] if isinstance(out, Tensor) else out
|
||||||
return [o.realize() for o in out]
|
return [o.realize() for o in out]
|
||||||
|
|
||||||
# twice to run the JIT
|
# twice to run the JIT
|
||||||
for _ in range(2): the_output = run(*args)
|
for _ in range(2):
|
||||||
|
the_output = run(*args)
|
||||||
special_names = {}
|
special_names = {}
|
||||||
|
|
||||||
# hack to put the inputs back
|
# hack to put the inputs back
|
||||||
for (j, i), idx in run.input_replace.items():
|
for (j, i), idx in run.input_replace.items():
|
||||||
realized_input = args[idx].lazydata.realized
|
realized_input = args[idx].lazydata.realized
|
||||||
run.jit_cache[j].rawbufs[i] = realized_input
|
run.jit_cache[j].rawbufs[i] = realized_input
|
||||||
special_names[id(realized_input)] = f'input{idx}'
|
special_names[id(realized_input)] = f"input{idx}"
|
||||||
|
|
||||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||||
for i, output in enumerate(the_output):
|
for i, output in enumerate(the_output):
|
||||||
special_names[id(output.lazydata.realized)] = f'output{i}'
|
special_names[id(output.lazydata.realized)] = f"output{i}"
|
||||||
return run, special_names
|
return run, special_names
|
||||||
|
|
||||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
|
|
||||||
|
def export_model_clang(
|
||||||
|
functions: Dict[str, str],
|
||||||
|
statements: Dict[str, Tuple[str, int, int]],
|
||||||
|
bufs: Dict[str, Tuple[str, int, int]],
|
||||||
|
bufs_to_save: Dict[str, Tensor],
|
||||||
|
input_names: List[str],
|
||||||
|
output_names: List[str],
|
||||||
|
) -> str:
|
||||||
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
||||||
|
|
||||||
cprog = [CLANG_PROGRAM_HEADER]
|
cprog = [CLANG_PROGRAM_HEADER]
|
||||||
|
|
||||||
for name, cl in bufs_to_save.items():
|
for name, cl in bufs_to_save.items():
|
||||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
weight = "".join(["\\x%02X" % x for x in bytes(cl._buf)])
|
||||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
cprog.append(f'unsigned char {name}_data[] = "{weight}";')
|
||||||
|
|
||||||
inputs = ", ".join([f'float* {input}' for input in input_names])
|
inputs = ", ".join([f"float* {input}" for input in input_names])
|
||||||
outputs = ", ".join([f'float* {output}' for output in output_names])
|
outputs = ", ".join([f"float* {output}" for output in output_names])
|
||||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
|
cprog += [
|
||||||
|
f"float {name}[{len}];"
|
||||||
|
if name not in bufs_to_save
|
||||||
|
else f"float *{name} = (float *){name}_data;"
|
||||||
|
for name, (len, dtype, _key) in bufs.items()
|
||||||
|
if name not in ["input", "outputs"]
|
||||||
|
]
|
||||||
cprog += list(functions.values())
|
cprog += list(functions.values())
|
||||||
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
cprog += (
|
||||||
return '\n'.join(cprog)
|
[f"void net({inputs}, {outputs}) {{"]
|
||||||
|
+ [
|
||||||
|
f"{name}({', '.join(args)});"
|
||||||
|
for (name, args, _global_size, _local_size) in statements
|
||||||
|
]
|
||||||
|
+ ["}"]
|
||||||
|
)
|
||||||
|
return "\n".join(cprog)
|
||||||
|
|
||||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
|
||||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
def export_model_webgpu(
|
||||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names
|
||||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
) -> Tuple[str, int, int]:
|
||||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
kernel_code = "\n\n".join(
|
||||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
|
[
|
||||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
f"const {key} = `{code.replace(key, 'main')}`;"
|
||||||
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
|
for key, code in functions.items()
|
||||||
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
|
]
|
||||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
)
|
||||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
kernel_names = ", ".join(
|
||||||
return f"""
|
[name for (name, _args, _global_size, _local_size) in statements]
|
||||||
|
)
|
||||||
|
kernel_calls = "\n ".join(
|
||||||
|
[
|
||||||
|
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
|
||||||
|
for i, (_name, args, global_size, _local_size) in enumerate(statements)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_bufs = "\n ".join(
|
||||||
|
[
|
||||||
|
f"const {name} = "
|
||||||
|
+ (
|
||||||
|
f"createEmptyBuf(device, {size});"
|
||||||
|
if _key not in weight_names
|
||||||
|
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))"
|
||||||
|
)
|
||||||
|
+ ";"
|
||||||
|
for name, (size, dtype, _key) in bufs.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gpu_write_bufs = "\n ".join(
|
||||||
|
[
|
||||||
|
f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
|
||||||
|
for i, input_name in enumerate(input_names)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
input_writers = "\n ".join(
|
||||||
|
[
|
||||||
|
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
|
||||||
|
+ f"_{inp_name});"
|
||||||
|
+ f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);"
|
||||||
|
for i, inp_name in enumerate(input_names)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
gpu_read_bufs = "\n ".join(
|
||||||
|
[
|
||||||
|
f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});"
|
||||||
|
for i, output_name in enumerate(output_names)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
outbuf_copies = "\n ".join(
|
||||||
|
[
|
||||||
|
f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);"
|
||||||
|
for i, output_name in enumerate(output_names)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_readers = "\n ".join(
|
||||||
|
[
|
||||||
|
f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();"
|
||||||
|
for i in range(len(output_names))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_return = "[{}]".format(
|
||||||
|
",".join([f"resultBuffer{i}" for i in range(len(output_names))])
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
f"""
|
||||||
const getTensorMetadata = (safetensorBuffer) => {{
|
const getTensorMetadata = (safetensorBuffer) => {{
|
||||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||||
|
@ -134,10 +247,15 @@ const setupNet = async (device, safetensor) => {{
|
||||||
return {output_return};
|
return {output_return};
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
"""
|
||||||
|
+ f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def export_model(model, target: str, *inputs):
|
def export_model(model, target: str, *inputs):
|
||||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
assert (
|
||||||
|
Device.DEFAULT in EXPORT_SUPPORTED_DEVICE
|
||||||
|
), "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||||
run, special_names = jit_model(model, *inputs)
|
run, special_names = jit_model(model, *inputs)
|
||||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||||
state = get_state_dict(model)
|
state = get_state_dict(model)
|
||||||
|
@ -146,34 +264,56 @@ def export_model(model, target:str, *inputs):
|
||||||
output_names = [name for _, name in special_names.items() if "output" in name]
|
output_names = [name for _, name in special_names.items() if "output" in name]
|
||||||
prg = ""
|
prg = ""
|
||||||
if target == "clang":
|
if target == "clang":
|
||||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
|
prg = export_model_clang(
|
||||||
|
functions, statements, bufs, bufs_to_save, input_names, output_names
|
||||||
|
)
|
||||||
elif target == "webgpu":
|
elif target == "webgpu":
|
||||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
|
prg = export_model_webgpu(
|
||||||
|
functions,
|
||||||
|
statements,
|
||||||
|
bufs,
|
||||||
|
bufs_to_save,
|
||||||
|
weight_names,
|
||||||
|
input_names,
|
||||||
|
output_names,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prg = json.dumps({
|
prg = json.dumps(
|
||||||
|
{
|
||||||
"backend": Device.DEFAULT,
|
"backend": Device.DEFAULT,
|
||||||
"inputs": [{
|
"inputs": [
|
||||||
"size": bufs[name][0],
|
{"size": bufs[name][0], "dtype": bufs[name][1].name}
|
||||||
"dtype": bufs[name][1].name
|
for name in input_names
|
||||||
} for name in input_names],
|
],
|
||||||
"outputs": [{
|
"outputs": [
|
||||||
"size": bufs[name][0],
|
{"size": bufs[name][0], "dtype": bufs[name][1].name}
|
||||||
"dtype": bufs[name][1].name
|
for name in output_names
|
||||||
} for name in output_names],
|
],
|
||||||
"functions": functions,
|
"functions": functions,
|
||||||
"statements": [{
|
"statements": [
|
||||||
|
{
|
||||||
"kernel": kernel,
|
"kernel": kernel,
|
||||||
"args": args,
|
"args": args,
|
||||||
"global_size": global_size,
|
"global_size": global_size,
|
||||||
"local_size": local_size
|
"local_size": local_size,
|
||||||
} for (kernel, args, global_size, local_size) in statements],
|
}
|
||||||
|
for (kernel, args, global_size, local_size) in statements
|
||||||
|
],
|
||||||
"buffers": {
|
"buffers": {
|
||||||
name: {
|
name: {
|
||||||
"size": size,
|
"size": size,
|
||||||
"dtype": dtype.name,
|
"dtype": dtype.name,
|
||||||
"id": weight_names[_key] if _key in weight_names else ""
|
"id": weight_names[_key] if _key in weight_names else "",
|
||||||
} for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]
|
|
||||||
}
|
}
|
||||||
})
|
for name, (size, dtype, _key) in bufs.items()
|
||||||
|
if name not in ["input", "outputs"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state
|
return (
|
||||||
|
prg,
|
||||||
|
{input: bufs[input][0] for input in input_names},
|
||||||
|
{output: bufs[output][0] for output in output_names},
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
np.set_printoptions(linewidth=160)
|
np.set_printoptions(linewidth=160)
|
||||||
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
|
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
|
||||||
from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const
|
from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const
|
||||||
|
@ -11,18 +12,61 @@ from llvmlite import ir # type: ignore
|
||||||
# https://github.com/corsix/amx/blob/main/Instructions.md
|
# https://github.com/corsix/amx/blob/main/Instructions.md
|
||||||
# 12 lines for AMX support
|
# 12 lines for AMX support
|
||||||
from functools import partialmethod
|
from functools import partialmethod
|
||||||
|
|
||||||
|
|
||||||
class AMX:
|
class AMX:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True)
|
def nop_op_imm5(op, imm5, builder):
|
||||||
|
builder.asm(
|
||||||
|
ir.FunctionType(ir.VoidType(), []),
|
||||||
|
f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}",
|
||||||
|
"",
|
||||||
|
tuple(),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True)
|
def op_gpr(op, builder, gpr):
|
||||||
|
builder.asm(
|
||||||
|
ir.FunctionType(ir.VoidType(), [ir.IntType(64)]),
|
||||||
|
f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0",
|
||||||
|
"r",
|
||||||
|
(gpr,),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
|
set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
|
||||||
ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3)
|
ldx, ldy, stx, sty = (
|
||||||
ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7)
|
partialmethod(op_gpr, 0),
|
||||||
|
partialmethod(op_gpr, 1),
|
||||||
|
partialmethod(op_gpr, 2),
|
||||||
|
partialmethod(op_gpr, 3),
|
||||||
|
)
|
||||||
|
ldz, stz, ldzi, stzi = (
|
||||||
|
partialmethod(op_gpr, 4),
|
||||||
|
partialmethod(op_gpr, 5),
|
||||||
|
partialmethod(op_gpr, 6),
|
||||||
|
partialmethod(op_gpr, 7),
|
||||||
|
)
|
||||||
extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
|
extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
|
||||||
fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13)
|
fma64, fms64, fma32, fms32 = (
|
||||||
mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16)
|
partialmethod(op_gpr, 10),
|
||||||
vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22)
|
partialmethod(op_gpr, 11),
|
||||||
|
partialmethod(op_gpr, 12),
|
||||||
|
partialmethod(op_gpr, 13),
|
||||||
|
)
|
||||||
|
mac16, fma16, fms16 = (
|
||||||
|
partialmethod(op_gpr, 14),
|
||||||
|
partialmethod(op_gpr, 15),
|
||||||
|
partialmethod(op_gpr, 16),
|
||||||
|
)
|
||||||
|
vecint, vecfp, matint, matfp, genlut = (
|
||||||
|
partialmethod(op_gpr, 18),
|
||||||
|
partialmethod(op_gpr, 19),
|
||||||
|
partialmethod(op_gpr, 20),
|
||||||
|
partialmethod(op_gpr, 21),
|
||||||
|
partialmethod(op_gpr, 22),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
N = 4096
|
N = 4096
|
||||||
|
@ -54,7 +98,11 @@ c = LLVMBuffer.fromCPU(np.zeros(256))
|
||||||
bufs = [c, a, b]
|
bufs = [c, a, b]
|
||||||
|
|
||||||
module = ir.Module(name=__file__)
|
module = ir.Module(name=__file__)
|
||||||
func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
|
func = ir.Function(
|
||||||
|
module,
|
||||||
|
ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()] * 3),
|
||||||
|
name="exec",
|
||||||
|
)
|
||||||
|
|
||||||
# load all
|
# load all
|
||||||
entry = ir.IRBuilder(func.append_basic_block(name="entry"))
|
entry = ir.IRBuilder(func.append_basic_block(name="entry"))
|
||||||
|
@ -69,7 +117,19 @@ y.add_incoming(int_const(0), entry._block)
|
||||||
yp = loop_1_exit.add(y, int_const(32 * 2))
|
yp = loop_1_exit.add(y, int_const(32 * 2))
|
||||||
y.add_incoming(yp, loop_1_exit._block)
|
y.add_incoming(yp, loop_1_exit._block)
|
||||||
|
|
||||||
prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch")
|
prefetch_function = ir.Function(
|
||||||
|
module,
|
||||||
|
ir.FunctionType(
|
||||||
|
ir.VoidType(),
|
||||||
|
[
|
||||||
|
ir.PointerType(ir.FloatType()),
|
||||||
|
ir.IntType(32),
|
||||||
|
ir.IntType(32),
|
||||||
|
ir.IntType(32),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
name="llvm.prefetch",
|
||||||
|
)
|
||||||
|
|
||||||
xptr = y
|
xptr = y
|
||||||
addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
||||||
|
@ -79,7 +139,12 @@ addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
||||||
|
|
||||||
AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1 << 62), addr))
|
AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1 << 62), addr))
|
||||||
xptr = loop_1_exit.add(xptr, int_const(32))
|
xptr = loop_1_exit.add(xptr, int_const(32))
|
||||||
AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))))
|
AMX.ldy(
|
||||||
|
loop_1_exit,
|
||||||
|
loop_1_exit.add(
|
||||||
|
int_const(1 << 62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
|
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
|
||||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16 * 4) << 10))
|
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16 * 4) << 10))
|
||||||
|
@ -93,7 +158,9 @@ AMX.clr(exit)
|
||||||
|
|
||||||
entry.branch(loop_1._block)
|
entry.branch(loop_1._block)
|
||||||
loop_1.branch(loop_1_exit._block)
|
loop_1.branch(loop_1_exit._block)
|
||||||
loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block)
|
loop_1_exit.cbranch(
|
||||||
|
loop_1_exit.icmp_unsigned("==", yp, int_const(N * N)), exit._block, loop_1._block
|
||||||
|
)
|
||||||
exit.ret(int_const(0))
|
exit.ret(int_const(0))
|
||||||
|
|
||||||
cfunc = LLVM().exec(module, bufs, N**2)
|
cfunc = LLVM().exec(module, bufs, N**2)
|
||||||
|
@ -185,4 +252,3 @@ np.testing.assert_allclose(c.toCPU()[:sn.shape[0]], sn, atol=1e-4, rtol=1e-4)
|
||||||
print(cn.astype(np.int64))
|
print(cn.astype(np.int64))
|
||||||
np.testing.assert_allclose(c.toCPU(), cn, atol=1e-4, rtol=1e-5)
|
np.testing.assert_allclose(c.toCPU(), cn, atol=1e-4, rtol=1e-5)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
os.environ["CUDA"] = "1"
|
os.environ["CUDA"] = "1"
|
||||||
from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda
|
from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda
|
||||||
|
|
||||||
|
@ -21,7 +22,10 @@ c = RawCUDABuffer.fromCPU(np.ones((N,N),dtype=np.float32))
|
||||||
FLOPS = N * N * N * 2
|
FLOPS = N * N * N * 2
|
||||||
BW = N * N * 3 * 4
|
BW = N * N * 3 * 4
|
||||||
|
|
||||||
prog = CUDAProgram("wmma_example", compile_cuda(f"""
|
prog = CUDAProgram(
|
||||||
|
"wmma_example",
|
||||||
|
compile_cuda(
|
||||||
|
f"""
|
||||||
#include <mma.h>
|
#include <mma.h>
|
||||||
using namespace nvcuda;
|
using namespace nvcuda;
|
||||||
|
|
||||||
|
@ -88,10 +92,23 @@ __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLO
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
"""))
|
"""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
global_size, local_size = [(N // 16) // 4, (N // 16) // 4], [32, 1, 1]
|
global_size, local_size = [(N // 16) // 4, (N // 16) // 4], [32, 1, 1]
|
||||||
tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)])
|
tm = min(
|
||||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
[
|
||||||
|
prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)
|
||||||
|
for _ in range(20)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||||
|
)
|
||||||
|
|
||||||
np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), c.toCPU().reshape((N,N)).T, atol=1e-2)
|
np.testing.assert_allclose(
|
||||||
|
na.T.astype(np.float32) @ nb.T.astype(np.float32),
|
||||||
|
c.toCPU().reshape((N, N)).T,
|
||||||
|
atol=1e-2,
|
||||||
|
)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from tinygrad.helpers import partition, GlobalCounters, Context, getenv, prod, d
|
||||||
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
|
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
|
||||||
from tinygrad.ops import LoadOps, ReduceOps
|
from tinygrad.ops import LoadOps, ReduceOps
|
||||||
|
|
||||||
|
|
||||||
def single_kernel():
|
def single_kernel():
|
||||||
# single kernel
|
# single kernel
|
||||||
sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4)
|
sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4)
|
||||||
|
@ -22,11 +23,17 @@ def single_kernel():
|
||||||
x = CLBuffer(prod(sz2), dtypes.imageh(sz2))
|
x = CLBuffer(prod(sz2), dtypes.imageh(sz2))
|
||||||
w = CLBuffer(prod(sz3), dtypes.imageh(sz3))
|
w = CLBuffer(prod(sz3), dtypes.imageh(sz3))
|
||||||
|
|
||||||
old = CLProgram("r_32_16_16_64_4_4_4", open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read())
|
old = CLProgram(
|
||||||
old_tms = [old([1,1,32], [16,16,1], out, x, w, wait=True)*1e6 for _ in range(5)]
|
"r_32_16_16_64_4_4_4",
|
||||||
|
open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read(),
|
||||||
|
)
|
||||||
|
old_tms = [
|
||||||
|
old([1, 1, 32], [16, 16, 1], out, x, w, wait=True) * 1e6 for _ in range(5)
|
||||||
|
]
|
||||||
print(old_tms, 67.107 / min(old_tms) * 1e3)
|
print(old_tms, 67.107 / min(old_tms) * 1e3)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
# CONV=0 PYTHONPATH="." LATEDEBUG=5 OPT=99 IMAGE=2 FLOAT16=1 NOLOCALS=1 python3 extra/fastvits/fastvits_speed.py
|
# CONV=0 PYTHONPATH="." LATEDEBUG=5 OPT=99 IMAGE=2 FLOAT16=1 NOLOCALS=1 python3 extra/fastvits/fastvits_speed.py
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# single_kernel()
|
# single_kernel()
|
||||||
|
@ -43,7 +50,11 @@ if __name__ == "__main__":
|
||||||
out = x.sequential([c1, c2, c3, c4, c5])
|
out = x.sequential([c1, c2, c3, c4, c5])
|
||||||
schedule = out.lazydata.schedule()
|
schedule = out.lazydata.schedule()
|
||||||
|
|
||||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps and any(y.op in ReduceOps for y in x.ast.get_lazyops()))
|
schedule, schedule_input = partition(
|
||||||
|
schedule,
|
||||||
|
lambda x: x.ast.op not in LoadOps
|
||||||
|
and any(y.op in ReduceOps for y in x.ast.get_lazyops()),
|
||||||
|
)
|
||||||
run_schedule(schedule_input)
|
run_schedule(schedule_input)
|
||||||
run_schedule(schedule[: getenv("CONV")])
|
run_schedule(schedule[: getenv("CONV")])
|
||||||
print("*** init done ***")
|
print("*** init done ***")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# os.environ['OMP_NUM_THREADS'] = '1'
|
# os.environ['OMP_NUM_THREADS'] = '1'
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
@ -78,5 +78,3 @@ if __name__ == "__main__":
|
||||||
new_tms.append(new([256, 1, 1], [4, 16, 1], out, x, w, b, wait=True))
|
new_tms.append(new([256, 1, 1], [4, 16, 1], out, x, w, b, wait=True))
|
||||||
|
|
||||||
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")
|
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,12 +30,21 @@ a = hipallocator.alloc(N*N*4)
|
||||||
b = hipallocator.alloc(N * N * 2)
|
b = hipallocator.alloc(N * N * 2)
|
||||||
c = hipallocator.alloc(N * N * 2)
|
c = hipallocator.alloc(N * N * 2)
|
||||||
na = np.empty(N * N, np.float32)
|
na = np.empty(N * N, np.float32)
|
||||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
nb = (
|
||||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
np.random.default_rng()
|
||||||
|
.standard_normal(size=(N, N), dtype=np.float32)
|
||||||
|
.astype(np.float16)
|
||||||
|
)
|
||||||
|
nc = (
|
||||||
|
np.random.default_rng()
|
||||||
|
.standard_normal(size=(N, N), dtype=np.float32)
|
||||||
|
.astype(np.float16)
|
||||||
|
)
|
||||||
hipallocator.copyin(b, bytearray(nb))
|
hipallocator.copyin(b, bytearray(nb))
|
||||||
hipallocator.copyin(c, bytearray(nc))
|
hipallocator.copyin(c, bytearray(nc))
|
||||||
|
|
||||||
lib = compile_hip(f"""
|
lib = compile_hip(
|
||||||
|
f"""
|
||||||
#define F32
|
#define F32
|
||||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||||
|
@ -92,10 +101,12 @@ extern "C" __global__ void __launch_bounds__ (128, 1) test(float* c, __half* a,
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
}}""")
|
}}"""
|
||||||
|
)
|
||||||
|
|
||||||
prog = HIPProgram(device, "test", lib)
|
prog = HIPProgram(device, "test", lib)
|
||||||
|
|
||||||
|
|
||||||
def timeit(fxn):
|
def timeit(fxn):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
et = fxn()
|
et = fxn()
|
||||||
|
@ -103,11 +114,28 @@ def timeit(fxn):
|
||||||
# print(f"{ret*1e6:.2f} us")
|
# print(f"{ret*1e6:.2f} us")
|
||||||
return et
|
return et
|
||||||
|
|
||||||
|
|
||||||
global_size, local_size = [N // (KX * 16 * 2), N // (KY * 16 * 2), 1], [32, 2, 2]
|
global_size, local_size = [N // (KX * 16 * 2), N // (KY * 16 * 2), 1], [32, 2, 2]
|
||||||
print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}")
|
print(
|
||||||
tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(1000)])
|
"global/local size",
|
||||||
|
global_size,
|
||||||
|
local_size,
|
||||||
|
f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}",
|
||||||
|
)
|
||||||
|
tm = min(
|
||||||
|
[
|
||||||
|
timeit(
|
||||||
|
lambda: prog(
|
||||||
|
a, b, c, global_size=global_size, local_size=local_size, wait=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(1000)
|
||||||
|
]
|
||||||
|
)
|
||||||
hipallocator.copyout(flat_mv(na.data), a)
|
hipallocator.copyout(flat_mv(na.data), a)
|
||||||
na = na.reshape(N, N)
|
na = na.reshape(N, N)
|
||||||
comp = nb.astype(np.float32) @ nc.astype(np.float32)
|
comp = nb.astype(np.float32) @ nc.astype(np.float32)
|
||||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
print(
|
||||||
|
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||||
|
)
|
||||||
np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)
|
np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)
|
||||||
|
|
|
@ -14,7 +14,12 @@ A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
|
||||||
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
||||||
|
|
||||||
OPS = DEVICES * BS * N * N * N * 2
|
OPS = DEVICES * BS * N * N * N * 2
|
||||||
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
|
|
||||||
|
|
||||||
|
def matmul(A, B):
|
||||||
|
return jnp.matmul(A, B, preferred_element_type=jnp.float32)
|
||||||
|
|
||||||
|
|
||||||
pmatmul = jax.pmap(matmul)
|
pmatmul = jax.pmap(matmul)
|
||||||
|
|
||||||
MAX_TFLOPS = 123 * DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
MAX_TFLOPS = 123 * DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
||||||
|
@ -23,5 +28,6 @@ for i in range(10):
|
||||||
C = pmatmul(A, B).block_until_ready()
|
C = pmatmul(A, B).block_until_ready()
|
||||||
et = time.perf_counter() - st
|
et = time.perf_counter() - st
|
||||||
tflops = (OPS * 1e-12) / et
|
tflops = (OPS * 1e-12) / et
|
||||||
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
|
print(
|
||||||
|
f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}"
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# os.environ["METAL"] = "1"
|
# os.environ["METAL"] = "1"
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -18,14 +19,16 @@ nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import time, torch, torch.mps
|
import time, torch, torch.mps
|
||||||
b = torch.from_numpy(nb).to('mps')
|
|
||||||
c = torch.from_numpy(nc).to('mps')
|
b = torch.from_numpy(nb).to("mps")
|
||||||
|
c = torch.from_numpy(nc).to("mps")
|
||||||
|
|
||||||
def torch_prog(b, c):
|
def torch_prog(b, c):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
a = torch.nn.functional.conv2d(b, c, padding=PADDING)
|
a = torch.nn.functional.conv2d(b, c, padding=PADDING)
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
|
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
@ -34,16 +37,23 @@ except RuntimeError:
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
|
|
||||||
b = Tensor(nb)
|
b = Tensor(nb)
|
||||||
c = Tensor(nc)
|
c = Tensor(nc)
|
||||||
|
|
||||||
|
|
||||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def tiny_jit(b, c):
|
def tiny_jit(b, c):
|
||||||
return b.conv2d(c, padding=PADDING).realize()
|
return b.conv2d(c, padding=PADDING).realize()
|
||||||
|
|
||||||
|
|
||||||
def tiny_prog(b, c):
|
def tiny_prog(b, c):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
a = tiny_jit(b, c)
|
a = tiny_jit(b, c)
|
||||||
Device[a.device].synchronize()
|
Device[a.device].synchronize()
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
|
|
||||||
tm = min([tiny_prog(b, c) for _ in range(5)])
|
tm = min([tiny_prog(b, c) for _ in range(5)])
|
||||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")
|
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["METAL"] = "1"
|
os.environ["METAL"] = "1"
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -10,15 +11,22 @@ LID = 2
|
||||||
|
|
||||||
a = RawMetalBuffer(N * N, dtypes.float32)
|
a = RawMetalBuffer(N * N, dtypes.float32)
|
||||||
|
|
||||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
nb = np.random.default_rng().standard_normal(
|
||||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
size=(N, N), dtype=np.float32
|
||||||
|
) # .astype(np.int32).astype(np.float32)
|
||||||
|
nc = np.random.default_rng().standard_normal(
|
||||||
|
size=(N, N), dtype=np.float32
|
||||||
|
) # .astype(np.int32).astype(np.float32)
|
||||||
b = RawMetalBuffer.fromCPU(nb)
|
b = RawMetalBuffer.fromCPU(nb)
|
||||||
c = RawMetalBuffer.fromCPU(nc)
|
c = RawMetalBuffer.fromCPU(nc)
|
||||||
|
|
||||||
FLOPS = N * N * N * 2
|
FLOPS = N * N * N * 2
|
||||||
BW = N * N * 3 * 4
|
BW = N * N * 3 * 4
|
||||||
|
|
||||||
prog = MetalProgram("test", compile_metal(f"""
|
prog = MetalProgram(
|
||||||
|
"test",
|
||||||
|
compile_metal(
|
||||||
|
f"""
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
@ -80,46 +88,83 @@ kernel void test(device float *a, device const float *data1, device const float
|
||||||
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
|
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
|
||||||
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
|
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
|
||||||
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
|
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
|
||||||
}}"""))
|
}}"""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def timeit(fxn):
|
def timeit(fxn):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
et = fxn()
|
et = fxn()
|
||||||
# NOTE: et doesn't contain the launch overhead
|
# NOTE: et doesn't contain the launch overhead
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
|
|
||||||
|
|
||||||
|
tm = min(
|
||||||
|
[
|
||||||
|
timeit(
|
||||||
|
lambda: prog(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
c,
|
||||||
|
global_size=[N // (8 * 4), N // (8 * 4 * LID), 1],
|
||||||
|
local_size=[32, LID, 1],
|
||||||
|
wait=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(20)
|
||||||
|
]
|
||||||
|
)
|
||||||
na = a.toCPU().reshape(N, N)
|
na = a.toCPU().reshape(N, N)
|
||||||
comp = nb @ nc
|
comp = nb @ nc
|
||||||
if N <= 32:
|
if N <= 32:
|
||||||
print(na)
|
print(na)
|
||||||
print(comp)
|
print(comp)
|
||||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
print(
|
||||||
|
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||||
|
)
|
||||||
np.testing.assert_allclose(na, comp, atol=1e-3)
|
np.testing.assert_allclose(na, comp, atol=1e-3)
|
||||||
|
|
||||||
import torch, torch.mps
|
import torch, torch.mps
|
||||||
b = torch.from_numpy(nb).to('mps')
|
|
||||||
c = torch.from_numpy(nc).to('mps')
|
b = torch.from_numpy(nb).to("mps")
|
||||||
|
c = torch.from_numpy(nc).to("mps")
|
||||||
|
|
||||||
|
|
||||||
def torch_prog(b, c):
|
def torch_prog(b, c):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
a = b @ c
|
a = b @ c
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
|
|
||||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
|
print(
|
||||||
|
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch"
|
||||||
|
)
|
||||||
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
from tinygrad.runtime.ops_metal import METAL
|
from tinygrad.runtime.ops_metal import METAL
|
||||||
|
|
||||||
b = Tensor(nb)
|
b = Tensor(nb)
|
||||||
c = Tensor(nc)
|
c = Tensor(nc)
|
||||||
|
|
||||||
|
|
||||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def tiny_jit(b, c):
|
def tiny_jit(b, c):
|
||||||
return (b @ c).realize()
|
return (b @ c).realize()
|
||||||
|
|
||||||
|
|
||||||
def tiny_prog(b, c):
|
def tiny_prog(b, c):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
a = tiny_jit(b, c)
|
a = tiny_jit(b, c)
|
||||||
METAL.synchronize()
|
METAL.synchronize()
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
|
|
||||||
tm = min([tiny_prog(b, c) for _ in range(20)])
|
tm = min([tiny_prog(b, c) for _ in range(20)])
|
||||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")
|
print(
|
||||||
|
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad"
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# os.environ["METAL"] = "1"
|
# os.environ["METAL"] = "1"
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time, torch, torch.mps
|
import time, torch, torch.mps
|
||||||
|
@ -10,6 +11,7 @@ from tinygrad import Device
|
||||||
from tinygrad.helpers import colored, getenv, CI
|
from tinygrad.helpers import colored, getenv, CI
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["METAL"] = "1"
|
os.environ["METAL"] = "1"
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -20,27 +22,38 @@ N = 16384
|
||||||
M = 4096
|
M = 4096
|
||||||
FLOPS = N * M * 2
|
FLOPS = N * M * 2
|
||||||
|
|
||||||
nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
nb = np.random.default_rng().standard_normal(
|
||||||
nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
size=(N), dtype=np.float32
|
||||||
|
) # .astype(np.int32).astype(np.float32)
|
||||||
|
nc = np.random.default_rng().standard_normal(
|
||||||
|
size=(N, M), dtype=np.float32
|
||||||
|
) # .astype(np.int32).astype(np.float32)
|
||||||
|
|
||||||
import torch, torch.mps
|
import torch, torch.mps
|
||||||
b = torch.from_numpy(nb).to('mps')
|
|
||||||
c = torch.from_numpy(nc).to('mps')
|
b = torch.from_numpy(nb).to("mps")
|
||||||
|
c = torch.from_numpy(nc).to("mps")
|
||||||
|
|
||||||
|
|
||||||
def torch_prog(b, c):
|
def torch_prog(b, c):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
a = b @ c
|
a = b @ c
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
|
|
||||||
tm = min([torch_prog(b, c) for _ in range(200)])
|
tm = min([torch_prog(b, c) for _ in range(200)])
|
||||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch")
|
print(
|
||||||
|
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch"
|
||||||
|
)
|
||||||
torch_a = (b @ c).cpu()
|
torch_a = (b @ c).cpu()
|
||||||
|
|
||||||
WORKSIZE_ROW = 16
|
WORKSIZE_ROW = 16
|
||||||
WORKSIZE_COL = 1
|
WORKSIZE_COL = 1
|
||||||
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
|
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
|
||||||
GLOBAL_SIZE = [M // (LOCAL_SIZE[0] * LOCAL_SIZE[1] * 4), 1, 1]
|
GLOBAL_SIZE = [M // (LOCAL_SIZE[0] * LOCAL_SIZE[1] * 4), 1, 1]
|
||||||
prog = compile_metal(f"""
|
prog = compile_metal(
|
||||||
|
f"""
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
|
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
|
||||||
|
@ -86,41 +99,59 @@ kernel void test(device float* data0, const device float* data1, const device fl
|
||||||
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
|
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
prog = MetalProgram("test", prog)
|
prog = MetalProgram("test", prog)
|
||||||
# print(prog_string)
|
# print(prog_string)
|
||||||
na = np.zeros(M, dtype=np.float32)
|
na = np.zeros(M, dtype=np.float32)
|
||||||
b = RawMetalBuffer.fromCPU(nb)
|
b = RawMetalBuffer.fromCPU(nb)
|
||||||
c = RawMetalBuffer.fromCPU(nc)
|
c = RawMetalBuffer.fromCPU(nc)
|
||||||
|
|
||||||
|
|
||||||
def metalrun():
|
def metalrun():
|
||||||
a = RawMetalBuffer.fromCPU(na)
|
a = RawMetalBuffer.fromCPU(na)
|
||||||
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
|
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
def timeit(fxn):
|
def timeit(fxn):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
et = fxn()
|
et = fxn()
|
||||||
# NOTE: et doesn't contain the launch overhead
|
# NOTE: et doesn't contain the launch overhead
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
|
|
||||||
tm = min([timeit(metalrun) for _ in range(200)])
|
tm = min([timeit(metalrun) for _ in range(200)])
|
||||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
|
print(
|
||||||
|
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal"
|
||||||
|
)
|
||||||
metal_a = metalrun().toCPU().reshape(M)
|
metal_a = metalrun().toCPU().reshape(M)
|
||||||
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
|
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
|
||||||
|
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.jit import TinyJit
|
from tinygrad.jit import TinyJit
|
||||||
from tinygrad.runtime.ops_metal import METAL
|
from tinygrad.runtime.ops_metal import METAL
|
||||||
|
|
||||||
b = Tensor(nb)
|
b = Tensor(nb)
|
||||||
c = Tensor(nc)
|
c = Tensor(nc)
|
||||||
|
|
||||||
|
|
||||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def tiny_jit(b, c):
|
def tiny_jit(b, c):
|
||||||
return (b @ c).realize()
|
return (b @ c).realize()
|
||||||
|
|
||||||
|
|
||||||
def tiny_prog(b, c):
|
def tiny_prog(b, c):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
a = tiny_jit(b, c)
|
a = tiny_jit(b, c)
|
||||||
METAL.synchronize()
|
METAL.synchronize()
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
|
|
||||||
tm = min([tiny_prog(b, c) for _ in range(200)])
|
tm = min([tiny_prog(b, c) for _ in range(200)])
|
||||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad")
|
print(
|
||||||
|
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad"
|
||||||
|
)
|
||||||
tiny_a = tiny_jit(b, c).numpy()
|
tiny_a = tiny_jit(b, c).numpy()
|
||||||
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)
|
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)
|
|
@ -2,14 +2,28 @@ import numpy as np
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import dtypes
|
||||||
|
|
||||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||||
N = getenv("N", 4096)
|
N = getenv("N", 4096)
|
||||||
CNT = getenv("CNT", 10)
|
CNT = getenv("CNT", 10)
|
||||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
a, b = (
|
||||||
|
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||||
|
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||||
|
)
|
||||||
for i in range(CNT):
|
for i in range(CNT):
|
||||||
if i > 0 and getenv("RAND", 0) != 0:
|
if i > 0 and getenv("RAND", 0) != 0:
|
||||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
a, b = (
|
||||||
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
|
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||||
|
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||||
|
)
|
||||||
|
c = (
|
||||||
|
(a.reshape(N, 1, N) * b.permute(1, 0).reshape(1, N, N))
|
||||||
|
.float()
|
||||||
|
.sum(axis=2)
|
||||||
|
.realize()
|
||||||
|
if getenv("ACCUM_FP32")
|
||||||
|
else (a @ b).realize()
|
||||||
|
)
|
||||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||||
nc = c.numpy()
|
nc = c.numpy()
|
||||||
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2)
|
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2)
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import time
|
import time
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
gpus = tf.config.list_physical_devices('GPU')
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
if gpus:
|
if gpus:
|
||||||
try:
|
try:
|
||||||
# Currently, memory growth needs to be the same across GPUs
|
# Currently, memory growth needs to be the same across GPUs
|
||||||
for gpu in gpus:
|
for gpu in gpus:
|
||||||
tf.config.experimental.set_memory_growth(gpu, True)
|
tf.config.experimental.set_memory_growth(gpu, True)
|
||||||
logical_gpus = tf.config.list_logical_devices('GPU')
|
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||||
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
# Memory growth must be set before GPUs have been initialized
|
# Memory growth must be set before GPUs have been initialized
|
||||||
|
@ -26,8 +26,12 @@ for dtype in [tf.float16, tf.float32]:
|
||||||
def tf_prog(b, c):
|
def tf_prog(b, c):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
a = tf.matmul(b, c)
|
a = tf.matmul(b, c)
|
||||||
tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done.
|
tf.debugging.check_numerics(
|
||||||
|
a, "Nan or Inf in result"
|
||||||
|
) # Ensures that the calculation is done.
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
tm = min([tf_prog(b, c) for _ in range(20)])
|
tm = min([tf_prog(b, c) for _ in range(20)])
|
||||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
|
print(
|
||||||
|
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
|
||||||
|
)
|
||||||
|
|
|
@ -13,5 +13,8 @@ for dtype in [torch.float16, torch.float32]:
|
||||||
a = b @ c
|
a = b @ c
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
|
|
||||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
|
print(
|
||||||
|
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
|
||||||
|
)
|
||||||
|
|
|
@ -5,6 +5,7 @@ M, N, K = 1024, 1024, 1024
|
||||||
try:
|
try:
|
||||||
import tvm
|
import tvm
|
||||||
from tvm import te
|
from tvm import te
|
||||||
|
|
||||||
# print(tvm.target.Target.list_kinds())
|
# print(tvm.target.Target.list_kinds())
|
||||||
|
|
||||||
# c, opencl
|
# c, opencl
|
||||||
|
@ -39,9 +40,13 @@ C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||||
sched = C.lazydata.schedule()
|
sched = C.lazydata.schedule()
|
||||||
from tinygrad.codegen.linearizer import Linearizer
|
from tinygrad.codegen.linearizer import Linearizer
|
||||||
from tinygrad.codegen.kernel import LinearizerOptions
|
from tinygrad.codegen.kernel import LinearizerOptions
|
||||||
lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False))
|
|
||||||
|
lin = Linearizer(
|
||||||
|
sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False)
|
||||||
|
)
|
||||||
# lin.hand_coded_optimizations()
|
# lin.hand_coded_optimizations()
|
||||||
lin.linearize()
|
lin.linearize()
|
||||||
from tinygrad.runtime.ops_clang import renderer
|
from tinygrad.runtime.ops_clang import renderer
|
||||||
|
|
||||||
src = renderer("mmult", lin.uops)
|
src = renderer("mmult", lin.uops)
|
||||||
print(src)
|
print(src)
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
def mask_like(like, mask_inx, mask_value=1.0):
|
def mask_like(like, mask_inx, mask_value=1.0):
|
||||||
mask = np.zeros_like(like).reshape(-1)
|
mask = np.zeros_like(like).reshape(-1)
|
||||||
mask[mask_inx] = mask_value
|
mask[mask_inx] = mask_value
|
||||||
return mask.reshape(like.shape)
|
return mask.reshape(like.shape)
|
||||||
|
|
||||||
|
|
||||||
def jacobian(func, input):
|
def jacobian(func, input):
|
||||||
output = func(input)
|
output = func(input)
|
||||||
|
|
||||||
|
@ -19,13 +21,14 @@ def jacobian(func, input):
|
||||||
|
|
||||||
# tinygrad doesn't support slicing, tiny-hack to select
|
# tinygrad doesn't support slicing, tiny-hack to select
|
||||||
# the needed scalar an backpropagate only through it
|
# the needed scalar an backpropagate only through it
|
||||||
o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum()
|
o_scalar = Tensor(mask_like(output.numpy(), o, 1.0)).mul(output).sum()
|
||||||
o_scalar.backward()
|
o_scalar.backward()
|
||||||
|
|
||||||
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
|
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
|
||||||
J[o, i] = grad
|
J[o, i] = grad
|
||||||
return J
|
return J
|
||||||
|
|
||||||
|
|
||||||
def numerical_jacobian(func, input, eps=1e-3):
|
def numerical_jacobian(func, input, eps=1e-3):
|
||||||
output = func(input)
|
output = func(input)
|
||||||
|
|
||||||
|
@ -36,14 +39,19 @@ def numerical_jacobian(func, input, eps = 1e-3):
|
||||||
for i in range(ji):
|
for i in range(ji):
|
||||||
eps_perturb = mask_like(input.numpy(), i, mask_value=eps)
|
eps_perturb = mask_like(input.numpy(), i, mask_value=eps)
|
||||||
|
|
||||||
output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
|
output_perturb_add = (
|
||||||
output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
|
func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
|
||||||
|
)
|
||||||
|
output_perturb_sub = (
|
||||||
|
func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
|
||||||
|
)
|
||||||
|
|
||||||
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2 * eps)
|
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2 * eps)
|
||||||
|
|
||||||
NJ[:, i] = grad_approx
|
NJ[:, i] = grad_approx
|
||||||
return NJ
|
return NJ
|
||||||
|
|
||||||
|
|
||||||
def gradcheck(func, input, eps=1e-3, atol=1e-3, rtol=1e-3):
|
def gradcheck(func, input, eps=1e-3, atol=1e-3, rtol=1e-3):
|
||||||
NJ = numerical_jacobian(func, input, eps)
|
NJ = numerical_jacobian(func, input, eps)
|
||||||
J = jacobian(func, input)
|
J = jacobian(func, input)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import multiprocessing, subprocess
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def _early_exec_process(qin, qout):
|
def _early_exec_process(qin, qout):
|
||||||
while True:
|
while True:
|
||||||
path, inp = qin.get()
|
path, inp = qin.get()
|
||||||
|
@ -10,41 +11,62 @@ def _early_exec_process(qin, qout):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
qout.put(e)
|
qout.put(e)
|
||||||
|
|
||||||
|
|
||||||
def enable_early_exec():
|
def enable_early_exec():
|
||||||
qin: multiprocessing.Queue = multiprocessing.Queue()
|
qin: multiprocessing.Queue = multiprocessing.Queue()
|
||||||
qout: multiprocessing.Queue = multiprocessing.Queue()
|
qout: multiprocessing.Queue = multiprocessing.Queue()
|
||||||
p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout))
|
p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout))
|
||||||
p.daemon = True
|
p.daemon = True
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
def early_exec(x):
|
def early_exec(x):
|
||||||
qin.put(x)
|
qin.put(x)
|
||||||
ret = qout.get()
|
ret = qout.get()
|
||||||
if isinstance(ret, Exception): raise ret
|
if isinstance(ret, Exception):
|
||||||
else: return ret
|
raise ret
|
||||||
|
else:
|
||||||
|
return ret
|
||||||
|
|
||||||
return early_exec
|
return early_exec
|
||||||
|
|
||||||
|
|
||||||
def proc(itermaker, q) -> None:
|
def proc(itermaker, q) -> None:
|
||||||
try:
|
try:
|
||||||
for x in itermaker(): q.put(x)
|
for x in itermaker():
|
||||||
|
q.put(x)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
q.put(e)
|
q.put(e)
|
||||||
finally:
|
finally:
|
||||||
q.put(None)
|
q.put(None)
|
||||||
q.close()
|
q.close()
|
||||||
|
|
||||||
|
|
||||||
class _CloudpickleFunctionWrapper:
|
class _CloudpickleFunctionWrapper:
|
||||||
def __init__(self, fn): self.fn = fn
|
def __init__(self, fn):
|
||||||
def __getstate__(self): return cloudpickle.dumps(self.fn)
|
self.fn = fn
|
||||||
def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn)
|
|
||||||
def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs)
|
def __getstate__(self):
|
||||||
|
return cloudpickle.dumps(self.fn)
|
||||||
|
|
||||||
|
def __setstate__(self, pfn):
|
||||||
|
self.fn = cloudpickle.loads(pfn)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
return self.fn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def cross_process(itermaker, maxsize=16):
|
def cross_process(itermaker, maxsize=16):
|
||||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
||||||
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
||||||
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
|
p = multiprocessing.Process(
|
||||||
|
target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q)
|
||||||
|
)
|
||||||
p.start()
|
p.start()
|
||||||
while True:
|
while True:
|
||||||
ret = q.get()
|
ret = q.get()
|
||||||
if isinstance(ret, Exception): raise ret
|
if isinstance(ret, Exception):
|
||||||
elif ret is None: break
|
raise ret
|
||||||
else: yield ret
|
elif ret is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield ret
|
||||||
|
|
|
@ -6,6 +6,7 @@ from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.runtime.ops_gpu import CLBuffer
|
from tinygrad.runtime.ops_gpu import CLBuffer
|
||||||
from tinygrad.helpers import GlobalCounters
|
from tinygrad.helpers import GlobalCounters
|
||||||
|
|
||||||
|
|
||||||
def print_objects():
|
def print_objects():
|
||||||
# gc.collect()
|
# gc.collect()
|
||||||
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
|
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
|
||||||
|
@ -15,8 +16,12 @@ def print_objects():
|
||||||
realized_buffers = [x.realized for x in lazybuffers if x.realized]
|
realized_buffers = [x.realized for x in lazybuffers if x.realized]
|
||||||
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]
|
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]
|
||||||
|
|
||||||
print(f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB")
|
print(
|
||||||
print(f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers")
|
f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers"
|
||||||
|
)
|
||||||
print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned")
|
print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned")
|
||||||
|
|
||||||
cnt = 0
|
cnt = 0
|
||||||
|
@ -33,11 +38,14 @@ def print_objects():
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
for x in gpubuffers_orphaned:
|
for x in gpubuffers_orphaned:
|
||||||
if getattr(x, '_buf', None): del x._buf
|
if getattr(x, "_buf", None):
|
||||||
if getattr(x, '_image', None): del x._image
|
del x._buf
|
||||||
|
if getattr(x, "_image", None):
|
||||||
|
del x._image
|
||||||
|
|
||||||
return len(gpubuffers_orphaned)
|
return len(gpubuffers_orphaned)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
|
|
@ -7,39 +7,44 @@ from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
from google.protobuf.internal import builder as _builder
|
from google.protobuf.internal import builder as _builder
|
||||||
|
|
||||||
# @@protoc_insertion_point(imports)
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||||
|
b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
|
)
|
||||||
|
|
||||||
_globals = globals()
|
_globals = globals()
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals)
|
||||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||||
_globals['DESCRIPTOR']._options = None
|
_globals["DESCRIPTOR"]._options = None
|
||||||
_globals['DESCRIPTOR']._serialized_options = b'H\003'
|
_globals["DESCRIPTOR"]._serialized_options = b"H\003"
|
||||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
|
_globals["_TRAINERSPEC"].fields_by_name["mining_sentence_size"]._options = None
|
||||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
|
_globals["_TRAINERSPEC"].fields_by_name[
|
||||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
|
"mining_sentence_size"
|
||||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
|
]._serialized_options = b"\030\001"
|
||||||
_globals['_TRAINERSPEC']._serialized_start=45
|
_globals["_TRAINERSPEC"].fields_by_name["training_sentence_size"]._options = None
|
||||||
_globals['_TRAINERSPEC']._serialized_end=1581
|
_globals["_TRAINERSPEC"].fields_by_name[
|
||||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
|
"training_sentence_size"
|
||||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
|
]._serialized_options = b"\030\001"
|
||||||
_globals['_NORMALIZERSPEC']._serialized_start=1584
|
_globals["_TRAINERSPEC"]._serialized_start = 45
|
||||||
_globals['_NORMALIZERSPEC']._serialized_end=1793
|
_globals["_TRAINERSPEC"]._serialized_end = 1581
|
||||||
_globals['_SELFTESTDATA']._serialized_start=1795
|
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517
|
||||||
_globals['_SELFTESTDATA']._serialized_end=1916
|
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570
|
||||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
|
_globals["_NORMALIZERSPEC"]._serialized_start = 1584
|
||||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
|
_globals["_NORMALIZERSPEC"]._serialized_end = 1793
|
||||||
_globals['_MODELPROTO']._serialized_start=1919
|
_globals["_SELFTESTDATA"]._serialized_start = 1795
|
||||||
_globals['_MODELPROTO']._serialized_end=2429
|
_globals["_SELFTESTDATA"]._serialized_end = 1916
|
||||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
|
_globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864
|
||||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
|
_globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905
|
||||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
|
_globals["_MODELPROTO"]._serialized_start = 1919
|
||||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
|
_globals["_MODELPROTO"]._serialized_end = 2429
|
||||||
|
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208
|
||||||
|
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418
|
||||||
|
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323
|
||||||
|
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407
|
||||||
# @@protoc_insertion_point(module_scope)
|
# @@protoc_insertion_point(module_scope)
|
||||||
|
|
|
@ -3,17 +3,22 @@ from typing import List
|
||||||
from tinygrad.nn.optim import Optimizer
|
from tinygrad.nn.optim import Optimizer
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
class LR_Scheduler:
|
class LR_Scheduler:
|
||||||
def __init__(self, optimizer: Optimizer):
|
def __init__(self, optimizer: Optimizer):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
|
self.epoch_counter = Tensor(
|
||||||
|
[0], requires_grad=False, device=self.optimizer.device
|
||||||
|
)
|
||||||
|
|
||||||
def get_lr(self): pass
|
def get_lr(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def step(self) -> None:
|
def step(self) -> None:
|
||||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||||
self.optimizer.lr.assign(self.get_lr()).realize()
|
self.optimizer.lr.assign(self.get_lr()).realize()
|
||||||
|
|
||||||
|
|
||||||
class MultiStepLR(LR_Scheduler):
|
class MultiStepLR(LR_Scheduler):
|
||||||
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
|
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
|
@ -25,18 +30,38 @@ class MultiStepLR(LR_Scheduler):
|
||||||
return self.optimizer.lr
|
return self.optimizer.lr
|
||||||
return self.optimizer.lr * self.gamma
|
return self.optimizer.lr * self.gamma
|
||||||
|
|
||||||
|
|
||||||
class ReduceLROnPlateau(LR_Scheduler):
|
class ReduceLROnPlateau(LR_Scheduler):
|
||||||
def __init__(self, optimizer: Optimizer, mode="min", factor=0.1, patience=10, threshold=1e-4, threshold_mode="rel"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
mode="min",
|
||||||
|
factor=0.1,
|
||||||
|
patience=10,
|
||||||
|
threshold=1e-4,
|
||||||
|
threshold_mode="rel",
|
||||||
|
):
|
||||||
assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"]
|
assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"]
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = mode, factor, patience, threshold, threshold_mode
|
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = (
|
||||||
self.best = float('inf') if mode == "min" else float('-inf')
|
mode,
|
||||||
|
factor,
|
||||||
|
patience,
|
||||||
|
threshold,
|
||||||
|
threshold_mode,
|
||||||
|
)
|
||||||
|
self.best = float("inf") if mode == "min" else float("-inf")
|
||||||
self.bad_epoch = 0
|
self.bad_epoch = 0
|
||||||
|
|
||||||
if mode == "min": self.threshold *= -1
|
if mode == "min":
|
||||||
|
self.threshold *= -1
|
||||||
|
|
||||||
def is_better(self, current: float) -> bool:
|
def is_better(self, current: float) -> bool:
|
||||||
dynamic_threshold = self.best*(1+self.threshold) if self.threshold_mode == "rel" else self.best+self.threshold
|
dynamic_threshold = (
|
||||||
|
self.best * (1 + self.threshold)
|
||||||
|
if self.threshold_mode == "rel"
|
||||||
|
else self.best + self.threshold
|
||||||
|
)
|
||||||
if self.mode == "min":
|
if self.mode == "min":
|
||||||
return current < dynamic_threshold
|
return current < dynamic_threshold
|
||||||
return current > dynamic_threshold
|
return current > dynamic_threshold
|
||||||
|
@ -53,6 +78,7 @@ class ReduceLROnPlateau(LR_Scheduler):
|
||||||
self.optimizer.lr *= self.factor
|
self.optimizer.lr *= self.factor
|
||||||
self.bad_epoch = 0
|
self.bad_epoch = 0
|
||||||
|
|
||||||
|
|
||||||
class CosineAnnealingLR(LR_Scheduler):
|
class CosineAnnealingLR(LR_Scheduler):
|
||||||
def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0):
|
def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0):
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
|
@ -61,26 +87,54 @@ class CosineAnnealingLR(LR_Scheduler):
|
||||||
self.eta_max = optimizer.lr.numpy()[0]
|
self.eta_max = optimizer.lr.numpy()[0]
|
||||||
|
|
||||||
def get_lr(self) -> Tensor:
|
def get_lr(self) -> Tensor:
|
||||||
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device)
|
return Tensor(
|
||||||
|
[
|
||||||
|
self.eta_min
|
||||||
|
+ 0.5
|
||||||
|
* (self.eta_max - self.eta_min)
|
||||||
|
* (1 + math.cos((self.epoch_counter.numpy()[0] / self.T_max) * math.pi))
|
||||||
|
],
|
||||||
|
device=self.optimizer.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLR(LR_Scheduler):
|
class OneCycleLR(LR_Scheduler):
|
||||||
def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
|
def __init__(
|
||||||
anneal_strategy: str = 'linear', cycle_momentum: bool = False):
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
max_lr: float,
|
||||||
|
div_factor: float,
|
||||||
|
final_div_factor: float,
|
||||||
|
total_steps: int,
|
||||||
|
pct_start: float,
|
||||||
|
anneal_strategy: str = "linear",
|
||||||
|
cycle_momentum: bool = False,
|
||||||
|
):
|
||||||
self.initial_lr = Tensor([max_lr / div_factor]).contiguous()
|
self.initial_lr = Tensor([max_lr / div_factor]).contiguous()
|
||||||
self.max_lr = Tensor([max_lr]).contiguous()
|
self.max_lr = Tensor([max_lr]).contiguous()
|
||||||
self.min_lr = self.initial_lr / final_div_factor
|
self.min_lr = self.initial_lr / final_div_factor
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
self.total_steps = total_steps
|
self.total_steps = total_steps
|
||||||
self.pct_start = pct_start
|
self.pct_start = pct_start
|
||||||
assert anneal_strategy == 'linear', 'only linear annealing supported'
|
assert anneal_strategy == "linear", "only linear annealing supported"
|
||||||
assert not cycle_momentum, 'cycle momentum not supported'
|
assert not cycle_momentum, "cycle momentum not supported"
|
||||||
self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR
|
self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor: return ((end - start) * pct + start)
|
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor:
|
||||||
|
return (end - start) * pct + start
|
||||||
|
|
||||||
def get_lr(self) -> Tensor:
|
def get_lr(self) -> Tensor:
|
||||||
return (self.epoch_counter < self.total_steps * self.pct_start).where(
|
return (self.epoch_counter < self.total_steps * self.pct_start).where(
|
||||||
self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)),
|
self._annealing_linear(
|
||||||
self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start)))
|
self.initial_lr,
|
||||||
|
self.max_lr,
|
||||||
|
self.epoch_counter / (self.total_steps * self.pct_start),
|
||||||
|
),
|
||||||
|
self._annealing_linear(
|
||||||
|
self.max_lr,
|
||||||
|
self.min_lr,
|
||||||
|
(self.epoch_counter - (self.total_steps * self.pct_start))
|
||||||
|
/ (self.total_steps * (1 - self.pct_start)),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,8 +5,29 @@ from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class BertForQuestionAnswering:
|
class BertForQuestionAnswering:
|
||||||
def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1):
|
def __init__(
|
||||||
self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
|
self,
|
||||||
|
hidden_size=1024,
|
||||||
|
intermediate_size=4096,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
type_vocab_size=2,
|
||||||
|
vocab_size=30522,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
):
|
||||||
|
self.bert = Bert(
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
max_position_embeddings,
|
||||||
|
num_attention_heads,
|
||||||
|
num_hidden_layers,
|
||||||
|
type_vocab_size,
|
||||||
|
vocab_size,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
)
|
||||||
self.qa_outputs = Linear(hidden_size, 2)
|
self.qa_outputs = Linear(hidden_size, 2)
|
||||||
|
|
||||||
def load_from_pretrained(self):
|
def load_from_pretrained(self):
|
||||||
|
@ -16,15 +37,20 @@ class BertForQuestionAnswering:
|
||||||
fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
|
fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
with open(fn, "rb") as f:
|
with open(fn, "rb") as f:
|
||||||
state_dict = torch.load(f, map_location="cpu")
|
state_dict = torch.load(f, map_location="cpu")
|
||||||
|
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if "dropout" in k: continue # skip dropout
|
if "dropout" in k:
|
||||||
if "pooler" in k: continue # skip pooler
|
continue # skip dropout
|
||||||
|
if "pooler" in k:
|
||||||
|
continue # skip pooler
|
||||||
get_child(self, k).assign(v.numpy()).realize()
|
get_child(self, k).assign(v.numpy()).realize()
|
||||||
|
|
||||||
def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor):
|
def __call__(
|
||||||
|
self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor
|
||||||
|
):
|
||||||
sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
|
sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
|
||||||
logits = self.qa_outputs(sequence_output)
|
logits = self.qa_outputs(sequence_output)
|
||||||
start_logits, end_logits = logits.chunk(2, dim=-1)
|
start_logits, end_logits = logits.chunk(2, dim=-1)
|
||||||
|
@ -33,10 +59,35 @@ class BertForQuestionAnswering:
|
||||||
|
|
||||||
return Tensor.stack([start_logits, end_logits])
|
return Tensor.stack([start_logits, end_logits])
|
||||||
|
|
||||||
|
|
||||||
class Bert:
|
class Bert:
|
||||||
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
|
def __init__(
|
||||||
self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
|
self,
|
||||||
self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob)
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
max_position_embeddings,
|
||||||
|
num_attention_heads,
|
||||||
|
num_hidden_layers,
|
||||||
|
type_vocab_size,
|
||||||
|
vocab_size,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
):
|
||||||
|
self.embeddings = BertEmbeddings(
|
||||||
|
hidden_size,
|
||||||
|
max_position_embeddings,
|
||||||
|
type_vocab_size,
|
||||||
|
vocab_size,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
)
|
||||||
|
self.encoder = BertEncoder(
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
num_attention_heads,
|
||||||
|
num_hidden_layers,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids):
|
def __call__(self, input_ids, attention_mask, token_type_ids):
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
@ -47,8 +98,16 @@ class Bert:
|
||||||
|
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
|
|
||||||
class BertEmbeddings:
|
class BertEmbeddings:
|
||||||
def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size,
|
||||||
|
max_position_embeddings,
|
||||||
|
type_vocab_size,
|
||||||
|
vocab_size,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
):
|
||||||
self.word_embeddings = Embedding(vocab_size, hidden_size)
|
self.word_embeddings = Embedding(vocab_size, hidden_size)
|
||||||
self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
|
self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
|
||||||
self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
|
self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
|
||||||
|
@ -59,7 +118,11 @@ class BertEmbeddings:
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
position_ids = (
|
||||||
|
Tensor.arange(seq_length, requires_grad=False)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(*input_shape)
|
||||||
|
)
|
||||||
words_embeddings = self.word_embeddings(input_ids)
|
words_embeddings = self.word_embeddings(input_ids)
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
@ -69,18 +132,49 @@ class BertEmbeddings:
|
||||||
embeddings = embeddings.dropout(self.dropout)
|
embeddings = embeddings.dropout(self.dropout)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder:
|
class BertEncoder:
|
||||||
def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob):
|
def __init__(
|
||||||
self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)]
|
self,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
num_attention_heads,
|
||||||
|
num_hidden_layers,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
):
|
||||||
|
self.layer = [
|
||||||
|
BertLayer(
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
)
|
||||||
|
for _ in range(num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask):
|
def __call__(self, hidden_states, attention_mask):
|
||||||
for layer in self.layer:
|
for layer in self.layer:
|
||||||
hidden_states = layer(hidden_states, attention_mask)
|
hidden_states = layer(hidden_states, attention_mask)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class BertLayer:
|
class BertLayer:
|
||||||
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
|
def __init__(
|
||||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
|
self,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
):
|
||||||
|
self.attention = BertAttention(
|
||||||
|
hidden_size,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
)
|
||||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size)
|
self.intermediate = BertIntermediate(hidden_size, intermediate_size)
|
||||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
|
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
|
||||||
|
|
||||||
|
@ -90,6 +184,7 @@ class BertLayer:
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
layer_output = self.output(intermediate_output, attention_output)
|
||||||
return layer_output
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
class BertOutput:
|
class BertOutput:
|
||||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
|
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
|
||||||
self.dense = Linear(intermediate_size, hidden_size)
|
self.dense = Linear(intermediate_size, hidden_size)
|
||||||
|
@ -102,10 +197,21 @@ class BertOutput:
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# approximation of the error function
|
# approximation of the error function
|
||||||
def erf(x):
|
def erf(x):
|
||||||
t = (1 + 0.3275911 * x.abs()).reciprocal()
|
t = (1 + 0.3275911 * x.abs()).reciprocal()
|
||||||
return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp())
|
return x.sign() * (
|
||||||
|
1
|
||||||
|
- (
|
||||||
|
(((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736)
|
||||||
|
* t
|
||||||
|
+ 0.254829592
|
||||||
|
)
|
||||||
|
* t
|
||||||
|
* (-(x.square())).exp()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BertIntermediate:
|
class BertIntermediate:
|
||||||
def __init__(self, hidden_size, intermediate_size):
|
def __init__(self, hidden_size, intermediate_size):
|
||||||
|
@ -116,9 +222,18 @@ class BertIntermediate:
|
||||||
# tinygrad gelu is openai gelu but we need the original bert gelu
|
# tinygrad gelu is openai gelu but we need the original bert gelu
|
||||||
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
||||||
|
|
||||||
|
|
||||||
class BertAttention:
|
class BertAttention:
|
||||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
|
def __init__(
|
||||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
|
self,
|
||||||
|
hidden_size,
|
||||||
|
num_attention_heads,
|
||||||
|
attention_probs_dropout_prob,
|
||||||
|
hidden_dropout_prob,
|
||||||
|
):
|
||||||
|
self.self = BertSelfAttention(
|
||||||
|
hidden_size, num_attention_heads, attention_probs_dropout_prob
|
||||||
|
)
|
||||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
|
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask):
|
def __call__(self, hidden_states, attention_mask):
|
||||||
|
@ -126,6 +241,7 @@ class BertAttention:
|
||||||
attention_output = self.output(self_output, hidden_states)
|
attention_output = self.output(self_output, hidden_states)
|
||||||
return attention_output
|
return attention_output
|
||||||
|
|
||||||
|
|
||||||
class BertSelfAttention:
|
class BertSelfAttention:
|
||||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
|
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
@ -147,17 +263,24 @@ class BertSelfAttention:
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||||
|
|
||||||
context_layer = Tensor.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, self.dropout)
|
context_layer = Tensor.scaled_dot_product_attention(
|
||||||
|
query_layer, key_layer, value_layer, attention_mask, self.dropout
|
||||||
|
)
|
||||||
|
|
||||||
context_layer = context_layer.transpose(1, 2)
|
context_layer = context_layer.transpose(1, 2)
|
||||||
context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size)
|
context_layer = context_layer.reshape(
|
||||||
|
context_layer.shape[0], context_layer.shape[1], self.all_head_size
|
||||||
|
)
|
||||||
|
|
||||||
return context_layer
|
return context_layer
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size)
|
x = x.reshape(
|
||||||
|
x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size
|
||||||
|
)
|
||||||
return x.transpose(1, 2)
|
return x.transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
class BertSelfOutput:
|
class BertSelfOutput:
|
||||||
def __init__(self, hidden_size, hidden_dropout_prob):
|
def __init__(self, hidden_size, hidden_dropout_prob):
|
||||||
self.dense = Linear(hidden_size, hidden_size)
|
self.dense = Linear(hidden_size, hidden_size)
|
||||||
|
|
|
@ -2,6 +2,7 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
||||||
from tinygrad.helpers import fetch, get_child
|
from tinygrad.helpers import fetch, get_child
|
||||||
|
|
||||||
|
|
||||||
class Block:
|
class Block:
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
||||||
|
@ -11,18 +12,43 @@ class Block:
|
||||||
self.gamma = Tensor.ones(dim)
|
self.gamma = Tensor.ones(dim)
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
return x + x.sequential([
|
return x + x.sequential(
|
||||||
self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm,
|
[
|
||||||
self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2)
|
self.dwconv,
|
||||||
])
|
lambda x: x.permute(0, 2, 3, 1),
|
||||||
|
self.norm,
|
||||||
|
self.pwconv1,
|
||||||
|
Tensor.gelu,
|
||||||
|
self.pwconv2,
|
||||||
|
lambda x: (self.gamma * x).permute(0, 3, 1, 2),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt:
|
class ConvNeXt:
|
||||||
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chans=3,
|
||||||
|
num_classes=1000,
|
||||||
|
depths=[3, 3, 9, 3],
|
||||||
|
dims=[96, 192, 384, 768],
|
||||||
|
):
|
||||||
self.downsample_layers = [
|
self.downsample_layers = [
|
||||||
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
|
[
|
||||||
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
||||||
|
LayerNorm2d(dims[0], eps=1e-6),
|
||||||
|
],
|
||||||
|
*[
|
||||||
|
[
|
||||||
|
LayerNorm2d(dims[i], eps=1e-6),
|
||||||
|
Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
||||||
|
]
|
||||||
|
for i in range(len(dims) - 1)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
self.stages = [
|
||||||
|
[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))
|
||||||
]
|
]
|
||||||
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
|
|
||||||
self.norm = LayerNorm(dims[-1])
|
self.norm = LayerNorm(dims[-1])
|
||||||
self.head = Linear(dims[-1], num_classes)
|
self.head = Linear(dims[-1], num_classes)
|
||||||
|
|
||||||
|
@ -31,6 +57,7 @@ class ConvNeXt:
|
||||||
x = x.sequential(downsample).sequential(stage)
|
x = x.sequential(downsample).sequential(stage)
|
||||||
return x.mean([-2, -1]).sequential([self.norm, self.head])
|
return x.mean([-2, -1]).sequential([self.norm, self.head])
|
||||||
|
|
||||||
|
|
||||||
# *** model definition is done ***
|
# *** model definition is done ***
|
||||||
|
|
||||||
versions = {
|
versions = {
|
||||||
|
@ -38,24 +65,32 @@ versions = {
|
||||||
"small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
|
"small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
|
||||||
"base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
|
"base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
|
||||||
"large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
|
"large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
|
||||||
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]}
|
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_model(version, load_weights=False):
|
def get_model(version, load_weights=False):
|
||||||
model = ConvNeXt(**versions[version])
|
model = ConvNeXt(**versions[version])
|
||||||
if load_weights:
|
if load_weights:
|
||||||
from tinygrad.nn.state import torch_load
|
from tinygrad.nn.state import torch_load
|
||||||
weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
|
|
||||||
|
weights = torch_load(
|
||||||
|
fetch(
|
||||||
|
f"https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth"
|
||||||
|
)
|
||||||
|
)["model"]
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
mv = get_child(model, k)
|
mv = get_child(model, k)
|
||||||
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
|
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = get_model("tiny", True)
|
model = get_model("tiny", True)
|
||||||
|
|
||||||
# load image
|
# load image
|
||||||
from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
|
from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
|
||||||
|
|
||||||
img = Tensor(preprocess(chicken_img))
|
img = Tensor(preprocess(chicken_img))
|
||||||
|
|
||||||
Tensor.training = False
|
Tensor.training = False
|
||||||
|
|
|
@ -4,8 +4,19 @@ from tinygrad.nn import BatchNorm2d
|
||||||
from tinygrad.helpers import get_child, fetch
|
from tinygrad.helpers import get_child, fetch
|
||||||
from tinygrad.nn.state import torch_load
|
from tinygrad.nn.state import torch_load
|
||||||
|
|
||||||
|
|
||||||
class MBConvBlock:
|
class MBConvBlock:
|
||||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
kernel_size,
|
||||||
|
strides,
|
||||||
|
expand_ratio,
|
||||||
|
input_filters,
|
||||||
|
output_filters,
|
||||||
|
se_ratio,
|
||||||
|
has_se,
|
||||||
|
track_running_stats=True,
|
||||||
|
):
|
||||||
oup = expand_ratio * input_filters
|
oup = expand_ratio * input_filters
|
||||||
if expand_ratio != 1:
|
if expand_ratio != 1:
|
||||||
self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
|
self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
|
||||||
|
@ -37,12 +48,19 @@ class MBConvBlock:
|
||||||
x = inputs
|
x = inputs
|
||||||
if self._expand_conv:
|
if self._expand_conv:
|
||||||
x = self._bn0(x.conv2d(self._expand_conv)).swish()
|
x = self._bn0(x.conv2d(self._expand_conv)).swish()
|
||||||
x = x.conv2d(self._depthwise_conv, padding=self.pad, stride=self.strides, groups=self._depthwise_conv.shape[0])
|
x = x.conv2d(
|
||||||
|
self._depthwise_conv,
|
||||||
|
padding=self.pad,
|
||||||
|
stride=self.strides,
|
||||||
|
groups=self._depthwise_conv.shape[0],
|
||||||
|
)
|
||||||
x = self._bn1(x).swish()
|
x = self._bn1(x).swish()
|
||||||
|
|
||||||
if self.has_se:
|
if self.has_se:
|
||||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||||
x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
|
x_squeezed = x_squeezed.conv2d(
|
||||||
|
self._se_reduce, self._se_reduce_bias
|
||||||
|
).swish()
|
||||||
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
|
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
|
||||||
x = x.mul(x_squeezed.sigmoid())
|
x = x.mul(x_squeezed.sigmoid())
|
||||||
|
|
||||||
|
@ -51,8 +69,17 @@ class MBConvBlock:
|
||||||
x = x.add(inputs)
|
x = x.add(inputs)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class EfficientNet:
|
class EfficientNet:
|
||||||
def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
number=0,
|
||||||
|
classes=1000,
|
||||||
|
has_se=True,
|
||||||
|
track_running_stats=True,
|
||||||
|
input_channels=3,
|
||||||
|
has_fc_output=True,
|
||||||
|
):
|
||||||
self.number = number
|
self.number = number
|
||||||
global_params = [
|
global_params = [
|
||||||
# width, depth
|
# width, depth
|
||||||
|
@ -106,10 +133,31 @@ class EfficientNet:
|
||||||
]
|
]
|
||||||
|
|
||||||
self._blocks = []
|
self._blocks = []
|
||||||
for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
|
for (
|
||||||
input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
|
num_repeats,
|
||||||
|
kernel_size,
|
||||||
|
strides,
|
||||||
|
expand_ratio,
|
||||||
|
input_filters,
|
||||||
|
output_filters,
|
||||||
|
se_ratio,
|
||||||
|
) in blocks_args:
|
||||||
|
input_filters, output_filters = round_filters(input_filters), round_filters(
|
||||||
|
output_filters
|
||||||
|
)
|
||||||
for n in range(round_repeats(num_repeats)):
|
for n in range(round_repeats(num_repeats)):
|
||||||
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
|
self._blocks.append(
|
||||||
|
MBConvBlock(
|
||||||
|
kernel_size,
|
||||||
|
strides,
|
||||||
|
expand_ratio,
|
||||||
|
input_filters,
|
||||||
|
output_filters,
|
||||||
|
se_ratio,
|
||||||
|
has_se=has_se,
|
||||||
|
track_running_stats=track_running_stats,
|
||||||
|
)
|
||||||
|
)
|
||||||
input_filters = output_filters
|
input_filters = output_filters
|
||||||
strides = (1, 1)
|
strides = (1, 1)
|
||||||
|
|
||||||
|
@ -140,25 +188,34 @@ class EfficientNet:
|
||||||
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||||
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||||
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
b0 = torch_load(fetch(model_urls[self.number]))
|
b0 = torch_load(fetch(model_urls[self.number]))
|
||||||
for k, v in b0.items():
|
for k, v in b0.items():
|
||||||
if k.endswith("num_batches_tracked"): continue
|
if k.endswith("num_batches_tracked"):
|
||||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
continue
|
||||||
|
for cat in [
|
||||||
|
"_conv_head",
|
||||||
|
"_conv_stem",
|
||||||
|
"_depthwise_conv",
|
||||||
|
"_expand_conv",
|
||||||
|
"_fc",
|
||||||
|
"_project_conv",
|
||||||
|
"_se_reduce",
|
||||||
|
"_se_expand",
|
||||||
|
]:
|
||||||
if cat in k:
|
if cat in k:
|
||||||
k = k.replace('.bias', '_bias')
|
k = k.replace(".bias", "_bias")
|
||||||
k = k.replace('.weight', '')
|
k = k.replace(".weight", "")
|
||||||
|
|
||||||
# print(k, v.shape)
|
# print(k, v.shape)
|
||||||
mv = get_child(self, k)
|
mv = get_child(self, k)
|
||||||
vnp = v # .astype(np.float32)
|
vnp = v # .astype(np.float32)
|
||||||
vnp = vnp if k != '_fc' else vnp.cpu().T
|
vnp = vnp if k != "_fc" else vnp.cpu().T
|
||||||
# vnp = vnp if vnp.shape != () else np.array([vnp])
|
# vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||||
|
|
||||||
if mv.shape == vnp.shape:
|
if mv.shape == vnp.shape:
|
||||||
mv.assign(vnp.to(mv.device))
|
mv.assign(vnp.to(mv.device))
|
||||||
else:
|
else:
|
||||||
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,15 @@ from typing import Tuple, Union, Optional, Dict
|
||||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
|
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
|
||||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
|
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(
|
||||||
|
1, end, 1, dim // 2, 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||||
def complex_mult(A, c, d):
|
def complex_mult(A, c, d):
|
||||||
|
@ -15,20 +19,33 @@ def complex_mult(A, c, d):
|
||||||
co = a * d + b * c
|
co = a * d + b * c
|
||||||
return ro.cat(co, dim=-1)
|
return ro.cat(co, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
||||||
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
assert (
|
||||||
|
freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1]
|
||||||
|
), f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||||
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
||||||
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
|
c, d = (
|
||||||
|
freqs_cis[:, : xq.shape[1], :, :, 0:1],
|
||||||
|
freqs_cis[:, : xq.shape[1], :, :, 1:2],
|
||||||
|
)
|
||||||
xq_out = complex_mult(xq, c, d)
|
xq_out = complex_mult(xq, c, d)
|
||||||
xk_out = complex_mult(xk, c, d)
|
xk_out = complex_mult(xk, c, d)
|
||||||
return xq_out.flatten(3), xk_out.flatten(3)
|
return xq_out.flatten(3), xk_out.flatten(3)
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||||
if n_rep == 1: return x
|
if n_rep == 1:
|
||||||
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
return x
|
||||||
|
return (
|
||||||
|
x.reshape(bs, seqlen, n_kv_heads, 1, head_dim)
|
||||||
|
.expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
|
||||||
|
.reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm:
|
class RMSNorm:
|
||||||
def __init__(self, dim, eps=1e-6):
|
def __init__(self, dim, eps=1e-6):
|
||||||
|
@ -39,10 +56,13 @@ class RMSNorm:
|
||||||
# TODO: convert to float?
|
# TODO: convert to float?
|
||||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||||
|
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
self.n_kv_heads = (
|
||||||
|
n_kv_heads if n_kv_heads is not None else n_heads
|
||||||
|
) # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
self.n_rep = self.n_heads // self.n_kv_heads
|
self.n_rep = self.n_heads // self.n_kv_heads
|
||||||
self.max_context = max_context
|
self.max_context = max_context
|
||||||
|
@ -52,7 +72,13 @@ class Attention:
|
||||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||||
|
|
||||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
start_pos: Union[Variable, int],
|
||||||
|
freqs_cis: Tensor,
|
||||||
|
mask: Optional[Tensor],
|
||||||
|
) -> Tensor:
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||||
|
@ -62,21 +88,40 @@ class Attention:
|
||||||
|
|
||||||
# create kv cache
|
# create kv cache
|
||||||
if not hasattr(self, "cache_k"):
|
if not hasattr(self, "cache_k"):
|
||||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
self.cache_k, self.cache_v = Tensor.zeros(
|
||||||
|
bsz, self.max_context, self.n_kv_heads, self.head_dim
|
||||||
|
), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
||||||
|
|
||||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||||
|
|
||||||
# update the cache
|
# update the cache
|
||||||
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
self.cache_k.assign(
|
||||||
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
keys.pad(
|
||||||
|
(None, (0, self.max_context - start_pos - seqlen), None, None)
|
||||||
|
).contiguous()
|
||||||
|
).realize()
|
||||||
|
self.cache_v.assign(
|
||||||
|
values.pad(
|
||||||
|
(None, (0, self.max_context - start_pos - seqlen), None, None)
|
||||||
|
).contiguous()
|
||||||
|
).realize()
|
||||||
|
|
||||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||||
|
|
||||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
xq, keys, values = (
|
||||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
xq.transpose(1, 2),
|
||||||
|
keys.transpose(1, 2),
|
||||||
|
values.transpose(1, 2),
|
||||||
|
)
|
||||||
|
attn = (
|
||||||
|
xq.scaled_dot_product_attention(keys, values, mask)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(bsz, seqlen, -1)
|
||||||
|
)
|
||||||
return self.wo(attn)
|
return self.wo(attn)
|
||||||
|
|
||||||
|
|
||||||
class FeedForward:
|
class FeedForward:
|
||||||
def __init__(self, dim, hidden_dim, linear=nn.Linear):
|
def __init__(self, dim, hidden_dim, linear=nn.Linear):
|
||||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||||
|
@ -84,36 +129,88 @@ class FeedForward:
|
||||||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||||
|
|
||||||
def __call__(self, x: Tensor) -> Tensor:
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
return self.w2(
|
||||||
|
self.w1(x).silu() * self.w3(x)
|
||||||
|
) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock:
|
class TransformerBlock:
|
||||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: float,
|
||||||
|
max_context: int,
|
||||||
|
linear=nn.Linear,
|
||||||
|
):
|
||||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||||
self.feed_forward = FeedForward(dim, hidden_dim, linear)
|
self.feed_forward = FeedForward(dim, hidden_dim, linear)
|
||||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||||
|
|
||||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
start_pos: Union[Variable, int],
|
||||||
|
freqs_cis: Tensor,
|
||||||
|
mask: Optional[Tensor],
|
||||||
|
):
|
||||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||||
|
|
||||||
|
|
||||||
class Transformer:
|
class Transformer:
|
||||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True):
|
def __init__(
|
||||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear) for _ in range(n_layers)]
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
n_layers: int,
|
||||||
|
norm_eps: float,
|
||||||
|
vocab_size,
|
||||||
|
linear=nn.Linear,
|
||||||
|
n_kv_heads=None,
|
||||||
|
rope_theta=10000,
|
||||||
|
max_context=1024,
|
||||||
|
jit=True,
|
||||||
|
):
|
||||||
|
self.layers = [
|
||||||
|
TransformerBlock(
|
||||||
|
dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear
|
||||||
|
)
|
||||||
|
for _ in range(n_layers)
|
||||||
|
]
|
||||||
self.norm = RMSNorm(dim, norm_eps)
|
self.norm = RMSNorm(dim, norm_eps)
|
||||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||||
self.output = linear(dim, vocab_size, bias=False)
|
self.output = linear(dim, vocab_size, bias=False)
|
||||||
self.max_context = max_context
|
self.max_context = max_context
|
||||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
|
self.freqs_cis = precompute_freqs_cis(
|
||||||
|
dim // n_heads, self.max_context * 2, rope_theta
|
||||||
|
)
|
||||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||||
|
|
||||||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
|
def forward(
|
||||||
|
self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float = 0.0
|
||||||
|
):
|
||||||
_bsz, seqlen = tokens.shape
|
_bsz, seqlen = tokens.shape
|
||||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
freqs_cis = self.freqs_cis.shrink(
|
||||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
(None, (start_pos, start_pos + seqlen), None, None, None)
|
||||||
|
)
|
||||||
|
mask = (
|
||||||
|
Tensor.full(
|
||||||
|
(1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32
|
||||||
|
)
|
||||||
|
.triu(start_pos + 1)
|
||||||
|
.realize()
|
||||||
|
if seqlen > 1
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
h = self.tok_embeddings(tokens)
|
h = self.tok_embeddings(tokens)
|
||||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
for layer in self.layers:
|
||||||
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
logits = self.output(self.norm(h))
|
logits = self.output(self.norm(h))
|
||||||
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().flatten().realize()
|
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().flatten().realize()
|
||||||
|
|
||||||
|
@ -121,27 +218,54 @@ class Transformer:
|
||||||
# TODO: better way to handle the first call v.s. the rest?
|
# TODO: better way to handle the first call v.s. the rest?
|
||||||
if tokens.shape[0:2] == (1, 1) and self.forward_jit and getenv("JIT", 1):
|
if tokens.shape[0:2] == (1, 1) and self.forward_jit and getenv("JIT", 1):
|
||||||
assert start_pos > 0
|
assert start_pos > 0
|
||||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
|
return self.forward_jit(
|
||||||
|
tokens,
|
||||||
|
Variable("start_pos", 1, self.max_context).bind(start_pos),
|
||||||
|
temperature,
|
||||||
|
)
|
||||||
return self.forward(tokens, start_pos, temperature)
|
return self.forward(tokens, start_pos, temperature)
|
||||||
|
|
||||||
|
|
||||||
# *** helpers ***
|
# *** helpers ***
|
||||||
|
|
||||||
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
|
||||||
|
def convert_from_huggingface(
|
||||||
|
weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int
|
||||||
|
):
|
||||||
def permute(v: Tensor, n_heads: int):
|
def permute(v: Tensor, n_heads: int):
|
||||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
return (
|
||||||
|
v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1])
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(*v.shape[:2])
|
||||||
|
)
|
||||||
|
|
||||||
keymap = {
|
keymap = {
|
||||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
**{
|
||||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
|
||||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
for l in range(len(model.layers))
|
||||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
},
|
||||||
|
**{
|
||||||
|
f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
|
||||||
|
for x in ["q", "k", "v", "o"]
|
||||||
|
for l in range(len(model.layers))
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
|
||||||
|
for l in range(len(model.layers))
|
||||||
|
},
|
||||||
|
**{
|
||||||
|
f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
|
||||||
|
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
|
||||||
|
for l in range(len(model.layers))
|
||||||
|
},
|
||||||
"model.norm.weight": "norm.weight",
|
"model.norm.weight": "norm.weight",
|
||||||
"lm_head.weight": "output.weight",
|
"lm_head.weight": "output.weight",
|
||||||
}
|
}
|
||||||
sd = {}
|
sd = {}
|
||||||
for k, v in weights.items():
|
for k, v in weights.items():
|
||||||
if ".rotary_emb." in k: continue
|
if ".rotary_emb." in k:
|
||||||
|
continue
|
||||||
v = v.to(Device.DEFAULT)
|
v = v.to(Device.DEFAULT)
|
||||||
if "model.layers" in k:
|
if "model.layers" in k:
|
||||||
if "q_proj" in k:
|
if "q_proj" in k:
|
||||||
|
|
|
@ -10,64 +10,95 @@ from tinygrad.nn.state import torch_load
|
||||||
from extra.models.resnet import ResNet
|
from extra.models.resnet import ResNet
|
||||||
from extra.models.retinanet import nms as _box_nms
|
from extra.models.retinanet import nms as _box_nms
|
||||||
|
|
||||||
USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
|
USE_NP_GATHER = os.getenv("FULL_TINYGRAD", "0") == "0"
|
||||||
|
|
||||||
|
|
||||||
def rint(tensor):
|
def rint(tensor):
|
||||||
x = (tensor * 2).cast(dtypes.int32).contiguous().cast(dtypes.float32) / 2
|
x = (tensor * 2).cast(dtypes.int32).contiguous().cast(dtypes.float32) / 2
|
||||||
return (x < 0).where(x.floor(), x.ceil())
|
return (x < 0).where(x.floor(), x.ceil())
|
||||||
|
|
||||||
|
|
||||||
def nearest_interpolate(tensor, scale_factor):
|
def nearest_interpolate(tensor, scale_factor):
|
||||||
bs, c, py, px = tensor.shape
|
bs, c, py, px = tensor.shape
|
||||||
return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
|
return (
|
||||||
|
tensor.reshape(bs, c, py, 1, px, 1)
|
||||||
|
.expand(bs, c, py, scale_factor, px, scale_factor)
|
||||||
|
.reshape(bs, c, py * scale_factor, px * scale_factor)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def meshgrid(x, y):
|
def meshgrid(x, y):
|
||||||
grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
|
grid_x = Tensor.cat(
|
||||||
|
*[x[idx : idx + 1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])]
|
||||||
|
)
|
||||||
grid_y = Tensor.cat(*[y.unsqueeze(0)] * x.shape[0])
|
grid_y = Tensor.cat(*[y.unsqueeze(0)] * x.shape[0])
|
||||||
return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
|
return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
def topk(input_, k, dim=-1, largest=True, sorted=False):
|
def topk(input_, k, dim=-1, largest=True, sorted=False):
|
||||||
k = min(k, input_.shape[dim] - 1)
|
k = min(k, input_.shape[dim] - 1)
|
||||||
input_ = input_.numpy()
|
input_ = input_.numpy()
|
||||||
if largest: input_ *= -1
|
if largest:
|
||||||
|
input_ *= -1
|
||||||
ind = np.argpartition(input_, k, axis=dim)
|
ind = np.argpartition(input_, k, axis=dim)
|
||||||
if largest: input_ *= -1
|
if largest:
|
||||||
|
input_ *= -1
|
||||||
ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
|
ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
|
||||||
input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
|
input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
|
||||||
if not sorted: return Tensor(input_), ind
|
if not sorted:
|
||||||
if largest: input_ *= -1
|
return Tensor(input_), ind
|
||||||
|
if largest:
|
||||||
|
input_ *= -1
|
||||||
ind_part = np.argsort(input_, axis=dim)
|
ind_part = np.argsort(input_, axis=dim)
|
||||||
ind = np.take_along_axis(ind, ind_part, axis=dim)
|
ind = np.take_along_axis(ind, ind_part, axis=dim)
|
||||||
if largest: input_ *= -1
|
if largest:
|
||||||
|
input_ *= -1
|
||||||
val = np.take_along_axis(input_, ind_part, axis=dim)
|
val = np.take_along_axis(input_, ind_part, axis=dim)
|
||||||
return Tensor(val), ind
|
return Tensor(val), ind
|
||||||
|
|
||||||
|
|
||||||
# This is very slow for large arrays, or indices
|
# This is very slow for large arrays, or indices
|
||||||
def _gather(array, indices):
|
def _gather(array, indices):
|
||||||
indices = indices.float().to(array.device)
|
indices = indices.float().to(array.device)
|
||||||
reshape_arg = [1] * array.ndim + [array.shape[-1]]
|
reshape_arg = [1] * array.ndim + [array.shape[-1]]
|
||||||
return Tensor.where(
|
return Tensor.where(
|
||||||
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
|
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1])
|
||||||
array, 0,
|
== Tensor.arange(array.shape[-1])
|
||||||
|
.reshape(*reshape_arg)
|
||||||
|
.expand(*indices.shape, array.shape[-1]),
|
||||||
|
array,
|
||||||
|
0,
|
||||||
).sum(indices.ndim)
|
).sum(indices.ndim)
|
||||||
|
|
||||||
|
|
||||||
# TODO: replace npgather with a faster gather using tinygrad only
|
# TODO: replace npgather with a faster gather using tinygrad only
|
||||||
# NOTE: this blocks the gradient
|
# NOTE: this blocks the gradient
|
||||||
def npgather(array, indices):
|
def npgather(array, indices):
|
||||||
if isinstance(array, Tensor): array = array.numpy()
|
if isinstance(array, Tensor):
|
||||||
if isinstance(indices, Tensor): indices = indices.numpy()
|
array = array.numpy()
|
||||||
if isinstance(indices, list): indices = np.asarray(indices)
|
if isinstance(indices, Tensor):
|
||||||
|
indices = indices.numpy()
|
||||||
|
if isinstance(indices, list):
|
||||||
|
indices = np.asarray(indices)
|
||||||
return Tensor(array[indices.astype(int)])
|
return Tensor(array[indices.astype(int)])
|
||||||
|
|
||||||
|
|
||||||
def get_strides(shape):
|
def get_strides(shape):
|
||||||
prod = [1]
|
prod = [1]
|
||||||
for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
|
for idx in range(len(shape) - 1, -1, -1):
|
||||||
|
prod.append(prod[-1] * shape[idx])
|
||||||
# something about ints is broken with gpu, cuda
|
# something about ints is broken with gpu, cuda
|
||||||
return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu()
|
return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu()
|
||||||
|
|
||||||
|
|
||||||
# with keys as integer array for all axes
|
# with keys as integer array for all axes
|
||||||
def tensor_getitem(tensor, *keys):
|
def tensor_getitem(tensor, *keys):
|
||||||
# something about ints is broken with gpu, cuda
|
# something about ints is broken with gpu, cuda
|
||||||
flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cpu().cast(dtypes.int32)
|
flat_keys = (
|
||||||
|
Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1)
|
||||||
|
.cpu()
|
||||||
|
.cast(dtypes.int32)
|
||||||
|
)
|
||||||
strides = get_strides(tensor.shape)
|
strides = get_strides(tensor.shape)
|
||||||
idxs = (flat_keys * strides).sum(1)
|
idxs = (flat_keys * strides).sum(1)
|
||||||
gatherer = npgather if USE_NP_GATHER else _gather
|
gatherer = npgather if USE_NP_GATHER else _gather
|
||||||
|
@ -97,7 +128,8 @@ def tensor_gather(tensor, indices):
|
||||||
|
|
||||||
|
|
||||||
class LastLevelMaxPool:
|
class LastLevelMaxPool:
|
||||||
def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
|
def __call__(self, x):
|
||||||
|
return [Tensor.max_pool2d(x, 1, 2)]
|
||||||
|
|
||||||
|
|
||||||
# transpose
|
# transpose
|
||||||
|
@ -117,9 +149,7 @@ class BoxList:
|
||||||
if not isinstance(bbox, Tensor):
|
if not isinstance(bbox, Tensor):
|
||||||
bbox = Tensor(bbox)
|
bbox = Tensor(bbox)
|
||||||
if bbox.ndim != 2:
|
if bbox.ndim != 2:
|
||||||
raise ValueError(
|
raise ValueError("bbox should have 2 dimensions, got {}".format(bbox.ndim))
|
||||||
"bbox should have 2 dimensions, got {}".format(bbox.ndim)
|
|
||||||
)
|
|
||||||
if bbox.shape[-1] != 4:
|
if bbox.shape[-1] != 4:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"last dimenion of bbox should have a "
|
"last dimenion of bbox should have a "
|
||||||
|
@ -145,7 +175,9 @@ class BoxList:
|
||||||
box = self.bbox
|
box = self.bbox
|
||||||
if self.mode == "xyxy":
|
if self.mode == "xyxy":
|
||||||
TO_REMOVE = 1
|
TO_REMOVE = 1
|
||||||
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
|
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (
|
||||||
|
box[:, 3] - box[:, 1] + TO_REMOVE
|
||||||
|
)
|
||||||
elif self.mode == "xywh":
|
elif self.mode == "xywh":
|
||||||
area = box[:, 2] * box[:, 3]
|
area = box[:, 2] * box[:, 3]
|
||||||
return area
|
return area
|
||||||
|
@ -241,7 +273,8 @@ class BoxList:
|
||||||
transposed_ymax = image_height - ymin
|
transposed_ymax = image_height - ymin
|
||||||
|
|
||||||
transposed_boxes = Tensor.cat(
|
transposed_boxes = Tensor.cat(
|
||||||
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
|
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax),
|
||||||
|
dim=-1,
|
||||||
)
|
)
|
||||||
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
|
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
|
||||||
for k, v in self.extra_fields.items():
|
for k, v in self.extra_fields.items():
|
||||||
|
@ -289,7 +322,11 @@ def cat_boxlist(bboxes):
|
||||||
else:
|
else:
|
||||||
cat_boxes = BoxList(bboxes[0].bbox, size, mode)
|
cat_boxes = BoxList(bboxes[0].bbox, size, mode)
|
||||||
for field in fields:
|
for field in fields:
|
||||||
cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
|
cat_field_list = [
|
||||||
|
bbox.get_field(field)
|
||||||
|
for bbox in bboxes
|
||||||
|
if bbox.get_field(field).shape[0] > 0
|
||||||
|
]
|
||||||
|
|
||||||
if len(cat_box_list) > 0:
|
if len(cat_box_list) > 0:
|
||||||
data = Tensor.cat(*cat_field_list, dim=0)
|
data = Tensor.cat(*cat_field_list, dim=0)
|
||||||
|
@ -305,8 +342,12 @@ class FPN:
|
||||||
def __init__(self, in_channels_list, out_channels):
|
def __init__(self, in_channels_list, out_channels):
|
||||||
self.inner_blocks, self.layer_blocks = [], []
|
self.inner_blocks, self.layer_blocks = [], []
|
||||||
for in_channels in in_channels_list:
|
for in_channels in in_channels_list:
|
||||||
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
self.inner_blocks.append(
|
||||||
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||||
|
)
|
||||||
|
self.layer_blocks.append(
|
||||||
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
)
|
||||||
self.top_block = LastLevelMaxPool()
|
self.top_block = LastLevelMaxPool()
|
||||||
|
|
||||||
def __call__(self, x: Tensor):
|
def __call__(self, x: Tensor):
|
||||||
|
@ -357,9 +398,7 @@ class AnchorGenerator:
|
||||||
):
|
):
|
||||||
if len(anchor_strides) == 1:
|
if len(anchor_strides) == 1:
|
||||||
anchor_stride = anchor_strides[0]
|
anchor_stride = anchor_strides[0]
|
||||||
cell_anchors = [
|
cell_anchors = [generate_anchors(anchor_stride, sizes, aspect_ratios)]
|
||||||
generate_anchors(anchor_stride, sizes, aspect_ratios)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
if len(anchor_strides) != len(sizes):
|
if len(anchor_strides) != len(sizes):
|
||||||
raise RuntimeError("FPN should have #anchor_strides == #sizes")
|
raise RuntimeError("FPN should have #anchor_strides == #sizes")
|
||||||
|
@ -368,7 +407,7 @@ class AnchorGenerator:
|
||||||
generate_anchors(
|
generate_anchors(
|
||||||
anchor_stride,
|
anchor_stride,
|
||||||
size if isinstance(size, (tuple, list)) else (size,),
|
size if isinstance(size, (tuple, list)) else (size,),
|
||||||
aspect_ratios
|
aspect_ratios,
|
||||||
)
|
)
|
||||||
for anchor_stride, size in zip(anchor_strides, sizes)
|
for anchor_stride, size in zip(anchor_strides, sizes)
|
||||||
]
|
]
|
||||||
|
@ -387,10 +426,18 @@ class AnchorGenerator:
|
||||||
grid_height, grid_width = size
|
grid_height, grid_width = size
|
||||||
device = base_anchors.device
|
device = base_anchors.device
|
||||||
shifts_x = Tensor.arange(
|
shifts_x = Tensor.arange(
|
||||||
start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
|
start=0,
|
||||||
|
stop=grid_width * stride,
|
||||||
|
step=stride,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
shifts_y = Tensor.arange(
|
shifts_y = Tensor.arange(
|
||||||
start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
|
start=0,
|
||||||
|
stop=grid_height * stride,
|
||||||
|
step=stride,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
|
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
|
||||||
shift_x = shift_x.reshape(-1)
|
shift_x = shift_x.reshape(-1)
|
||||||
|
@ -398,7 +445,9 @@ class AnchorGenerator:
|
||||||
shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
||||||
|
|
||||||
anchors.append(
|
anchors.append(
|
||||||
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
|
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(
|
||||||
|
-1, 4
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return anchors
|
return anchors
|
||||||
|
@ -415,14 +464,16 @@ class AnchorGenerator:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
device = anchors.device
|
device = anchors.device
|
||||||
inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
|
inds_inside = Tensor.ones(
|
||||||
|
anchors.shape[0], dtype=dtypes.uint8, device=device
|
||||||
|
)
|
||||||
boxlist.add_field("visibility", inds_inside)
|
boxlist.add_field("visibility", inds_inside)
|
||||||
|
|
||||||
def __call__(self, image_list, feature_maps):
|
def __call__(self, image_list, feature_maps):
|
||||||
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
||||||
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
|
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
|
||||||
anchors = []
|
anchors = []
|
||||||
for (image_height, image_width) in image_list.image_sizes:
|
for image_height, image_width in image_list.image_sizes:
|
||||||
anchors_in_image = []
|
anchors_in_image = []
|
||||||
for anchors_per_feature_map in anchors_over_all_feature_maps:
|
for anchors_per_feature_map in anchors_over_all_feature_maps:
|
||||||
boxlist = BoxList(
|
boxlist = BoxList(
|
||||||
|
@ -437,14 +488,19 @@ class AnchorGenerator:
|
||||||
def generate_anchors(
|
def generate_anchors(
|
||||||
stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
|
stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
|
||||||
):
|
):
|
||||||
return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
|
return _generate_anchors(
|
||||||
|
stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_anchors(base_size, scales, aspect_ratios):
|
def _generate_anchors(base_size, scales, aspect_ratios):
|
||||||
anchor = Tensor([1, 1, base_size, base_size]) - 1
|
anchor = Tensor([1, 1, base_size, base_size]) - 1
|
||||||
anchors = _ratio_enum(anchor, aspect_ratios)
|
anchors = _ratio_enum(anchor, aspect_ratios)
|
||||||
anchors = Tensor.cat(
|
anchors = Tensor.cat(
|
||||||
*[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
|
*[
|
||||||
|
_scale_enum(anchors[i, :], scales).reshape(-1, 4)
|
||||||
|
for i in range(anchors.shape[0])
|
||||||
|
]
|
||||||
)
|
)
|
||||||
return anchors
|
return anchors
|
||||||
|
|
||||||
|
@ -460,12 +516,15 @@ def _whctrs(anchor):
|
||||||
def _mkanchors(ws, hs, x_ctr, y_ctr):
|
def _mkanchors(ws, hs, x_ctr, y_ctr):
|
||||||
ws = ws[:, None]
|
ws = ws[:, None]
|
||||||
hs = hs[:, None]
|
hs = hs[:, None]
|
||||||
anchors = Tensor.cat(*(
|
anchors = Tensor.cat(
|
||||||
|
*(
|
||||||
x_ctr - 0.5 * (ws - 1),
|
x_ctr - 0.5 * (ws - 1),
|
||||||
y_ctr - 0.5 * (hs - 1),
|
y_ctr - 0.5 * (hs - 1),
|
||||||
x_ctr + 0.5 * (ws - 1),
|
x_ctr + 0.5 * (ws - 1),
|
||||||
y_ctr + 0.5 * (hs - 1),
|
y_ctr + 0.5 * (hs - 1),
|
||||||
), dim=1)
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
return anchors
|
return anchors
|
||||||
|
|
||||||
|
|
||||||
|
@ -504,7 +563,7 @@ class RPNHead:
|
||||||
|
|
||||||
|
|
||||||
class BoxCoder(object):
|
class BoxCoder(object):
|
||||||
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
|
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
self.bbox_xform_clip = bbox_xform_clip
|
self.bbox_xform_clip = bbox_xform_clip
|
||||||
|
|
||||||
|
@ -557,7 +616,11 @@ class BoxCoder(object):
|
||||||
y = pred_ctr_y - 0.5 * pred_h
|
y = pred_ctr_y - 0.5 * pred_h
|
||||||
w = pred_ctr_x + 0.5 * pred_w - 1
|
w = pred_ctr_x + 0.5 * pred_w - 1
|
||||||
h = pred_ctr_y + 0.5 * pred_h - 1
|
h = pred_ctr_y + 0.5 * pred_h - 1
|
||||||
pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
|
pred_boxes = (
|
||||||
|
Tensor.stack([x, y, w, h])
|
||||||
|
.permute(1, 2, 0)
|
||||||
|
.reshape(rel_codes.shape[0], rel_codes.shape[1])
|
||||||
|
)
|
||||||
return pred_boxes
|
return pred_boxes
|
||||||
|
|
||||||
|
|
||||||
|
@ -578,9 +641,7 @@ def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
|
||||||
def remove_small_boxes(boxlist, min_size):
|
def remove_small_boxes(boxlist, min_size):
|
||||||
xywh_boxes = boxlist.convert("xywh").bbox
|
xywh_boxes = boxlist.convert("xywh").bbox
|
||||||
_, _, ws, hs = xywh_boxes.chunk(4, dim=1)
|
_, _, ws, hs = xywh_boxes.chunk(4, dim=1)
|
||||||
keep = ((
|
keep = (((ws >= min_size) * (hs >= min_size)) > 0).reshape(-1)
|
||||||
(ws >= min_size) * (hs >= min_size)
|
|
||||||
) > 0).reshape(-1)
|
|
||||||
if keep.sum().numpy() == len(boxlist):
|
if keep.sum().numpy() == len(boxlist):
|
||||||
return boxlist
|
return boxlist
|
||||||
else:
|
else:
|
||||||
|
@ -630,8 +691,12 @@ class RPNPostProcessor:
|
||||||
box_regression_list = []
|
box_regression_list = []
|
||||||
concat_anchors_list = []
|
concat_anchors_list = []
|
||||||
for batch_idx in range(N):
|
for batch_idx in range(N):
|
||||||
box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
|
box_regression_list.append(
|
||||||
concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
|
tensor_gather(box_regression[batch_idx], topk_idx[batch_idx])
|
||||||
|
)
|
||||||
|
concat_anchors_list.append(
|
||||||
|
tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx])
|
||||||
|
)
|
||||||
|
|
||||||
box_regression = Tensor.stack(box_regression_list)
|
box_regression = Tensor.stack(box_regression_list)
|
||||||
concat_anchors = Tensor.stack(concat_anchors_list)
|
concat_anchors = Tensor.stack(concat_anchors_list)
|
||||||
|
@ -677,9 +742,7 @@ class RPNPostProcessor:
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
objectness = boxlists[i].get_field("objectness")
|
objectness = boxlists[i].get_field("objectness")
|
||||||
post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
|
post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
|
||||||
_, inds_sorted = topk(objectness,
|
_, inds_sorted = topk(objectness, post_nms_top_n, dim=0, sorted=False)
|
||||||
post_nms_top_n, dim=0, sorted=False
|
|
||||||
)
|
|
||||||
boxlists[i] = boxlists[i][inds_sorted]
|
boxlists[i] = boxlists[i][inds_sorted]
|
||||||
return boxlists
|
return boxlists
|
||||||
|
|
||||||
|
@ -689,9 +752,7 @@ class RPN:
|
||||||
self.anchor_generator = AnchorGenerator()
|
self.anchor_generator = AnchorGenerator()
|
||||||
|
|
||||||
in_channels = 256
|
in_channels = 256
|
||||||
head = RPNHead(
|
head = RPNHead(in_channels, self.anchor_generator.num_anchors_per_location()[0])
|
||||||
in_channels, self.anchor_generator.num_anchors_per_location()[0]
|
|
||||||
)
|
|
||||||
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
||||||
box_selector_test = RPNPostProcessor(
|
box_selector_test = RPNPostProcessor(
|
||||||
pre_nms_top_n=1000,
|
pre_nms_top_n=1000,
|
||||||
|
@ -699,7 +760,7 @@ class RPN:
|
||||||
nms_thresh=0.7,
|
nms_thresh=0.7,
|
||||||
min_size=0,
|
min_size=0,
|
||||||
box_coder=rpn_box_coder,
|
box_coder=rpn_box_coder,
|
||||||
fpn_post_nms_top_n=1000
|
fpn_post_nms_top_n=1000,
|
||||||
)
|
)
|
||||||
self.head = head
|
self.head = head
|
||||||
self.box_selector_test = box_selector_test
|
self.box_selector_test = box_selector_test
|
||||||
|
@ -725,7 +786,7 @@ def make_conv3x3(
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=dilation,
|
padding=dilation,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
bias=False if use_gn else True
|
bias=False if use_gn else True,
|
||||||
)
|
)
|
||||||
return conv
|
return conv
|
||||||
|
|
||||||
|
@ -746,10 +807,18 @@ class MaskRCNNFPNFeatureExtractor:
|
||||||
use_gn = False
|
use_gn = False
|
||||||
layers = (256, 256, 256, 256)
|
layers = (256, 256, 256, 256)
|
||||||
dilation = 1
|
dilation = 1
|
||||||
self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
|
self.mask_fcn1 = make_conv3x3(
|
||||||
self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
|
input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn
|
||||||
self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
|
)
|
||||||
self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
|
self.mask_fcn2 = make_conv3x3(
|
||||||
|
layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn
|
||||||
|
)
|
||||||
|
self.mask_fcn3 = make_conv3x3(
|
||||||
|
layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn
|
||||||
|
)
|
||||||
|
self.mask_fcn4 = make_conv3x3(
|
||||||
|
layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn
|
||||||
|
)
|
||||||
self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
|
self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
|
||||||
|
|
||||||
def __call__(self, x, proposals):
|
def __call__(self, x, proposals):
|
||||||
|
@ -833,7 +902,9 @@ def _bilinear_interpolate(
|
||||||
y = Tensor.where(ymask[:, None, :], y, 0)
|
y = Tensor.where(ymask[:, None, :], y, 0)
|
||||||
x = Tensor.where(xmask[:, None, :], x, 0)
|
x = Tensor.where(xmask[:, None, :], x, 0)
|
||||||
key1 = roi_batch_ind[:, None, None, None, None, None]
|
key1 = roi_batch_ind[:, None, None, None, None, None]
|
||||||
key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
|
key2 = Tensor.arange(channels, device=input.device)[
|
||||||
|
None, :, None, None, None, None
|
||||||
|
]
|
||||||
key3 = y[:, None, :, None, :, None]
|
key3 = y[:, None, :, None, :, None]
|
||||||
key4 = x[:, None, None, :, None, :]
|
key4 = x[:, None, None, :, None, :]
|
||||||
return tensor_getitem(input, key1, key2, key3, key4) # [K, C, PH, PW, IY, IX]
|
return tensor_getitem(input, key1, key2, key3, key4) # [K, C, PH, PW, IY, IX]
|
||||||
|
@ -855,8 +926,11 @@ def _bilinear_interpolate(
|
||||||
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
|
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
# https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
|
# https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
|
||||||
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
def _roi_align(
|
||||||
|
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned
|
||||||
|
):
|
||||||
orig_dtype = input.dtype
|
orig_dtype = input.dtype
|
||||||
_, _, height, width = input.shape
|
_, _, height, width = input.shape
|
||||||
ph = Tensor.arange(pooled_height, device=input.device)
|
ph = Tensor.arange(pooled_height, device=input.device)
|
||||||
|
@ -879,8 +953,12 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
|
||||||
bin_size_w = roi_width / pooled_width
|
bin_size_w = roi_width / pooled_width
|
||||||
|
|
||||||
exact_sampling = sampling_ratio > 0
|
exact_sampling = sampling_ratio > 0
|
||||||
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
|
roi_bin_grid_h = (
|
||||||
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
|
sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
|
||||||
|
)
|
||||||
|
roi_bin_grid_w = (
|
||||||
|
sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
|
||||||
|
)
|
||||||
|
|
||||||
if exact_sampling:
|
if exact_sampling:
|
||||||
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
|
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
|
||||||
|
@ -923,6 +1001,7 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
|
||||||
output = output.cast(orig_dtype)
|
output = output.cast(orig_dtype)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class ROIAlign:
|
class ROIAlign:
|
||||||
def __init__(self, output_size, spatial_scale, sampling_ratio):
|
def __init__(self, output_size, spatial_scale, sampling_ratio):
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
|
@ -931,7 +1010,13 @@ class ROIAlign:
|
||||||
|
|
||||||
def __call__(self, input, rois):
|
def __call__(self, input, rois):
|
||||||
output = _roi_align(
|
output = _roi_align(
|
||||||
input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
|
input,
|
||||||
|
rois,
|
||||||
|
self.spatial_scale,
|
||||||
|
self.output_size[0],
|
||||||
|
self.output_size[1],
|
||||||
|
self.sampling_ratio,
|
||||||
|
aligned=False,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -1002,7 +1087,16 @@ class Pooler:
|
||||||
all_idxs.extend(idx_in_level)
|
all_idxs.extend(idx_in_level)
|
||||||
results.append(pooler_output)
|
results.append(pooler_output)
|
||||||
|
|
||||||
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
|
return tensor_gather(
|
||||||
|
Tensor.cat(*results),
|
||||||
|
[
|
||||||
|
x[0]
|
||||||
|
for x in sorted(
|
||||||
|
{i: idx for i, idx in enumerate(all_idxs)}.items(),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FPNPredictor:
|
class FPNPredictor:
|
||||||
|
@ -1027,13 +1121,13 @@ class PostProcessor:
|
||||||
nms=0.5,
|
nms=0.5,
|
||||||
detections_per_img=100,
|
detections_per_img=100,
|
||||||
box_coder=None,
|
box_coder=None,
|
||||||
cls_agnostic_bbox_reg=False
|
cls_agnostic_bbox_reg=False,
|
||||||
):
|
):
|
||||||
self.score_thresh = score_thresh
|
self.score_thresh = score_thresh
|
||||||
self.nms = nms
|
self.nms = nms
|
||||||
self.detections_per_img = detections_per_img
|
self.detections_per_img = detections_per_img
|
||||||
if box_coder is None:
|
if box_coder is None:
|
||||||
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
|
box_coder = BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
|
||||||
self.box_coder = box_coder
|
self.box_coder = box_coder
|
||||||
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
|
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
|
||||||
|
|
||||||
|
@ -1090,9 +1184,7 @@ class PostProcessor:
|
||||||
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
|
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
|
||||||
boxlist_for_class.add_field("scores", scores_j)
|
boxlist_for_class.add_field("scores", scores_j)
|
||||||
if len(boxlist_for_class):
|
if len(boxlist_for_class):
|
||||||
boxlist_for_class = boxlist_nms(
|
boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
|
||||||
boxlist_for_class, self.nms
|
|
||||||
)
|
|
||||||
num_labels = len(boxlist_for_class)
|
num_labels = len(boxlist_for_class)
|
||||||
boxlist_for_class.add_field(
|
boxlist_for_class.add_field(
|
||||||
"labels", Tensor.full((num_labels,), j, device=device)
|
"labels", Tensor.full((num_labels,), j, device=device)
|
||||||
|
@ -1119,8 +1211,8 @@ class RoIBoxHead:
|
||||||
score_thresh=0.05,
|
score_thresh=0.05,
|
||||||
nms=0.5,
|
nms=0.5,
|
||||||
detections_per_img=100,
|
detections_per_img=100,
|
||||||
box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
|
box_coder=BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)),
|
||||||
cls_agnostic_bbox_reg=False
|
cls_agnostic_bbox_reg=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, features, proposals, targets=None):
|
def __call__(self, features, proposals, targets=None):
|
||||||
|
@ -1210,7 +1302,6 @@ def to_image_list(tensors, size_divisible=32):
|
||||||
elif isinstance(tensors, (tuple, list)):
|
elif isinstance(tensors, (tuple, list)):
|
||||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
|
||||||
if size_divisible > 0:
|
if size_divisible > 0:
|
||||||
|
|
||||||
stride = size_divisible
|
stride = size_divisible
|
||||||
max_size = list(max_size)
|
max_size = list(max_size)
|
||||||
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
|
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
|
||||||
|
@ -1237,10 +1328,13 @@ class MaskRCNN:
|
||||||
self.roi_heads = RoIHeads(self.backbone.out_channels)
|
self.roi_heads = RoIHeads(self.backbone.out_channels)
|
||||||
|
|
||||||
def load_from_pretrained(self):
|
def load_from_pretrained(self):
|
||||||
fn = Path('./') / "weights/maskrcnn.pt"
|
fn = Path("./") / "weights/maskrcnn.pt"
|
||||||
fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
|
fetch(
|
||||||
|
"https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth",
|
||||||
|
fn,
|
||||||
|
)
|
||||||
|
|
||||||
state_dict = torch_load(fn)['model']
|
state_dict = torch_load(fn)["model"]
|
||||||
loaded_keys = []
|
loaded_keys = []
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if "module." in k:
|
if "module." in k:
|
||||||
|
@ -1265,7 +1359,7 @@ class MaskRCNN:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||||
model = MaskRCNN(backbone=resnet)
|
model = MaskRCNN(backbone=resnet)
|
||||||
model.load_from_pretrained()
|
model.load_from_pretrained()
|
||||||
|
|
|
@ -3,20 +3,33 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn.state import torch_load
|
from tinygrad.nn.state import torch_load
|
||||||
from tinygrad.helpers import fetch, get_child
|
from tinygrad.helpers import fetch, get_child
|
||||||
|
|
||||||
|
|
||||||
class BasicBlock:
|
class BasicBlock:
|
||||||
expansion = 1
|
expansion = 1
|
||||||
|
|
||||||
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
||||||
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
|
assert (
|
||||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
groups == 1 and base_width == 64
|
||||||
|
), "BasicBlock only supports groups=1 and base_width=64"
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||||
|
)
|
||||||
self.bn1 = nn.BatchNorm2d(planes)
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
self.conv2 = nn.Conv2d(
|
||||||
|
planes, planes, kernel_size=3, padding=1, stride=1, bias=False
|
||||||
|
)
|
||||||
self.bn2 = nn.BatchNorm2d(planes)
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
self.downsample = []
|
self.downsample = []
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.downsample = [
|
self.downsample = [
|
||||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
nn.Conv2d(
|
||||||
nn.BatchNorm2d(self.expansion*planes)
|
in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
@ -31,20 +44,44 @@ class Bottleneck:
|
||||||
# NOTE: stride_in_1x1=False, this is the v1.5 variant
|
# NOTE: stride_in_1x1=False, this is the v1.5 variant
|
||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
|
def __init__(
|
||||||
|
self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64
|
||||||
|
):
|
||||||
width = int(planes * (base_width / 64.0)) * groups
|
width = int(planes * (base_width / 64.0)) * groups
|
||||||
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
|
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
|
||||||
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_planes,
|
||||||
|
width,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride if stride_in_1x1 else 1,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
self.bn1 = nn.BatchNorm2d(width)
|
self.bn1 = nn.BatchNorm2d(width)
|
||||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
|
self.conv2 = nn.Conv2d(
|
||||||
|
width,
|
||||||
|
width,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
stride=1 if stride_in_1x1 else stride,
|
||||||
|
groups=groups,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
self.bn2 = nn.BatchNorm2d(width)
|
self.bn2 = nn.BatchNorm2d(width)
|
||||||
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
|
self.conv3 = nn.Conv2d(
|
||||||
|
width, self.expansion * planes, kernel_size=1, bias=False
|
||||||
|
)
|
||||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||||
self.downsample = []
|
self.downsample = []
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.downsample = [
|
self.downsample = [
|
||||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
nn.Conv2d(
|
||||||
nn.BatchNorm2d(self.expansion*planes)
|
in_planes,
|
||||||
|
self.expansion * planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
@ -55,15 +92,18 @@ class Bottleneck:
|
||||||
out = out.relu()
|
out = out.relu()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ResNet:
|
class ResNet:
|
||||||
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
|
def __init__(
|
||||||
|
self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False
|
||||||
|
):
|
||||||
self.num = num
|
self.num = num
|
||||||
self.block = {
|
self.block = {
|
||||||
18: BasicBlock,
|
18: BasicBlock,
|
||||||
34: BasicBlock,
|
34: BasicBlock,
|
||||||
50: Bottleneck,
|
50: Bottleneck,
|
||||||
101: Bottleneck,
|
101: Bottleneck,
|
||||||
152: Bottleneck
|
152: Bottleneck,
|
||||||
}[num]
|
}[num]
|
||||||
|
|
||||||
self.num_blocks = {
|
self.num_blocks = {
|
||||||
|
@ -71,7 +111,7 @@ class ResNet:
|
||||||
34: [3, 4, 6, 3],
|
34: [3, 4, 6, 3],
|
||||||
50: [3, 4, 6, 3],
|
50: [3, 4, 6, 3],
|
||||||
101: [3, 4, 23, 3],
|
101: [3, 4, 23, 3],
|
||||||
152: [3,8,36,3]
|
152: [3, 8, 36, 3],
|
||||||
}[num]
|
}[num]
|
||||||
|
|
||||||
self.in_planes = 64
|
self.in_planes = 64
|
||||||
|
@ -80,36 +120,64 @@ class ResNet:
|
||||||
self.base_width = width_per_group
|
self.base_width = width_per_group
|
||||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
||||||
self.bn1 = nn.BatchNorm2d(64)
|
self.bn1 = nn.BatchNorm2d(64)
|
||||||
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
|
self.layer1 = self._make_layer(
|
||||||
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
|
self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1
|
||||||
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
|
)
|
||||||
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
|
self.layer2 = self._make_layer(
|
||||||
self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
|
self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1
|
||||||
|
)
|
||||||
|
self.layer3 = self._make_layer(
|
||||||
|
self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1
|
||||||
|
)
|
||||||
|
self.layer4 = self._make_layer(
|
||||||
|
self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1
|
||||||
|
)
|
||||||
|
self.fc = (
|
||||||
|
nn.Linear(512 * self.block.expansion, num_classes)
|
||||||
|
if num_classes is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
|
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
|
||||||
strides = [stride] + [1] * (num_blocks - 1)
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
layers = []
|
layers = []
|
||||||
for stride in strides:
|
for stride in strides:
|
||||||
if block == Bottleneck:
|
if block == Bottleneck:
|
||||||
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
|
layers.append(
|
||||||
|
block(
|
||||||
|
self.in_planes,
|
||||||
|
planes,
|
||||||
|
stride,
|
||||||
|
stride_in_1x1,
|
||||||
|
self.groups,
|
||||||
|
self.base_width,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
|
layers.append(
|
||||||
|
block(self.in_planes, planes, stride, self.groups, self.base_width)
|
||||||
|
)
|
||||||
self.in_planes = planes * block.expansion
|
self.in_planes = planes * block.expansion
|
||||||
return layers
|
return layers
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
is_feature_only = self.fc is None
|
is_feature_only = self.fc is None
|
||||||
if is_feature_only: features = []
|
if is_feature_only:
|
||||||
|
features = []
|
||||||
out = self.bn1(self.conv1(x)).relu()
|
out = self.bn1(self.conv1(x)).relu()
|
||||||
out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2)
|
out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2)
|
||||||
out = out.sequential(self.layer1)
|
out = out.sequential(self.layer1)
|
||||||
if is_feature_only: features.append(out)
|
if is_feature_only:
|
||||||
|
features.append(out)
|
||||||
out = out.sequential(self.layer2)
|
out = out.sequential(self.layer2)
|
||||||
if is_feature_only: features.append(out)
|
if is_feature_only:
|
||||||
|
features.append(out)
|
||||||
out = out.sequential(self.layer3)
|
out = out.sequential(self.layer3)
|
||||||
if is_feature_only: features.append(out)
|
if is_feature_only:
|
||||||
|
features.append(out)
|
||||||
out = out.sequential(self.layer4)
|
out = out.sequential(self.layer4)
|
||||||
if is_feature_only: features.append(out)
|
if is_feature_only:
|
||||||
|
features.append(out)
|
||||||
if not is_feature_only:
|
if not is_feature_only:
|
||||||
out = out.mean([2, 3])
|
out = out.mean([2, 3])
|
||||||
out = self.fc(out).log_softmax()
|
out = self.fc(out).log_softmax()
|
||||||
|
@ -123,12 +191,16 @@ class ResNet:
|
||||||
# TODO replace with fake torch load
|
# TODO replace with fake torch load
|
||||||
|
|
||||||
model_urls = {
|
model_urls = {
|
||||||
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
(18, 1, 64): "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
||||||
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
(34, 1, 64): "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
||||||
(50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
(50, 1, 64): "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
||||||
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
(
|
||||||
(101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
50,
|
||||||
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
32,
|
||||||
|
4,
|
||||||
|
): "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||||
|
(101, 1, 64): "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
||||||
|
(152, 1, 64): "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||||
|
@ -136,17 +208,24 @@ class ResNet:
|
||||||
obj: Tensor = get_child(self, k)
|
obj: Tensor = get_child(self, k)
|
||||||
dat = v.detach().numpy()
|
dat = v.detach().numpy()
|
||||||
|
|
||||||
if 'fc.' in k and obj.shape != dat.shape:
|
if "fc." in k and obj.shape != dat.shape:
|
||||||
print("skipping fully connected layer")
|
print("skipping fully connected layer")
|
||||||
continue # Skip FC if transfer learning
|
continue # Skip FC if transfer learning
|
||||||
|
|
||||||
# TODO: remove or when #777 is merged
|
# TODO: remove or when #777 is merged
|
||||||
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape)
|
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (
|
||||||
|
k,
|
||||||
|
obj.shape,
|
||||||
|
dat.shape,
|
||||||
|
)
|
||||||
obj.assign(dat)
|
obj.assign(dat)
|
||||||
|
|
||||||
|
|
||||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||||
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
||||||
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
||||||
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
||||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(
|
||||||
|
50, num_classes=num_classes, groups=32, width_per_group=4
|
||||||
|
)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import tinygrad.nn as nn
|
||||||
from extra.models.resnet import ResNet
|
from extra.models.resnet import ResNet
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def nms(boxes, scores, thresh=0.5):
|
def nms(boxes, scores, thresh=0.5):
|
||||||
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
|
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
|
||||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
@ -15,11 +16,14 @@ def nms(boxes, scores, thresh=0.5):
|
||||||
inter_y1 = np.maximum(y1[cur], y1[to_process])
|
inter_y1 = np.maximum(y1[cur], y1[to_process])
|
||||||
inter_x2 = np.minimum(x2[cur], x2[to_process])
|
inter_x2 = np.minimum(x2[cur], x2[to_process])
|
||||||
inter_y2 = np.minimum(y2[cur], y2[to_process])
|
inter_y2 = np.minimum(y2[cur], y2[to_process])
|
||||||
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1)
|
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(
|
||||||
|
0, inter_y2 - inter_y1 + 1
|
||||||
|
)
|
||||||
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
|
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
|
||||||
to_process = to_process[np.where(iou <= thresh)[0]]
|
to_process = to_process[np.where(iou <= thresh)[0]]
|
||||||
return keep
|
return keep
|
||||||
|
|
||||||
|
|
||||||
def decode_bbox(offsets, anchors):
|
def decode_bbox(offsets, anchors):
|
||||||
dx, dy, dw, dh = np.rollaxis(offsets, 1)
|
dx, dy, dw, dh = np.rollaxis(offsets, 1)
|
||||||
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
|
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
|
||||||
|
@ -30,6 +34,7 @@ def decode_bbox(offsets, anchors):
|
||||||
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
|
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
|
||||||
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
|
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
|
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
|
||||||
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
|
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
|
||||||
anchors = []
|
anchors = []
|
||||||
|
@ -41,39 +46,87 @@ def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
|
||||||
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
|
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
|
||||||
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
|
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
|
||||||
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
|
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
|
||||||
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
|
shifts_x, shifts_y = np.meshgrid(
|
||||||
|
np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h
|
||||||
|
)
|
||||||
shifts_x = shifts_x.reshape(-1)
|
shifts_x = shifts_x.reshape(-1)
|
||||||
shifts_y = shifts_y.reshape(-1)
|
shifts_y = shifts_y.reshape(-1)
|
||||||
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
|
shifts = np.stack(
|
||||||
|
[shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32
|
||||||
|
)
|
||||||
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
|
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
|
||||||
return anchors
|
return anchors
|
||||||
|
|
||||||
|
|
||||||
class RetinaNet:
|
class RetinaNet:
|
||||||
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone: ResNet,
|
||||||
|
num_classes=264,
|
||||||
|
num_anchors=9,
|
||||||
|
scales=None,
|
||||||
|
aspect_ratios=None,
|
||||||
|
):
|
||||||
assert isinstance(backbone, ResNet)
|
assert isinstance(backbone, ResNet)
|
||||||
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
|
scales = (
|
||||||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
tuple(
|
||||||
|
(i, int(i * 2 ** (1 / 3)), int(i * 2 ** (2 / 3)))
|
||||||
|
for i in 2 ** np.arange(5, 10)
|
||||||
|
)
|
||||||
|
if scales is None
|
||||||
|
else scales
|
||||||
|
)
|
||||||
|
aspect_ratios = (
|
||||||
|
((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
||||||
|
)
|
||||||
self.num_anchors, self.num_classes = num_anchors, num_classes
|
self.num_anchors, self.num_classes = num_anchors, num_classes
|
||||||
assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios))
|
assert len(scales) == len(aspect_ratios) and all(
|
||||||
|
self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)
|
||||||
|
)
|
||||||
|
|
||||||
self.backbone = ResNetFPN(backbone)
|
self.backbone = ResNetFPN(backbone)
|
||||||
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
|
self.head = RetinaHead(
|
||||||
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
|
self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes
|
||||||
|
)
|
||||||
|
self.anchor_gen = lambda input_size: generate_anchors(
|
||||||
|
input_size,
|
||||||
|
self.backbone.compute_grid_sizes(input_size),
|
||||||
|
scales,
|
||||||
|
aspect_ratios,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return self.forward(x)
|
return self.forward(x)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.head(self.backbone(x))
|
return self.head(self.backbone(x))
|
||||||
|
|
||||||
def load_from_pretrained(self):
|
def load_from_pretrained(self):
|
||||||
model_urls = {
|
model_urls = {
|
||||||
(50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
(
|
||||||
(50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
|
50,
|
||||||
|
1,
|
||||||
|
64,
|
||||||
|
): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
||||||
|
(
|
||||||
|
50,
|
||||||
|
32,
|
||||||
|
4,
|
||||||
|
): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
|
||||||
}
|
}
|
||||||
self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)]
|
self.url = model_urls[
|
||||||
|
(
|
||||||
|
self.backbone.body.num,
|
||||||
|
self.backbone.body.groups,
|
||||||
|
self.backbone.body.base_width,
|
||||||
|
)
|
||||||
|
]
|
||||||
from torch.hub import load_state_dict_from_url
|
from torch.hub import load_state_dict_from_url
|
||||||
state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu')
|
|
||||||
state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict
|
state_dict = load_state_dict_from_url(
|
||||||
|
self.url, progress=True, map_location="cpu"
|
||||||
|
)
|
||||||
|
state_dict = state_dict["model"] if "model" in state_dict.keys() else state_dict
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
obj = get_child(self, k)
|
obj = get_child(self, k)
|
||||||
dat = v.detach().numpy()
|
dat = v.detach().numpy()
|
||||||
|
@ -81,10 +134,21 @@ class RetinaNet:
|
||||||
obj.assign(dat)
|
obj.assign(dat)
|
||||||
|
|
||||||
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
|
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
|
||||||
def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
|
def postprocess_detections(
|
||||||
|
self,
|
||||||
|
predictions,
|
||||||
|
input_size=(800, 800),
|
||||||
|
image_sizes=None,
|
||||||
|
orig_image_sizes=None,
|
||||||
|
score_thresh=0.05,
|
||||||
|
topk_candidates=1000,
|
||||||
|
nms_thresh=0.5,
|
||||||
|
):
|
||||||
anchors = self.anchor_gen(input_size)
|
anchors = self.anchor_gen(input_size)
|
||||||
grid_sizes = self.backbone.compute_grid_sizes(input_size)
|
grid_sizes = self.backbone.compute_grid_sizes(input_size)
|
||||||
split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
|
split_idx = np.cumsum(
|
||||||
|
[int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]]
|
||||||
|
)
|
||||||
detections = []
|
detections = []
|
||||||
for i, predictions_per_image in enumerate(predictions):
|
for i, predictions_per_image in enumerate(predictions):
|
||||||
h, w = input_size if image_sizes is None else image_sizes[i]
|
h, w = input_size if image_sizes is None else image_sizes[i]
|
||||||
|
@ -94,7 +158,9 @@ class RetinaNet:
|
||||||
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
|
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
|
||||||
|
|
||||||
image_boxes, image_scores, image_labels = [], [], []
|
image_boxes, image_scores, image_labels = [], [], []
|
||||||
for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors):
|
for offsets_per_level, scores_per_level, anchors_per_level in zip(
|
||||||
|
offsets_per_image, scores_per_image, anchors
|
||||||
|
):
|
||||||
# remove low scoring boxes
|
# remove low scoring boxes
|
||||||
scores_per_level = scores_per_level.flatten()
|
scores_per_level = scores_per_level.flatten()
|
||||||
keep_idxs = scores_per_level > score_thresh
|
keep_idxs = scores_per_level > score_thresh
|
||||||
|
@ -104,16 +170,23 @@ class RetinaNet:
|
||||||
topk_idxs = np.where(keep_idxs)[0]
|
topk_idxs = np.where(keep_idxs)[0]
|
||||||
num_topk = min(len(topk_idxs), topk_candidates)
|
num_topk = min(len(topk_idxs), topk_candidates)
|
||||||
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
|
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
|
||||||
topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs]
|
topk_idxs, scores_per_level = (
|
||||||
|
topk_idxs[sort_idxs],
|
||||||
|
scores_per_level[sort_idxs],
|
||||||
|
)
|
||||||
|
|
||||||
# bbox coords from offsets
|
# bbox coords from offsets
|
||||||
anchor_idxs = topk_idxs // self.num_classes
|
anchor_idxs = topk_idxs // self.num_classes
|
||||||
labels_per_level = topk_idxs % self.num_classes
|
labels_per_level = topk_idxs % self.num_classes
|
||||||
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
|
boxes_per_level = decode_bbox(
|
||||||
|
offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
||||||
|
)
|
||||||
# clip to image size
|
# clip to image size
|
||||||
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
|
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
|
||||||
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
|
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
|
||||||
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4)
|
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(
|
||||||
|
-1, 4
|
||||||
|
)
|
||||||
|
|
||||||
image_boxes.append(boxes_per_level)
|
image_boxes.append(boxes_per_level)
|
||||||
image_scores.append(scores_per_level)
|
image_scores.append(scores_per_level)
|
||||||
|
@ -127,7 +200,9 @@ class RetinaNet:
|
||||||
keep_mask = np.zeros_like(image_scores, dtype=bool)
|
keep_mask = np.zeros_like(image_scores, dtype=bool)
|
||||||
for class_id in np.unique(image_labels):
|
for class_id in np.unique(image_labels):
|
||||||
curr_indices = np.where(image_labels == class_id)[0]
|
curr_indices = np.where(image_labels == class_id)[0]
|
||||||
curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh)
|
curr_keep_indices = nms(
|
||||||
|
image_boxes[curr_indices], image_scores[curr_indices], nms_thresh
|
||||||
|
)
|
||||||
keep_mask[curr_indices[curr_keep_indices]] = True
|
keep_mask[curr_indices[curr_keep_indices]] = True
|
||||||
keep = np.where(keep_mask)[0]
|
keep = np.where(keep_mask)[0]
|
||||||
keep = keep[image_scores[keep].argsort()[::-1]]
|
keep = keep[image_scores[keep].argsort()[::-1]]
|
||||||
|
@ -139,42 +214,91 @@ class RetinaNet:
|
||||||
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
|
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
|
||||||
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
|
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
|
||||||
# xywh format
|
# xywh format
|
||||||
image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1)
|
image_boxes = np.concatenate(
|
||||||
|
[image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]})
|
detections.append(
|
||||||
|
{
|
||||||
|
"boxes": image_boxes,
|
||||||
|
"scores": image_scores[keep],
|
||||||
|
"labels": image_labels[keep],
|
||||||
|
}
|
||||||
|
)
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
|
||||||
class ClassificationHead:
|
class ClassificationHead:
|
||||||
def __init__(self, in_channels, num_anchors, num_classes):
|
def __init__(self, in_channels, num_anchors, num_classes):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
self.conv = flatten(
|
||||||
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
|
[
|
||||||
|
(
|
||||||
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||||
|
lambda x: x.relu(),
|
||||||
|
)
|
||||||
|
for _ in range(4)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.cls_logits = nn.Conv2d(
|
||||||
|
in_channels, num_anchors * num_classes, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
|
out = [
|
||||||
|
self.cls_logits(feat.sequential(self.conv))
|
||||||
|
.permute(0, 2, 3, 1)
|
||||||
|
.reshape(feat.shape[0], -1, self.num_classes)
|
||||||
|
for feat in x
|
||||||
|
]
|
||||||
return out[0].cat(*out[1:], dim=1).sigmoid()
|
return out[0].cat(*out[1:], dim=1).sigmoid()
|
||||||
|
|
||||||
|
|
||||||
class RegressionHead:
|
class RegressionHead:
|
||||||
def __init__(self, in_channels, num_anchors):
|
def __init__(self, in_channels, num_anchors):
|
||||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
self.conv = flatten(
|
||||||
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
|
[
|
||||||
|
(
|
||||||
|
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||||
|
lambda x: x.relu(),
|
||||||
|
)
|
||||||
|
for _ in range(4)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.bbox_reg = nn.Conv2d(
|
||||||
|
in_channels, num_anchors * 4, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
|
out = [
|
||||||
|
self.bbox_reg(feat.sequential(self.conv))
|
||||||
|
.permute(0, 2, 3, 1)
|
||||||
|
.reshape(feat.shape[0], -1, 4)
|
||||||
|
for feat in x
|
||||||
|
]
|
||||||
return out[0].cat(*out[1:], dim=1)
|
return out[0].cat(*out[1:], dim=1)
|
||||||
|
|
||||||
|
|
||||||
class RetinaHead:
|
class RetinaHead:
|
||||||
def __init__(self, in_channels, num_anchors, num_classes):
|
def __init__(self, in_channels, num_anchors, num_classes):
|
||||||
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
|
self.classification_head = ClassificationHead(
|
||||||
|
in_channels, num_anchors, num_classes
|
||||||
|
)
|
||||||
self.regression_head = RegressionHead(in_channels, num_anchors)
|
self.regression_head = RegressionHead(in_channels, num_anchors)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
|
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
|
||||||
out = pred_bbox.cat(pred_class, dim=-1)
|
out = pred_bbox.cat(pred_class, dim=-1)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ResNetFPN:
|
class ResNetFPN:
|
||||||
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
|
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.body = resnet
|
self.body = resnet
|
||||||
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
|
in_channels_list = [
|
||||||
|
(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers
|
||||||
|
]
|
||||||
self.fpn = FPN(in_channels_list, out_channels)
|
self.fpn = FPN(in_channels_list, out_channels)
|
||||||
|
|
||||||
# this is needed to decouple inference from postprocessing (anchors generation)
|
# this is needed to decouple inference from postprocessing (anchors generation)
|
||||||
|
@ -190,10 +314,15 @@ class ResNetFPN:
|
||||||
p5 = p4.sequential(self.body.layer4)
|
p5 = p4.sequential(self.body.layer4)
|
||||||
return self.fpn([p3, p4, p5])
|
return self.fpn([p3, p4, p5])
|
||||||
|
|
||||||
|
|
||||||
class ExtraFPNBlock:
|
class ExtraFPNBlock:
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
self.p6 = nn.Conv2d(
|
||||||
self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
in_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
|
self.p7 = nn.Conv2d(
|
||||||
|
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||||
|
)
|
||||||
self.use_P5 = in_channels == out_channels
|
self.use_P5 = in_channels == out_channels
|
||||||
|
|
||||||
def __call__(self, p, c):
|
def __call__(self, p, c):
|
||||||
|
@ -204,13 +333,20 @@ class ExtraFPNBlock:
|
||||||
p.extend([p6, p7])
|
p.extend([p6, p7])
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
class FPN:
|
class FPN:
|
||||||
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
|
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
|
||||||
self.inner_blocks, self.layer_blocks = [], []
|
self.inner_blocks, self.layer_blocks = [], []
|
||||||
for in_channels in in_channels_list:
|
for in_channels in in_channels_list:
|
||||||
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
self.inner_blocks.append(
|
||||||
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||||
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
|
)
|
||||||
|
self.layer_blocks.append(
|
||||||
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||||
|
)
|
||||||
|
self.extra_blocks = (
|
||||||
|
ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
last_inner = self.inner_blocks[-1](x[-1])
|
last_inner = self.inner_blocks[-1](x[-1])
|
||||||
|
@ -219,9 +355,17 @@ class FPN:
|
||||||
inner_lateral = self.inner_blocks[idx](x[idx])
|
inner_lateral = self.inner_blocks[idx](x[idx])
|
||||||
|
|
||||||
# upsample to inner_lateral's shape
|
# upsample to inner_lateral's shape
|
||||||
(ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2]
|
(ih, iw), (oh, ow), prefix = (
|
||||||
|
last_inner.shape[-2:],
|
||||||
|
inner_lateral.shape[-2:],
|
||||||
|
last_inner.shape[:-2],
|
||||||
|
)
|
||||||
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
|
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
|
||||||
inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow]
|
inner_top_down = (
|
||||||
|
last_inner.reshape(*prefix, ih, 1, iw, 1)
|
||||||
|
.expand(*prefix, ih, eh, iw, ew)
|
||||||
|
.reshape(*prefix, ih * eh, iw * ew)[:, :, :oh, :ow]
|
||||||
|
)
|
||||||
|
|
||||||
last_inner = inner_lateral + inner_top_down
|
last_inner = inner_lateral + inner_top_down
|
||||||
results.insert(0, self.layer_blocks[idx](last_inner))
|
results.insert(0, self.layer_blocks[idx](last_inner))
|
||||||
|
@ -229,8 +373,10 @@ class FPN:
|
||||||
results = self.extra_blocks(results, x)
|
results = self.extra_blocks(results, x)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from extra.models.resnet import ResNeXt50_32X4D
|
from extra.models.resnet import ResNeXt50_32X4D
|
||||||
|
|
||||||
backbone = ResNeXt50_32X4D()
|
backbone = ResNeXt50_32X4D()
|
||||||
retina = RetinaNet(backbone)
|
retina = RetinaNet(backbone)
|
||||||
retina.load_from_pretrained()
|
retina.load_from_pretrained()
|
||||||
|
|
|
@ -7,10 +7,31 @@ from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class RNNT:
|
class RNNT:
|
||||||
def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
|
def __init__(
|
||||||
self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
|
self,
|
||||||
|
input_features=240,
|
||||||
|
vocab_size=29,
|
||||||
|
enc_hidden_size=1024,
|
||||||
|
pred_hidden_size=320,
|
||||||
|
joint_hidden_size=512,
|
||||||
|
pre_enc_layers=2,
|
||||||
|
post_enc_layers=3,
|
||||||
|
pred_layers=2,
|
||||||
|
stack_time_factor=2,
|
||||||
|
dropout=0.32,
|
||||||
|
):
|
||||||
|
self.encoder = Encoder(
|
||||||
|
input_features,
|
||||||
|
enc_hidden_size,
|
||||||
|
pre_enc_layers,
|
||||||
|
post_enc_layers,
|
||||||
|
stack_time_factor,
|
||||||
|
dropout,
|
||||||
|
)
|
||||||
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
|
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
|
||||||
self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
|
self.joint = Joint(
|
||||||
|
vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
|
||||||
|
)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def __call__(self, x, y, hc=None):
|
def __call__(self, x, y, hc=None):
|
||||||
|
@ -30,7 +51,12 @@ class RNNT:
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def _greedy_decode(self, logits, logit_len):
|
def _greedy_decode(self, logits, logit_len):
|
||||||
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
|
hc = Tensor.zeros(
|
||||||
|
self.prediction.rnn.layers,
|
||||||
|
2,
|
||||||
|
self.prediction.hidden_size,
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
labels = []
|
labels = []
|
||||||
label = Tensor.zeros(1, 1, requires_grad=False)
|
label = Tensor.zeros(1, 1, requires_grad=False)
|
||||||
mask = Tensor.zeros(1, requires_grad=False)
|
mask = Tensor.zeros(1, requires_grad=False)
|
||||||
|
@ -41,7 +67,14 @@ class RNNT:
|
||||||
while not_blank and added < 30:
|
while not_blank and added < 30:
|
||||||
if len(labels) > 0:
|
if len(labels) > 0:
|
||||||
mask = (mask + 1).clip(0, 1)
|
mask = (mask + 1).clip(0, 1)
|
||||||
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
|
label = (
|
||||||
|
Tensor(
|
||||||
|
[[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]],
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
- 1
|
||||||
|
)
|
||||||
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
||||||
k = jhc[0, 0, :29].argmax(axis=0).numpy()
|
k = jhc[0, 0, :29].argmax(axis=0).numpy()
|
||||||
not_blank = k != 28
|
not_blank = k != 28
|
||||||
|
@ -61,31 +94,59 @@ class RNNT:
|
||||||
|
|
||||||
def load_from_pretrained(self):
|
def load_from_pretrained(self):
|
||||||
fn = Path(__file__).parents[1] / "weights/rnnt.pt"
|
fn = Path(__file__).parents[1] / "weights/rnnt.pt"
|
||||||
fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
|
fetch(
|
||||||
|
"https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1",
|
||||||
|
fn,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
with open(fn, "rb") as f:
|
with open(fn, "rb") as f:
|
||||||
state_dict = torch.load(f, map_location="cpu")["state_dict"]
|
state_dict = torch.load(f, map_location="cpu")["state_dict"]
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
|
self.encoder.pre_rnn.cells[i].weights_ih.assign(
|
||||||
self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
|
state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||||
self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
|
)
|
||||||
self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
|
self.encoder.pre_rnn.cells[i].weights_hh.assign(
|
||||||
|
state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
self.encoder.pre_rnn.cells[i].bias_ih.assign(
|
||||||
|
state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
self.encoder.pre_rnn.cells[i].bias_hh.assign(
|
||||||
|
state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||||
|
)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
|
self.encoder.post_rnn.cells[i].weights_ih.assign(
|
||||||
self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
|
state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||||
self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
|
)
|
||||||
self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
|
self.encoder.post_rnn.cells[i].weights_hh.assign(
|
||||||
|
state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
self.encoder.post_rnn.cells[i].bias_ih.assign(
|
||||||
|
state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
self.encoder.post_rnn.cells[i].bias_hh.assign(
|
||||||
|
state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
|
||||||
# prediction
|
# prediction
|
||||||
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
|
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
|
self.prediction.rnn.cells[i].weights_ih.assign(
|
||||||
self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
|
state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||||
self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
|
)
|
||||||
self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
|
self.prediction.rnn.cells[i].weights_hh.assign(
|
||||||
|
state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
self.prediction.rnn.cells[i].bias_ih.assign(
|
||||||
|
state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
self.prediction.rnn.cells[i].bias_hh.assign(
|
||||||
|
state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||||
|
)
|
||||||
|
|
||||||
# joint
|
# joint
|
||||||
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
|
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
|
||||||
|
@ -104,7 +165,9 @@ class LSTMCell:
|
||||||
self.bias_hh = Tensor.uniform(hidden_size * 4)
|
self.bias_hh = Tensor.uniform(hidden_size * 4)
|
||||||
|
|
||||||
def __call__(self, x, hc):
|
def __call__(self, x, hc):
|
||||||
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
|
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[: x.shape[0]].linear(
|
||||||
|
self.weights_hh.T, self.bias_hh
|
||||||
|
)
|
||||||
|
|
||||||
i, f, g, o = gates.chunk(4, 1)
|
i, f, g, o = gates.chunk(4, 1)
|
||||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||||
|
@ -121,7 +184,12 @@ class LSTM:
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.layers = layers
|
self.layers = layers
|
||||||
|
|
||||||
self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
|
self.cells = [
|
||||||
|
LSTMCell(input_size, hidden_size, dropout)
|
||||||
|
if i == 0
|
||||||
|
else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0)
|
||||||
|
for i in range(layers)
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(self, x, hc):
|
def __call__(self, x, hc):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
|
@ -129,7 +197,9 @@ class LSTM:
|
||||||
return self.do_step(x_, hc_)
|
return self.do_step(x_, hc_)
|
||||||
|
|
||||||
if hc is None:
|
if hc is None:
|
||||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
|
hc = Tensor.zeros(
|
||||||
|
self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
output = None
|
output = None
|
||||||
for t in range(x.shape[0]):
|
for t in range(x.shape[0]):
|
||||||
|
@ -159,10 +229,20 @@ class StackTime:
|
||||||
|
|
||||||
|
|
||||||
class Encoder:
|
class Encoder:
|
||||||
def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size,
|
||||||
|
hidden_size,
|
||||||
|
pre_layers,
|
||||||
|
post_layers,
|
||||||
|
stack_time_factor,
|
||||||
|
dropout,
|
||||||
|
):
|
||||||
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
|
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
|
||||||
self.stack_time = StackTime(stack_time_factor)
|
self.stack_time = StackTime(stack_time_factor)
|
||||||
self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
|
self.post_rnn = LSTM(
|
||||||
|
stack_time_factor * hidden_size, hidden_size, post_layers, dropout
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x, x_lens):
|
def __call__(self, x, x_lens):
|
||||||
x, _ = self.pre_rnn(x, None)
|
x, _ = self.pre_rnn(x, None)
|
||||||
|
@ -185,7 +265,9 @@ class Prediction:
|
||||||
|
|
||||||
|
|
||||||
class Joint:
|
class Joint:
|
||||||
def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
|
def __init__(
|
||||||
|
self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
|
||||||
|
):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
|
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
|
||||||
|
|
|
@ -1,8 +1,17 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock:
|
class TransformerBlock:
|
||||||
def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
ff_dim,
|
||||||
|
prenorm=False,
|
||||||
|
act=lambda x: x.relu(),
|
||||||
|
dropout=0.1,
|
||||||
|
):
|
||||||
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
@ -10,11 +19,23 @@ class TransformerBlock:
|
||||||
self.prenorm, self.act = prenorm, act
|
self.prenorm, self.act = prenorm, act
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
self.query = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
self.query = (
|
||||||
self.key = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||||
self.value = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
Tensor.zeros(embed_dim),
|
||||||
|
)
|
||||||
|
self.key = (
|
||||||
|
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||||
|
Tensor.zeros(embed_dim),
|
||||||
|
)
|
||||||
|
self.value = (
|
||||||
|
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||||
|
Tensor.zeros(embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
self.out = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
self.out = (
|
||||||
|
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||||
|
Tensor.zeros(embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
|
self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
|
||||||
self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
|
self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||||
|
@ -24,25 +45,41 @@ class TransformerBlock:
|
||||||
|
|
||||||
def attn(self, x):
|
def attn(self, x):
|
||||||
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
|
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
|
||||||
query, key, value = [x.linear(*y).reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)).transpose(1,2) for y in [self.query, self.key, self.value]]
|
query, key, value = [
|
||||||
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(1,2)
|
x.linear(*y)
|
||||||
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
|
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size))
|
||||||
|
.transpose(1, 2)
|
||||||
|
for y in [self.query, self.key, self.value]
|
||||||
|
]
|
||||||
|
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
return attention.reshape(
|
||||||
|
shape=(x.shape[0], -1, self.num_heads * self.head_size)
|
||||||
|
).linear(*self.out)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
if self.prenorm:
|
if self.prenorm:
|
||||||
x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout)
|
x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout)
|
||||||
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
|
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(
|
||||||
|
*self.ff2
|
||||||
|
).dropout(self.dropout)
|
||||||
else:
|
else:
|
||||||
x = x + self.attn(x).dropout(self.dropout)
|
x = x + self.attn(x).dropout(self.dropout)
|
||||||
x = x.layernorm().linear(*self.ln1)
|
x = x.layernorm().linear(*self.ln1)
|
||||||
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
|
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(
|
||||||
|
self.dropout
|
||||||
|
)
|
||||||
x = x.layernorm().linear(*self.ln2)
|
x = x.layernorm().linear(*self.ln2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Transformer:
|
class Transformer:
|
||||||
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
|
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
|
||||||
self.maxlen, self.syms = maxlen, syms
|
self.maxlen, self.syms = maxlen, syms
|
||||||
self.embed = Tensor.scaled_uniform(maxlen+syms, embed_dim, requires_grad=False)
|
self.embed = Tensor.scaled_uniform(
|
||||||
|
maxlen + syms, embed_dim, requires_grad=False
|
||||||
|
)
|
||||||
self.tbs = []
|
self.tbs = []
|
||||||
for i in range(layers):
|
for i in range(layers):
|
||||||
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
|
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
|
||||||
|
@ -57,8 +94,11 @@ class Transformer:
|
||||||
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
|
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
|
||||||
onehot = onehot.reshape(bs * x.shape[1], self.maxlen + self.syms)
|
onehot = onehot.reshape(bs * x.shape[1], self.maxlen + self.syms)
|
||||||
|
|
||||||
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
|
x = (
|
||||||
|
Tensor(onehot, device=x.device)
|
||||||
|
.dot(self.embed)
|
||||||
|
.reshape(shape=(bs, x.shape[1], -1))
|
||||||
|
)
|
||||||
x = x.sequential(self.tbs)
|
x = x.sequential(self.tbs)
|
||||||
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax()
|
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax()
|
||||||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||||
|
|
||||||
|
|
|
@ -4,25 +4,63 @@ from tinygrad import nn
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import fetch, get_child
|
from tinygrad.helpers import fetch, get_child
|
||||||
|
|
||||||
|
|
||||||
class DownsampleBlock:
|
class DownsampleBlock:
|
||||||
def __init__(self, c0, c1, stride=2):
|
def __init__(self, c0, c1, stride=2):
|
||||||
self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
self.conv1 = [
|
||||||
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
nn.Conv2d(
|
||||||
|
c0,
|
||||||
|
c1,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
stride=stride,
|
||||||
|
padding=(1, 1, 1, 1, 1, 1),
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.InstanceNorm(c1),
|
||||||
|
Tensor.relu,
|
||||||
|
]
|
||||||
|
self.conv2 = [
|
||||||
|
nn.Conv2d(
|
||||||
|
c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False
|
||||||
|
),
|
||||||
|
nn.InstanceNorm(c1),
|
||||||
|
Tensor.relu,
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
return x.sequential(self.conv1).sequential(self.conv2)
|
return x.sequential(self.conv1).sequential(self.conv2)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleBlock:
|
class UpsampleBlock:
|
||||||
def __init__(self, c0, c1):
|
def __init__(self, c0, c1):
|
||||||
self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)]
|
self.upsample_conv = [
|
||||||
self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
nn.ConvTranspose2d(c0, c1, kernel_size=(2, 2, 2), stride=2)
|
||||||
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
]
|
||||||
|
self.conv1 = [
|
||||||
|
nn.Conv2d(
|
||||||
|
2 * c1,
|
||||||
|
c1,
|
||||||
|
kernel_size=(3, 3, 3),
|
||||||
|
padding=(1, 1, 1, 1, 1, 1),
|
||||||
|
bias=False,
|
||||||
|
),
|
||||||
|
nn.InstanceNorm(c1),
|
||||||
|
Tensor.relu,
|
||||||
|
]
|
||||||
|
self.conv2 = [
|
||||||
|
nn.Conv2d(
|
||||||
|
c1, c1, kernel_size=(3, 3, 3), padding=(1, 1, 1, 1, 1, 1), bias=False
|
||||||
|
),
|
||||||
|
nn.InstanceNorm(c1),
|
||||||
|
Tensor.relu,
|
||||||
|
]
|
||||||
|
|
||||||
def __call__(self, x, skip):
|
def __call__(self, x, skip):
|
||||||
x = x.sequential(self.upsample_conv)
|
x = x.sequential(self.upsample_conv)
|
||||||
x = Tensor.cat(x, skip, dim=1)
|
x = Tensor.cat(x, skip, dim=1)
|
||||||
return x.sequential(self.conv1).sequential(self.conv2)
|
return x.sequential(self.conv1).sequential(self.conv2)
|
||||||
|
|
||||||
|
|
||||||
class UNet3D:
|
class UNet3D:
|
||||||
def __init__(self, in_channels=1, n_class=3):
|
def __init__(self, in_channels=1, n_class=3):
|
||||||
filters = [32, 64, 128, 256, 320]
|
filters = [32, 64, 128, 256, 320]
|
||||||
|
@ -30,7 +68,9 @@ class UNet3D:
|
||||||
self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
|
self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
|
||||||
self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
|
self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
|
||||||
self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
|
self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
|
||||||
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])]
|
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [
|
||||||
|
UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])
|
||||||
|
]
|
||||||
self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
|
self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
@ -47,13 +87,17 @@ class UNet3D:
|
||||||
|
|
||||||
def load_from_pretrained(self):
|
def load_from_pretrained(self):
|
||||||
fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt"
|
fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt"
|
||||||
fetch("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
|
fetch(
|
||||||
|
"https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1",
|
||||||
|
fn,
|
||||||
|
)
|
||||||
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
|
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
obj = get_child(self, k)
|
obj = get_child(self, k)
|
||||||
assert obj.shape == v.shape, (k, obj.shape, v.shape)
|
assert obj.shape == v.shape, (k, obj.shape, v.shape)
|
||||||
obj.assign(v.numpy())
|
obj.assign(v.numpy())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mdl = UNet3D()
|
mdl = UNet3D()
|
||||||
mdl.load_from_pretrained()
|
mdl.load_from_pretrained()
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue