Reformat, uh, everything, with black
parent
01503ca90d
commit
661dcc5ed0
|
@ -4,15 +4,19 @@ import pathlib
|
|||
from hexdump import hexdump
|
||||
|
||||
fxn = None
|
||||
|
||||
|
||||
def disasm(buf):
|
||||
global fxn
|
||||
if fxn is None:
|
||||
shared = pathlib.Path(__file__).parent / "disasm.so"
|
||||
if not shared.is_file():
|
||||
os.system(f'cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so')
|
||||
fxn = ctypes.CDLL(shared.as_posix())['disasm']
|
||||
#hexdump(buf)
|
||||
END = b"\x00\x00\x00\x00\x00\x00\x00\x03"
|
||||
buf = buf[0x510:] # this right?
|
||||
buf = buf.split(END)[0] + END
|
||||
fxn(buf, len(buf))
|
||||
global fxn
|
||||
if fxn is None:
|
||||
shared = pathlib.Path(__file__).parent / "disasm.so"
|
||||
if not shared.is_file():
|
||||
os.system(
|
||||
f"cd {pathlib.Path(__file__).parent} && gcc -shared disasm-a3xx.c -o disasm.so"
|
||||
)
|
||||
fxn = ctypes.CDLL(shared.as_posix())["disasm"]
|
||||
# hexdump(buf)
|
||||
END = b"\x00\x00\x00\x00\x00\x00\x00\x03"
|
||||
buf = buf[0x510:] # this right?
|
||||
buf = buf.split(END)[0] + END
|
||||
fxn(buf, len(buf))
|
||||
|
|
|
@ -23,88 +23,139 @@ from abc import ABC
|
|||
|
||||
# we will be using the clang backend
|
||||
from tinygrad import Device
|
||||
|
||||
Device.DEFAULT = "CLANG"
|
||||
|
||||
# first, 2+3 as a Tensor, the highest level
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
a = Tensor([2])
|
||||
b = Tensor([3])
|
||||
result = a + b
|
||||
print(f"{a.numpy()} + {b.numpy()} = {result.numpy()}")
|
||||
assert result.numpy()[0] == 5.
|
||||
assert result.numpy()[0] == 5.0
|
||||
|
||||
# %%
|
||||
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
|
||||
# it's worth reading tinygrad/tensor.py. it's pretty beautiful
|
||||
import tinygrad.mlops as mlops
|
||||
|
||||
|
||||
# this is the good old familiar Tensor class
|
||||
class Tensor:
|
||||
# these two are pretty straightforward
|
||||
grad: Optional[Tensor]
|
||||
requires_grad: Optional[bool]
|
||||
# these two are pretty straightforward
|
||||
grad: Optional[Tensor]
|
||||
requires_grad: Optional[bool]
|
||||
|
||||
# this is the graph for the autograd engine
|
||||
_ctx: Optional[Function]
|
||||
# this is the graph for the autograd engine
|
||||
_ctx: Optional[Function]
|
||||
|
||||
# this is where the data (and other tensor properties) actually live
|
||||
lazydata: LazyBuffer
|
||||
# this is where the data (and other tensor properties) actually live
|
||||
lazydata: LazyBuffer
|
||||
|
||||
# high level ops (hlops) are defined on this class. example: relu
|
||||
def relu(self): return self.maximum(0)
|
||||
# high level ops (hlops) are defined on this class. example: relu
|
||||
def relu(self):
|
||||
return self.maximum(0)
|
||||
|
||||
# log is an mlop, this is the wrapper function in Tensor
|
||||
def log(self):
|
||||
return mlops.Log.apply(self)
|
||||
|
||||
# log is an mlop, this is the wrapper function in Tensor
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
|
||||
# all the definitions of the derivatives are subclasses of Function (like mlops.Log)
|
||||
# there's only 18 mlops for derivatives for everything (in tinygrad/mlops.py, code 9/10)
|
||||
# if you read one file, read mlops.py. if you read two files, also read tinygrad/tensor.py
|
||||
# you can differentiate the world using the chain rule
|
||||
class Function:
|
||||
# example types of forward and backward
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: pass
|
||||
def backward(self, x:LazyBuffer) -> LazyBuffer: pass
|
||||
# example types of forward and backward
|
||||
def forward(self, x: LazyBuffer) -> LazyBuffer:
|
||||
pass
|
||||
|
||||
def backward(self, x: LazyBuffer) -> LazyBuffer:
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
# == LazyBuffer (in tinygrad/lazy.py, code 5/10) ==
|
||||
from tinygrad.helpers import DType
|
||||
|
||||
|
||||
# this is where the properties live that you thought were a part of Tensor
|
||||
# LazyBuffer is like a Tensor without derivatives, at the mlop layer
|
||||
class LazyBuffer:
|
||||
# these three define the "type" of the buffer, and they are returned as Tensor properties
|
||||
device: str
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
# these three define the "type" of the buffer, and they are returned as Tensor properties
|
||||
device: str
|
||||
shape: Tuple[int, ...]
|
||||
dtype: DType
|
||||
|
||||
# a ShapeTracker is used to track things like reshapes and permutes
|
||||
# all MovementOps are zero copy in tinygrad!
|
||||
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
|
||||
# we'll come back to this later
|
||||
st: ShapeTracker
|
||||
# a ShapeTracker is used to track things like reshapes and permutes
|
||||
# all MovementOps are zero copy in tinygrad!
|
||||
# the ShapeTracker specifies how the data in the RawBuffer matches to the shape
|
||||
# we'll come back to this later
|
||||
st: ShapeTracker
|
||||
|
||||
# if the LazyBuffer is realized, it has a Buffer
|
||||
# we will come back to Buffer later
|
||||
realized: Optional[Buffer]
|
||||
# if the LazyBuffer is realized, it has a Buffer
|
||||
# we will come back to Buffer later
|
||||
realized: Optional[Buffer]
|
||||
|
||||
# if the lazybuffer is unrealized, it has a LazyOp
|
||||
# this LazyOp describes the computation needed to realize this LazyBuffer
|
||||
op: Optional[LazyOp]
|
||||
|
||||
# if the lazybuffer is unrealized, it has a LazyOp
|
||||
# this LazyOp describes the computation needed to realize this LazyBuffer
|
||||
op: Optional[LazyOp]
|
||||
|
||||
# LazyOp (in tinygrad/ops.py, code 4/10)
|
||||
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
|
||||
class LazyOp:
|
||||
op: Op # the type of the compute
|
||||
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
||||
arg: Optional[Any] = None # and an optional static argument
|
||||
op: Op # the type of the compute
|
||||
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
||||
arg: Optional[Any] = None # and an optional static argument
|
||||
|
||||
|
||||
# there's currently 26 Ops you have to implement for an accelerator.
|
||||
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto()
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto()
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto()
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto()
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
|
||||
class UnaryOps(Enum):
|
||||
EXP2 = auto()
|
||||
LOG2 = auto()
|
||||
CAST = auto()
|
||||
SIN = auto()
|
||||
SQRT = auto()
|
||||
|
||||
|
||||
class BinaryOps(Enum):
|
||||
ADD = auto()
|
||||
SUB = auto()
|
||||
MUL = auto()
|
||||
DIV = auto()
|
||||
CMPLT = auto()
|
||||
MAX = auto()
|
||||
|
||||
|
||||
class ReduceOps(Enum):
|
||||
SUM = auto()
|
||||
MAX = auto()
|
||||
|
||||
|
||||
class MovementOps(Enum):
|
||||
RESHAPE = auto()
|
||||
PERMUTE = auto()
|
||||
EXPAND = auto()
|
||||
PAD = auto()
|
||||
SHRINK = auto()
|
||||
STRIDE = auto()
|
||||
|
||||
|
||||
class TernaryOps(Enum):
|
||||
MULACC = auto()
|
||||
WHERE = auto()
|
||||
|
||||
|
||||
class LoadOps(Enum):
|
||||
EMPTY = auto()
|
||||
CONST = auto()
|
||||
FROM = auto()
|
||||
CONTIGUOUS = auto()
|
||||
CUSTOM = auto()
|
||||
|
||||
|
||||
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
|
||||
# you do not need to implement the MovementOps
|
||||
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)
|
||||
|
@ -135,14 +186,16 @@ assert len(lazyop.src) == 2
|
|||
# again, a LazyOp AST is like a GPU kernel. you have to copy the data on the device first
|
||||
assert lazyop.src[0].op.op == LoadOps.FROM
|
||||
assert lazyop.src[0].op.src[0].device == "CPU"
|
||||
assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
||||
assert (
|
||||
lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2
|
||||
), "the src of the FROM LazyOP is a LazyBuffer on the CPU holding [2.]"
|
||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||
|
||||
# now we realize the LazyBuffer
|
||||
result.realize()
|
||||
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
|
||||
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
|
||||
#assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
|
||||
# assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
|
||||
# getting ahead of ourselves, but we can copy the DeviceBuffer toCPU
|
||||
assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU, it's 5"
|
||||
|
||||
|
@ -151,41 +204,58 @@ assert result.lazydata.realized.toCPU()[0] == 5, "when put in numpy with toCPU,
|
|||
|
||||
# Now you have a choice, you can either write a "Interpreted" backend or "Compiled" backend
|
||||
|
||||
|
||||
# Interpreted backends are very simple (example: CPU and TORCH)
|
||||
class Interpreted:
|
||||
# and they have a lookup table to functions for the Ops
|
||||
fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.EXP2: lambda x: np.exp2(x),
|
||||
BinaryOps.ADD: lambda x,y: x+y}
|
||||
# and they have a lookup table to functions for the Ops
|
||||
fxn_for_op: Dict[Op, Callable] = {
|
||||
UnaryOps.EXP2: lambda x: np.exp2(x),
|
||||
BinaryOps.ADD: lambda x, y: x + y,
|
||||
}
|
||||
|
||||
|
||||
# Compiled backends take a little more (example: GPU and LLVM)
|
||||
class Compiled:
|
||||
# a code generator, which compiles the AST
|
||||
codegen: Type[Linearizer]
|
||||
# a code generator, which compiles the AST
|
||||
codegen: Type[Linearizer]
|
||||
|
||||
# and a runtime, which runs the generated code
|
||||
runtime: Type[Runtime]
|
||||
|
||||
# and a runtime, which runs the generated code
|
||||
runtime: Type[Runtime]
|
||||
|
||||
# Runtime is what actually runs the kernels for a compiled backend
|
||||
class Runtime(ABC):
|
||||
# `name` is the name of the function, and `prg` is the code
|
||||
# the constructor compiles the code
|
||||
def __init__(self, name:str, prg:str): pass
|
||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||
def __call__(self, *bufs:List[Buffer], global_size:Optional[List[int]], local_size:Optional[List[int]]): pass
|
||||
# `name` is the name of the function, and `prg` is the code
|
||||
# the constructor compiles the code
|
||||
def __init__(self, name: str, prg: str):
|
||||
pass
|
||||
|
||||
# call runs the code on the bufs. NOTE: the output is always bufs[0], but this is just a convention
|
||||
def __call__(
|
||||
self,
|
||||
*bufs: List[Buffer],
|
||||
global_size: Optional[List[int]],
|
||||
local_size: Optional[List[int]],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# %%
|
||||
# == Buffer (in tinygrad/device.py, code 6/10) ==
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Buffer is where the data is actually held. it's pretty close to just memory
|
||||
class Buffer(ABC):
|
||||
# create an empty rawbuffer that holds `size` elements of type `dtype`
|
||||
# `opaque` is an opaque container class
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None): pass
|
||||
# create an empty rawbuffer that holds `size` elements of type `dtype`
|
||||
# `opaque` is an opaque container class
|
||||
def __init__(self, device: str, size: int, dtype: DType, opaque: Any = None):
|
||||
pass
|
||||
|
||||
# toCPU converts the RawBuffer to a numpy array with shape (size,)
|
||||
def toCPU(self) -> np.ndarray:
|
||||
pass
|
||||
|
||||
# toCPU converts the RawBuffer to a numpy array with shape (size,)
|
||||
def toCPU(self) -> np.ndarray: pass
|
||||
|
||||
# %%
|
||||
# == Example: 2+3 in raw clang ==
|
||||
|
@ -205,6 +275,7 @@ from tinygrad.runtime.ops_clang import ClangProgram, compile_clang
|
|||
# then we copy the numpy in to RawMallocBuffers
|
||||
# last, we create an empty output buffer
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
input_a, input_b = MallocAllocator.alloc(4), MallocAllocator.alloc(4)
|
||||
output = MallocAllocator.alloc(4)
|
||||
|
||||
|
@ -214,12 +285,14 @@ MallocAllocator.copyin(input_a, numpy_a.data.cast("B"))
|
|||
MallocAllocator.copyin(input_b, numpy_b.data.cast("B"))
|
||||
|
||||
# compile the program, run it, and 2+3 does indeed equal 5
|
||||
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
|
||||
program = ClangProgram(
|
||||
"add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
|
||||
)
|
||||
program(output, input_a, input_b)
|
||||
numpy_out = np.empty(1, dtype=np.float32)
|
||||
MallocAllocator.copyout(numpy_out.data.cast("B"), output)
|
||||
assert numpy_out[0] == 5, "it's still 5"
|
||||
np.testing.assert_allclose(numpy_out, numpy_a+numpy_b)
|
||||
np.testing.assert_allclose(numpy_out, numpy_a + numpy_b)
|
||||
|
||||
# %%
|
||||
# == Linearizer (in tinygrad/codegen/linearizer.py, code 4/10) ==
|
||||
|
@ -229,35 +302,52 @@ np.testing.assert_allclose(numpy_out, numpy_a+numpy_b)
|
|||
# the first step of transforming an AST into code is to "linearize" it, think like toposort on the AST
|
||||
# for that, we use the Linearizer, which turns an AST into a list of (linear) UOps
|
||||
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto();
|
||||
|
||||
class UOps(Enum):
|
||||
LOOP = auto()
|
||||
DEFINE_LOCAL = auto()
|
||||
LOAD = auto()
|
||||
ALU = auto()
|
||||
CONST = auto()
|
||||
ENDLOOP = auto()
|
||||
STORE = auto()
|
||||
|
||||
|
||||
class UOp:
|
||||
uop: UOps
|
||||
dtype: Optional[DType]
|
||||
vin: Tuple[UOp, ...]
|
||||
arg: Any
|
||||
num: int # UOps are unique
|
||||
uop: UOps
|
||||
dtype: Optional[DType]
|
||||
vin: Tuple[UOp, ...]
|
||||
arg: Any
|
||||
num: int # UOps are unique
|
||||
|
||||
|
||||
class Linearizer:
|
||||
# create the kernel with the AST
|
||||
# NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change
|
||||
def __init__(self, ast:LazyOp): pass
|
||||
def linearize(self): pass
|
||||
# create the kernel with the AST
|
||||
# NOTE: the AST contains the CompiledBuffers themselves as the root nodes. this will change
|
||||
def __init__(self, ast: LazyOp):
|
||||
pass
|
||||
|
||||
def linearize(self):
|
||||
pass
|
||||
|
||||
# when linearize is run, it fills in this list
|
||||
uops: List[UOp]
|
||||
|
||||
# when linearize is run, it fills in this list
|
||||
uops: List[UOp]
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
result = Tensor(2).realize() + Tensor(3).realize()
|
||||
|
||||
# use the real Linearizer to linearize 2+3
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
||||
sched = result.lazydata.schedule()
|
||||
linearizer = Linearizer(sched[-1].ast)
|
||||
linearizer.linearize()
|
||||
|
||||
# print the uops
|
||||
for uop in linearizer.uops: print(uop)
|
||||
for uop in linearizer.uops:
|
||||
print(uop)
|
||||
|
||||
# output:
|
||||
"""
|
||||
|
@ -275,13 +365,15 @@ for uop in linearizer.uops: print(uop)
|
|||
# here, we have an example where we fetch the generated code from the JIT
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
result = Tensor(2) + Tensor(3)
|
||||
|
||||
# we have a global cache used by the JIT
|
||||
# from there, we can see the generated clang code
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector.start() # enables the cache
|
||||
result.realize() # create the program and runs it
|
||||
|
||||
CacheCollector.start() # enables the cache
|
||||
result.realize() # create the program and runs it
|
||||
cache_saved = CacheCollector.finish() # disable the cache
|
||||
|
||||
# there's one ASTRunner in the cache
|
||||
|
@ -310,22 +402,24 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
a = ShapeTracker.from_shape((10, 10))
|
||||
|
||||
# you'll see it has one view. the (10, 1 are the strides)
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
|
||||
# we can permute it, and the strides change
|
||||
a = a.permute((1,0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
a = a.permute((1, 0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
|
||||
# we can then reshape it, and the strides change again
|
||||
# note how the permute stays applied
|
||||
a = a.reshape((5,2,5,2))
|
||||
print(a) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
||||
a = a.reshape((5, 2, 5, 2))
|
||||
print(
|
||||
a
|
||||
) # ShapeTracker(shape=(5, 2, 5, 2), views=[View((5, 2, 5, 2), (2, 1, 20, 10), 0)])
|
||||
|
||||
# now, if we were to reshape it to a (100,) shape tensor, we have to create a second view
|
||||
a = a.reshape((100,))
|
||||
print(a) # ShapeTracker(shape=(100,), views=[
|
||||
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
|
||||
# View((100,), (1,), 0)])
|
||||
print(a) # ShapeTracker(shape=(100,), views=[
|
||||
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
|
||||
# View((100,), (1,), 0)])
|
||||
|
||||
# Views stack on top of each other, to allow zero copy for any number of MovementOps
|
||||
# we can render a Python expression for the index at any time
|
||||
|
@ -333,22 +427,22 @@ idx, _ = a.expr_idxs()
|
|||
print(idx.render()) # (((idx0%10)*10)+(idx0//10))
|
||||
|
||||
# of course, if we reshape it back, the indexes get simple again
|
||||
a = a.reshape((10,10))
|
||||
a = a.reshape((10, 10))
|
||||
idx, _ = a.expr_idxs()
|
||||
print(idx.render()) # ((idx1*10)+idx0)
|
||||
|
||||
# the ShapeTracker still has two views though...
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[
|
||||
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
|
||||
# View((10, 10), (10, 1), 0)])
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[
|
||||
# View((5, 2, 5, 2), (2, 1, 20, 10), 0),
|
||||
# View((10, 10), (10, 1), 0)])
|
||||
|
||||
# ...until we simplify it!
|
||||
a = a.simplify()
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (1, 10), 0)])
|
||||
|
||||
# and now we permute it back
|
||||
a = a.permute((1,0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
a = a.permute((1, 0))
|
||||
print(a) # ShapeTracker(shape=(10, 10), views=[View((10, 10), (10, 1), 0)])
|
||||
|
||||
# and it's even contiguous
|
||||
assert a.contiguous == True
|
||||
|
@ -365,17 +459,17 @@ a = Variable("a", 0, 10)
|
|||
b = Variable("b", 0, 10)
|
||||
|
||||
# some math examples
|
||||
print((a*10).min, (a*10).max) # you'll see a*10 has a min of 0 and max of 100
|
||||
print((a+b).min, (a+b).max) # 0 20, you get the idea
|
||||
print((a * 10).min, (a * 10).max) # you'll see a*10 has a min of 0 and max of 100
|
||||
print((a + b).min, (a + b).max) # 0 20, you get the idea
|
||||
|
||||
# but complex expressions are where it gets fun
|
||||
expr = (a + b*10) % 10
|
||||
print(expr.render()) # (a%10)
|
||||
expr = (a + b * 10) % 10
|
||||
print(expr.render()) # (a%10)
|
||||
# as you can see, b is gone!
|
||||
|
||||
# one more
|
||||
expr = (a*40 + b) // 20
|
||||
print(expr.render()) # (a*2)
|
||||
expr = (a * 40 + b) // 20
|
||||
print(expr.render()) # (a*2)
|
||||
print(expr.min, expr.max) # 0 20
|
||||
# this is just "(a*2)"
|
||||
# since b only has a range from 0-10, it can't affect the output
|
||||
|
|
|
@ -15,8 +15,8 @@ a = MallocAllocator.alloc(4)
|
|||
b = MallocAllocator.alloc(4)
|
||||
|
||||
# load in some values (little endian)
|
||||
MallocAllocator.copyin(a, bytearray([2,0,0,0]))
|
||||
MallocAllocator.copyin(b, bytearray([3,0,0,0]))
|
||||
MallocAllocator.copyin(a, bytearray([2, 0, 0, 0]))
|
||||
MallocAllocator.copyin(b, bytearray([3, 0, 0, 0]))
|
||||
|
||||
# compile a program to a binary
|
||||
lib = compile_clang("void add(int *out, int *a, int *b) { out[0] = a[0] + b[0]; }")
|
||||
|
@ -34,7 +34,7 @@ assert val == 5
|
|||
|
||||
print("******** second, the Device ***********")
|
||||
|
||||
DEVICE = "CLANG" # NOTE: you can change this!
|
||||
DEVICE = "CLANG" # NOTE: you can change this!
|
||||
|
||||
import struct
|
||||
from tinygrad.helpers import dtypes
|
||||
|
@ -49,14 +49,21 @@ b = Buffer(DEVICE, 1, dtypes.int32).copyin(memoryview(bytearray(struct.pack("I",
|
|||
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
||||
|
||||
# describe the computation
|
||||
ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
ld_1 = LazyOp(
|
||||
BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||
)
|
||||
ld_2 = LazyOp(
|
||||
BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||
)
|
||||
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
|
||||
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
st_0 = LazyOp(
|
||||
BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,)))
|
||||
)
|
||||
|
||||
# convert the computation to a "linearized" format (print the format)
|
||||
lin = Device[DEVICE].get_linearizer(st_0).linearize()
|
||||
for u in lin.uops: print(u)
|
||||
for u in lin.uops:
|
||||
print(u)
|
||||
|
||||
# compile a program (and print the source)
|
||||
fxn = Device[DEVICE].to_program(lin)
|
||||
|
@ -67,7 +74,7 @@ print(fxn.prg)
|
|||
fxn.exec([out, a, b])
|
||||
|
||||
# check the data out
|
||||
print(val:=out.toCPU().item())
|
||||
print(val := out.toCPU().item())
|
||||
assert val == 5
|
||||
|
||||
|
||||
|
@ -79,6 +86,7 @@ from tinygrad.realize import run_schedule
|
|||
# allocate some values + load in values
|
||||
# TODO: remove numpy here
|
||||
import numpy as np
|
||||
|
||||
a = LazyBuffer.fromCPU(np.array([2], np.int32)).copy_to_device(DEVICE)
|
||||
b = LazyBuffer.fromCPU(np.array([3], np.int32)).copy_to_device(DEVICE)
|
||||
|
||||
|
@ -87,10 +95,12 @@ out = a.e(BinaryOps.ADD, b)
|
|||
|
||||
# schedule the computation as a list of kernels
|
||||
sched = out.schedule()
|
||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
for si in sched:
|
||||
print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
|
||||
# DEBUGGING: print the compute ast as a tree
|
||||
from tinygrad.graph import print_tree
|
||||
|
||||
print_tree(sched[-1].ast)
|
||||
# NOTE: sched[-1].ast is the same as st_0 above
|
||||
|
||||
|
@ -98,7 +108,7 @@ print_tree(sched[-1].ast)
|
|||
run_schedule(sched)
|
||||
|
||||
# check the data out
|
||||
print(val:=out.realized.toCPU().item())
|
||||
print(val := out.realized.toCPU().item())
|
||||
assert val == 5
|
||||
|
||||
|
||||
|
@ -111,5 +121,5 @@ b = Tensor([3], dtype=dtypes.int32, device=DEVICE)
|
|||
out = a + b
|
||||
|
||||
# check the data out
|
||||
print(val:=out.item())
|
||||
print(val := out.item())
|
||||
assert val == 5
|
||||
|
|
|
@ -1,114 +1,135 @@
|
|||
from typing import Tuple
|
||||
import time
|
||||
from tinygrad import Tensor, TinyJit, nn, Variable
|
||||
from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype
|
||||
from tinygrad.helpers import (
|
||||
dtypes,
|
||||
) # TODO: wouldn't need this if argmax returned the right dtype
|
||||
import gymnasium as gym
|
||||
from tqdm import trange
|
||||
import numpy as np # TODO: remove numpy import
|
||||
|
||||
|
||||
class ActorCritic:
|
||||
def __init__(self, in_features, out_features, hidden_state=32):
|
||||
self.l1 = nn.Linear(in_features, hidden_state)
|
||||
self.l2 = nn.Linear(hidden_state, out_features)
|
||||
def __init__(self, in_features, out_features, hidden_state=32):
|
||||
self.l1 = nn.Linear(in_features, hidden_state)
|
||||
self.l2 = nn.Linear(hidden_state, out_features)
|
||||
|
||||
self.c1 = nn.Linear(in_features, hidden_state)
|
||||
self.c2 = nn.Linear(hidden_state, 1)
|
||||
self.c1 = nn.Linear(in_features, hidden_state)
|
||||
self.c2 = nn.Linear(hidden_state, 1)
|
||||
|
||||
def __call__(self, obs:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
x = self.l1(obs).tanh()
|
||||
act = self.l2(x).log_softmax()
|
||||
x = self.c1(obs).relu()
|
||||
return act, self.c2(x)
|
||||
def __call__(self, obs: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
x = self.l1(obs).tanh()
|
||||
act = self.l2(x).log_softmax()
|
||||
x = self.c1(obs).relu()
|
||||
return act, self.c2(x)
|
||||
|
||||
|
||||
def evaluate(model: ActorCritic, test_env: gym.Env) -> float:
|
||||
(obs, _), terminated, truncated = test_env.reset(), False, False
|
||||
total_rew = 0.0
|
||||
while not terminated and not truncated:
|
||||
act = model(Tensor(obs))[0].argmax().cast(dtypes.int32).item()
|
||||
obs, rew, terminated, truncated, _ = test_env.step(act)
|
||||
total_rew += float(rew)
|
||||
return total_rew
|
||||
|
||||
def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
|
||||
(obs, _), terminated, truncated = test_env.reset(), False, False
|
||||
total_rew = 0.0
|
||||
while not terminated and not truncated:
|
||||
act = model(Tensor(obs))[0].argmax().cast(dtypes.int32).item()
|
||||
obs, rew, terminated, truncated, _ = test_env.step(act)
|
||||
total_rew += float(rew)
|
||||
return total_rew
|
||||
|
||||
# TODO: time should be < 5s on M1 Max
|
||||
if __name__ == "__main__":
|
||||
env = gym.make('CartPole-v1')
|
||||
env = gym.make("CartPole-v1")
|
||||
|
||||
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
|
||||
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
|
||||
|
||||
@TinyJit
|
||||
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
with Tensor.train():
|
||||
log_dist, value = model(x)
|
||||
@TinyJit
|
||||
def train_step(
|
||||
x: Tensor, selected_action: Tensor, reward: Tensor, old_log_dist: Tensor
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
with Tensor.train():
|
||||
log_dist, value = model(x)
|
||||
|
||||
# get advantage
|
||||
advantage = reward.reshape(-1, 1) - value
|
||||
mask = selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)
|
||||
masked_advantage = mask * advantage.detach()
|
||||
# get advantage
|
||||
advantage = reward.reshape(-1, 1) - value
|
||||
mask = selected_action.reshape(-1, 1) == Tensor.arange(
|
||||
log_dist.shape[1]
|
||||
).reshape(1, -1).expand(selected_action.shape[0], -1)
|
||||
masked_advantage = mask * advantage.detach()
|
||||
|
||||
# PPO
|
||||
ratios = (log_dist - old_log_dist).exp() * masked_advantage
|
||||
clipped_ratios = ratios.clip(1-0.2, 1+0.2) * masked_advantage
|
||||
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
|
||||
# PPO
|
||||
ratios = (log_dist - old_log_dist).exp() * masked_advantage
|
||||
clipped_ratios = ratios.clip(1 - 0.2, 1 + 0.2) * masked_advantage
|
||||
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
|
||||
|
||||
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
|
||||
critic_loss = advantage.square().mean()
|
||||
opt.zero_grad()
|
||||
(action_loss + entropy_loss*0.0005 + critic_loss).backward()
|
||||
opt.step()
|
||||
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
|
||||
entropy_loss = (
|
||||
(log_dist.exp() * log_dist).sum(-1).mean()
|
||||
) # this encourages diversity
|
||||
critic_loss = advantage.square().mean()
|
||||
opt.zero_grad()
|
||||
(action_loss + entropy_loss * 0.0005 + critic_loss).backward()
|
||||
opt.step()
|
||||
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def get_action_dist(obs:Tensor) -> Tensor:
|
||||
# TODO: with no_grad
|
||||
Tensor.no_grad = True
|
||||
ret = model(obs)[0].exp().realize()
|
||||
Tensor.no_grad = False
|
||||
return ret
|
||||
@TinyJit
|
||||
def get_action_dist(obs: Tensor) -> Tensor:
|
||||
# TODO: with no_grad
|
||||
Tensor.no_grad = True
|
||||
ret = model(obs)[0].exp().realize()
|
||||
Tensor.no_grad = False
|
||||
return ret
|
||||
|
||||
BS = 256
|
||||
MAX_REPLAY_BUFFER = 2000
|
||||
st, steps = time.perf_counter(), 0
|
||||
Xn, An, Rn = [], [], []
|
||||
for i in (t:=trange(40)):
|
||||
get_action_dist.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through
|
||||
BS = 256
|
||||
MAX_REPLAY_BUFFER = 2000
|
||||
st, steps = time.perf_counter(), 0
|
||||
Xn, An, Rn = [], [], []
|
||||
for i in (t := trange(40)):
|
||||
get_action_dist.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through
|
||||
|
||||
obs:np.ndarray = env.reset()[0]
|
||||
rews, terminated, truncated = [], False, False
|
||||
# NOTE: we don't want to early stop since then the rewards are wrong for the last episode
|
||||
while not terminated and not truncated:
|
||||
# pick actions
|
||||
# TODO: move the multinomial into jitted tinygrad when JIT rand works
|
||||
# TODO: what's the temperature here?
|
||||
act = get_action_dist(Tensor(obs)).multinomial().item()
|
||||
obs: np.ndarray = env.reset()[0]
|
||||
rews, terminated, truncated = [], False, False
|
||||
# NOTE: we don't want to early stop since then the rewards are wrong for the last episode
|
||||
while not terminated and not truncated:
|
||||
# pick actions
|
||||
# TODO: move the multinomial into jitted tinygrad when JIT rand works
|
||||
# TODO: what's the temperature here?
|
||||
act = get_action_dist(Tensor(obs)).multinomial().item()
|
||||
|
||||
# save this state action pair
|
||||
# TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment
|
||||
Xn.append(np.copy(obs))
|
||||
An.append(act)
|
||||
# save this state action pair
|
||||
# TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment
|
||||
Xn.append(np.copy(obs))
|
||||
An.append(act)
|
||||
|
||||
obs, rew, terminated, truncated, _ = env.step(act)
|
||||
rews.append(float(rew))
|
||||
steps += len(rews)
|
||||
obs, rew, terminated, truncated, _ = env.step(act)
|
||||
rews.append(float(rew))
|
||||
steps += len(rews)
|
||||
|
||||
# reward to go
|
||||
# TODO: move this into tinygrad
|
||||
discounts = np.power(0.99, np.arange(len(rews)))
|
||||
Rn += [np.sum(rews[i:] * discounts[:len(rews)-i]) for i in range(len(rews))]
|
||||
# reward to go
|
||||
# TODO: move this into tinygrad
|
||||
discounts = np.power(0.99, np.arange(len(rews)))
|
||||
Rn += [np.sum(rews[i:] * discounts[: len(rews) - i]) for i in range(len(rews))]
|
||||
|
||||
Xn, An, Rn = Xn[-MAX_REPLAY_BUFFER:], An[-MAX_REPLAY_BUFFER:], Rn[-MAX_REPLAY_BUFFER:]
|
||||
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
|
||||
Xn, An, Rn = (
|
||||
Xn[-MAX_REPLAY_BUFFER:],
|
||||
An[-MAX_REPLAY_BUFFER:],
|
||||
Rn[-MAX_REPLAY_BUFFER:],
|
||||
)
|
||||
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
|
||||
|
||||
# TODO: make this work
|
||||
#vsz = Variable("sz", 1, MAX_REPLAY_BUFFER-1).bind(len(Xn))
|
||||
#X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz)
|
||||
# TODO: make this work
|
||||
# vsz = Variable("sz", 1, MAX_REPLAY_BUFFER-1).bind(len(Xn))
|
||||
# X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz)
|
||||
|
||||
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
|
||||
for i in range(5):
|
||||
samples = Tensor.randint(BS, high=X.shape[0]).realize() # TODO: remove the need for this
|
||||
# TODO: is this recompiling based on the shape?
|
||||
action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
|
||||
t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}")
|
||||
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
|
||||
for i in range(5):
|
||||
samples = Tensor.randint(
|
||||
BS, high=X.shape[0]
|
||||
).realize() # TODO: remove the need for this
|
||||
# TODO: is this recompiling based on the shape?
|
||||
action_loss, entropy_loss, critic_loss = train_step(
|
||||
X[samples], A[samples], R[samples], old_log_dist[samples]
|
||||
)
|
||||
t.set_description(
|
||||
f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}"
|
||||
)
|
||||
|
||||
test_rew = evaluate(model, gym.make('CartPole-v1', render_mode='human'))
|
||||
print(f"test reward: {test_rew}")
|
||||
test_rew = evaluate(model, gym.make("CartPole-v1", render_mode="human"))
|
||||
print(f"test reward: {test_rew}")
|
||||
|
|
|
@ -4,42 +4,61 @@ from tinygrad import Tensor, TinyJit, nn, GlobalCounters
|
|||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm2d(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm2d(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5),
|
||||
Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5),
|
||||
Tensor.relu,
|
||||
nn.BatchNorm2d(32),
|
||||
Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3),
|
||||
Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3),
|
||||
Tensor.relu,
|
||||
nn.BatchNorm2d(64),
|
||||
Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1),
|
||||
nn.Linear(576, 10),
|
||||
]
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return x.sequential(self.layers)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
|
||||
|
||||
model = Model()
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
model = Model()
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||
|
||||
# TODO: there's a compiler error if you comment out TinyJit since randint isn't being realized and there's something weird with int
|
||||
@TinyJit
|
||||
def train_step(samples:Tensor) -> Tensor:
|
||||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss.realize()
|
||||
# TODO: there's a compiler error if you comment out TinyJit since randint isn't being realized and there's something weird with int
|
||||
@TinyJit
|
||||
def train_step(samples: Tensor) -> Tensor:
|
||||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = (
|
||||
model(X_train[samples])
|
||||
.sparse_categorical_crossentropy(Y_train[samples])
|
||||
.backward()
|
||||
)
|
||||
opt.step()
|
||||
return loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor:
|
||||
return ((model(X_test).argmax(axis=1) == Y_test).mean() * 100).realize()
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(70)):
|
||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
|
||||
loss = train_step(samples)
|
||||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
test_acc = float("nan")
|
||||
for i in (t := trange(70)):
|
||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||
samples = Tensor.randint(
|
||||
512, high=X_train.shape[0]
|
||||
) # TODO: put this in the JIT when rand is fixed
|
||||
loss = train_step(samples)
|
||||
if i % 10 == 9:
|
||||
test_acc = get_test_acc().item()
|
||||
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
|
||||
|
|
|
@ -10,8 +10,10 @@ from tinygrad.helpers import GlobalCounters
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
|
||||
def tensors_allocated():
|
||||
return sum(isinstance(x, Tensor) for x in gc.get_objects())
|
||||
return sum(isinstance(x, Tensor) for x in gc.get_objects())
|
||||
|
||||
|
||||
NUM = getenv("NUM", 2)
|
||||
BS = getenv("BS", 8)
|
||||
|
@ -22,46 +24,53 @@ ADAM = getenv("ADAM", 0)
|
|||
CLCACHE = getenv("CLCACHE", 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters: p.realize()
|
||||
if ADAM: optimizer = optim.Adam(parameters, lr=0.001)
|
||||
else: optimizer = optim.SGD(parameters, lr=0.001)
|
||||
|
||||
Tensor.training = TRAINING
|
||||
Tensor.no_grad = not BACKWARD
|
||||
for i in trange(CNT):
|
||||
GlobalCounters.reset()
|
||||
cpy = time.monotonic()
|
||||
x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()
|
||||
y_train = Tensor.randn(BS, 1000, requires_grad=False).realize()
|
||||
|
||||
# TODO: replace with TinyJit
|
||||
if i < 3 or not CLCACHE:
|
||||
st = time.monotonic()
|
||||
out = model.forward(x_train)
|
||||
loss = out.log_softmax().mul(y_train).mean()
|
||||
if i == 2 and CLCACHE: CacheCollector.start()
|
||||
if BACKWARD:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mt = time.monotonic()
|
||||
loss.realize()
|
||||
for p in parameters:
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False)
|
||||
parameters = get_parameters(model)
|
||||
for p in parameters:
|
||||
p.realize()
|
||||
et = time.monotonic()
|
||||
if ADAM:
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
else:
|
||||
st = mt = time.monotonic()
|
||||
for prg, args in cl_cache: prg(*args)
|
||||
et = time.monotonic()
|
||||
optimizer = optim.SGD(parameters, lr=0.001)
|
||||
|
||||
if i == 2 and CLCACHE:
|
||||
cl_cache = CacheCollector.finish()
|
||||
Tensor.training = TRAINING
|
||||
Tensor.no_grad = not BACKWARD
|
||||
for i in trange(CNT):
|
||||
GlobalCounters.reset()
|
||||
cpy = time.monotonic()
|
||||
x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()
|
||||
y_train = Tensor.randn(BS, 1000, requires_grad=False).realize()
|
||||
|
||||
mem_used = GlobalCounters.mem_used
|
||||
loss_cpu = loss.detach().numpy()
|
||||
cl = time.monotonic()
|
||||
# TODO: replace with TinyJit
|
||||
if i < 3 or not CLCACHE:
|
||||
st = time.monotonic()
|
||||
out = model.forward(x_train)
|
||||
loss = out.log_softmax().mul(y_train).mean()
|
||||
if i == 2 and CLCACHE:
|
||||
CacheCollector.start()
|
||||
if BACKWARD:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mt = time.monotonic()
|
||||
loss.realize()
|
||||
for p in parameters:
|
||||
p.realize()
|
||||
et = time.monotonic()
|
||||
else:
|
||||
st = mt = time.monotonic()
|
||||
for prg, args in cl_cache:
|
||||
prg(*args)
|
||||
et = time.monotonic()
|
||||
|
||||
print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
if i == 2 and CLCACHE:
|
||||
cl_cache = CacheCollector.finish()
|
||||
|
||||
mem_used = GlobalCounters.mem_used
|
||||
loss_cpu = loss.detach().numpy()
|
||||
cl = time.monotonic()
|
||||
|
||||
print(
|
||||
f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
import os, sys, traceback
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from io import StringIO
|
||||
|
@ -9,99 +10,148 @@ from tinygrad.helpers import Timing, colored, getenv, fetch
|
|||
from extra.models.llama import Transformer, convert_from_huggingface
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
||||
def create_fixed_tokenizer(output_file):
|
||||
print("creating fixed tokenizer")
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true").read_bytes())
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(mp.SerializeToString())
|
||||
print("creating fixed tokenizer")
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(
|
||||
fetch(
|
||||
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true"
|
||||
).read_bytes()
|
||||
)
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(mp.SerializeToString())
|
||||
|
||||
|
||||
# TODO: make loading bf16 fast so we can remove this
|
||||
def create_model_cache(output_file, model):
|
||||
print(f"creating model cache at {output_file}")
|
||||
# TODO: add read only Tensors
|
||||
with Timing("download weights: "):
|
||||
part1 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"))
|
||||
part2 = nn.state.torch_load(fetch("https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"))
|
||||
print(f"creating model cache at {output_file}")
|
||||
# TODO: add read only Tensors
|
||||
with Timing("download weights: "):
|
||||
part1 = nn.state.torch_load(
|
||||
fetch(
|
||||
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"
|
||||
)
|
||||
)
|
||||
part2 = nn.state.torch_load(
|
||||
fetch(
|
||||
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"
|
||||
)
|
||||
)
|
||||
|
||||
with Timing("weights -> model: "):
|
||||
nn.state.load_state_dict(model, convert_from_huggingface(part1, model, 32, 8), strict=False)
|
||||
nn.state.load_state_dict(model, convert_from_huggingface(part2, model, 32, 8), strict=False)
|
||||
with Timing("weights -> model: "):
|
||||
nn.state.load_state_dict(
|
||||
model, convert_from_huggingface(part1, model, 32, 8), strict=False
|
||||
)
|
||||
nn.state.load_state_dict(
|
||||
model, convert_from_huggingface(part2, model, 32, 8), strict=False
|
||||
)
|
||||
|
||||
with Timing("saving float16 cache: "):
|
||||
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
|
||||
with Timing("saving float16 cache: "):
|
||||
nn.state.safe_save(nn.state.get_state_dict(model), output_file)
|
||||
|
||||
print("cache created, rerun to use")
|
||||
exit(0)
|
||||
|
||||
print("cache created, rerun to use")
|
||||
exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
Tensor.no_grad = True
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||
with Timing("create model: "):
|
||||
model = Transformer(4096, 14336, n_heads=32, n_layers=32, norm_eps=1e-5, vocab_size=32002, n_kv_heads=8, max_context=4096)
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
|
||||
with Timing("create model: "):
|
||||
model = Transformer(
|
||||
4096,
|
||||
14336,
|
||||
n_heads=32,
|
||||
n_layers=32,
|
||||
norm_eps=1e-5,
|
||||
vocab_size=32002,
|
||||
n_kv_heads=8,
|
||||
max_context=4096,
|
||||
)
|
||||
|
||||
cached_model = "/tmp/cached_openhermes.safetensors"
|
||||
if not os.path.isfile(cached_model): create_model_cache(cached_model, model)
|
||||
with Timing("loading float16 cache: "):
|
||||
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
|
||||
cached_model = "/tmp/cached_openhermes.safetensors"
|
||||
if not os.path.isfile(cached_model):
|
||||
create_model_cache(cached_model, model)
|
||||
with Timing("loading float16 cache: "):
|
||||
nn.state.load_state_dict(model, nn.state.safe_load(cached_model))
|
||||
|
||||
if not os.path.isfile("/tmp/tokenizer.model"): create_fixed_tokenizer("/tmp/tokenizer.model")
|
||||
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
|
||||
if not os.path.isfile("/tmp/tokenizer.model"):
|
||||
create_fixed_tokenizer("/tmp/tokenizer.model")
|
||||
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
|
||||
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
IM_END = 32000
|
||||
IM_START = 32001
|
||||
def encode_prompt(k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
||||
def start_prompt(k): return [IM_START]+spp.encode(f"{k}\n")
|
||||
def output(outputted, toks, color):
|
||||
cur = spp.decode(toks)[len(outputted):]
|
||||
sys.stdout.write(colored(cur, color))
|
||||
sys.stdout.flush()
|
||||
outputted += cur
|
||||
return outputted
|
||||
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
|
||||
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
IM_END = 32000
|
||||
IM_START = 32001
|
||||
|
||||
# *** app below this line ***
|
||||
def encode_prompt(k, v):
|
||||
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
|
||||
|
||||
toks = [spp.bos_id()] + encode_prompt("system", "You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input")
|
||||
def start_prompt(k):
|
||||
return [IM_START] + spp.encode(f"{k}\n")
|
||||
|
||||
PROMPT = getenv("PROMPT", 1)
|
||||
temperature = getenv("TEMP", 0.7)
|
||||
def output(outputted, toks, color):
|
||||
cur = spp.decode(toks)[len(outputted) :]
|
||||
sys.stdout.write(colored(cur, color))
|
||||
sys.stdout.flush()
|
||||
outputted += cur
|
||||
return outputted
|
||||
|
||||
start_pos = 0
|
||||
outputted = output("", toks, "green")
|
||||
turn = True
|
||||
while 1:
|
||||
if PROMPT:
|
||||
toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
|
||||
else:
|
||||
toks += start_prompt("user" if turn else "assistant")
|
||||
turn = not turn
|
||||
old_output_len = len(outputted)
|
||||
# *** app below this line ***
|
||||
|
||||
toks = [spp.bos_id()] + encode_prompt(
|
||||
"system",
|
||||
"You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input",
|
||||
)
|
||||
|
||||
PROMPT = getenv("PROMPT", 1)
|
||||
temperature = getenv("TEMP", 0.7)
|
||||
|
||||
start_pos = 0
|
||||
outputted = output("", toks, "green")
|
||||
turn = True
|
||||
while 1:
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).multinomial().item()
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
||||
if tok == IM_END: break
|
||||
if tok == spp.eos_id(): break
|
||||
new_output = outputted[old_output_len:]
|
||||
if PROMPT:
|
||||
toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
|
||||
else:
|
||||
toks += start_prompt("user" if turn else "assistant")
|
||||
turn = not turn
|
||||
old_output_len = len(outputted)
|
||||
while 1:
|
||||
tok = (
|
||||
model(Tensor([toks[start_pos:]]), start_pos, temperature)
|
||||
.multinomial()
|
||||
.item()
|
||||
)
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
outputted = output(outputted, toks, "blue" if not turn else "cyan")
|
||||
if tok == IM_END:
|
||||
break
|
||||
if tok == spp.eos_id():
|
||||
break
|
||||
new_output = outputted[old_output_len:]
|
||||
|
||||
if new_output.endswith("```") and '```python\n' in new_output:
|
||||
python_code = new_output.split('```python\n')[1].split("```")[0]
|
||||
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
|
||||
if input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower() == 'y':
|
||||
my_stdout = StringIO()
|
||||
try:
|
||||
with redirect_stdout(my_stdout): exec(python_code)
|
||||
result = my_stdout.getvalue()
|
||||
except Exception as e:
|
||||
result = ''.join(traceback.format_exception_only(e))
|
||||
toks += spp.encode(f"\nOutput:\n```\n{result}```")
|
||||
outputted = output(outputted, toks, "yellow")
|
||||
old_output_len = len(outputted)
|
||||
print("")
|
||||
if new_output.endswith("```") and "```python\n" in new_output:
|
||||
python_code = new_output.split("```python\n")[1].split("```")[0]
|
||||
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
|
||||
if (
|
||||
input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower()
|
||||
== "y"
|
||||
):
|
||||
my_stdout = StringIO()
|
||||
try:
|
||||
with redirect_stdout(my_stdout):
|
||||
exec(python_code)
|
||||
result = my_stdout.getvalue()
|
||||
except Exception as e:
|
||||
result = "".join(traceback.format_exception_only(e))
|
||||
toks += spp.encode(f"\nOutput:\n```\n{result}```")
|
||||
outputted = output(outputted, toks, "yellow")
|
||||
old_output_len = len(outputted)
|
||||
print("")
|
||||
|
|
|
@ -7,32 +7,54 @@ from tinygrad.helpers import getenv, fetch
|
|||
import ast
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
dirname = Path(__file__).parent
|
||||
if getenv("CLANG", "") == "":
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
ext = "js" if getenv("WEBGPU", "") != "" else "json"
|
||||
with open(dirname / f"net.{ext}", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
else:
|
||||
cprog = [prg]
|
||||
# image library!
|
||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")]
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
mode = (
|
||||
"clang"
|
||||
if getenv("CLANG", "") != ""
|
||||
else "webgpu"
|
||||
if getenv("WEBGPU", "") != ""
|
||||
else ""
|
||||
)
|
||||
prg, inp_sizes, out_sizes, state = export_model(
|
||||
model, mode, Tensor.randn(1, 3, 224, 224)
|
||||
)
|
||||
dirname = Path(__file__).parent
|
||||
if getenv("CLANG", "") == "":
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
ext = "js" if getenv("WEBGPU", "") != "" else "json"
|
||||
with open(dirname / f"net.{ext}", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
else:
|
||||
cprog = [prg]
|
||||
# image library!
|
||||
cprog += [
|
||||
"#define STB_IMAGE_IMPLEMENTATION",
|
||||
fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h")
|
||||
.read_text()
|
||||
.replace("half", "_half"),
|
||||
]
|
||||
|
||||
# imagenet labels, move to datasets?
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
|
||||
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
|
||||
outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
cprog.append(inputs)
|
||||
cprog.append(outputs)
|
||||
# imagenet labels, move to datasets?
|
||||
lbls = ast.literal_eval(
|
||||
fetch(
|
||||
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||
).read_text()
|
||||
)
|
||||
lbls = ['"' + lbls[i] + '"' for i in range(1000)]
|
||||
inputs = "\n".join(
|
||||
[f"float {inp}[{inp_size}];" for inp, inp_size in inp_sizes.items()]
|
||||
)
|
||||
outputs = "\n".join(
|
||||
[f"float {out}[{out_size}];" for out, out_size in out_sizes.items()]
|
||||
)
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
cprog.append(inputs)
|
||||
cprog.append(outputs)
|
||||
|
||||
# buffers (empty + weights)
|
||||
cprog.append("""
|
||||
# buffers (empty + weights)
|
||||
cprog.append(
|
||||
"""
|
||||
int main(int argc, char* argv[]) {
|
||||
int DEBUG = getenv("DEBUG") != NULL ? atoi(getenv("DEBUG")) : 0;
|
||||
int X=0, Y=0, chan=0;
|
||||
|
@ -62,8 +84,9 @@ if __name__ == "__main__":
|
|||
}
|
||||
if (DEBUG) printf("category : %d (%s) with %f\\n", best_idx, lbls[best_idx], best);
|
||||
else printf("%s\\n", lbls[best_idx]);
|
||||
}""")
|
||||
}"""
|
||||
)
|
||||
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print('\n'.join(cprog))
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print("\n".join(cprog))
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
# An example to compile a small Tensorflow model to extremely portable C code
|
||||
|
||||
import os, sys
|
||||
os.environ["CLANG"] = '1'
|
||||
os.environ["GPU"] = '1'
|
||||
|
||||
os.environ["CLANG"] = "1"
|
||||
os.environ["GPU"] = "1"
|
||||
|
||||
import numpy as np
|
||||
import subprocess
|
||||
|
@ -12,55 +13,66 @@ from examples.compile_efficientnet import compile_net
|
|||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
def get_uncompiled_model2(dataset_size=32, output_size=4):
|
||||
inputs = tf.keras.Input(shape=(dataset_size,), name="inputs")
|
||||
x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs)
|
||||
x = tf.keras.layers.BatchNormalization()(x)
|
||||
x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x)
|
||||
outputs = tf.keras.layers.Dense(output_size, activation="sigmoid", name="predictions")(x)
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
return model
|
||||
inputs = tf.keras.Input(shape=(dataset_size,), name="inputs")
|
||||
x = tf.keras.layers.Dense(16, activation="relu", name="dense_1")(inputs)
|
||||
x = tf.keras.layers.BatchNormalization()(x)
|
||||
x = tf.keras.layers.Dense(32, activation="relu", name="dense_2")(x)
|
||||
outputs = tf.keras.layers.Dense(
|
||||
output_size, activation="sigmoid", name="predictions"
|
||||
)(x)
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
return model
|
||||
|
||||
|
||||
def create_onnx_model(keras_model):
|
||||
input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')]
|
||||
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
|
||||
return onnx_model
|
||||
input_signature = [tf.TensorSpec([1, 32], tf.float32, name="x")]
|
||||
onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13)
|
||||
return onnx_model
|
||||
|
||||
|
||||
def compile_onnx_model(onnx_model):
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
|
||||
from tinygrad.jit import TinyJit
|
||||
@TinyJit
|
||||
def run(x): return run_onnx({"x": x}, debug=False)['predictions'].realize()
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
the_input = Tensor.randn(1,32)
|
||||
the_output = run(the_input)
|
||||
the_output = run(the_input)
|
||||
@TinyJit
|
||||
def run(x):
|
||||
return run_onnx({"x": x}, debug=False)["predictions"].realize()
|
||||
|
||||
special_names = {id(the_input.lazydata.realized.cl): "input", id(the_output.lazydata.realized.cl): "outputs"}
|
||||
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
cprog = ["#include <string.h>", "#include <stdio.h>", "#include <stdlib.h>"] + cprog
|
||||
the_input = Tensor.randn(1, 32)
|
||||
the_output = run(the_input)
|
||||
the_output = run(the_input)
|
||||
|
||||
# buffers (all except input)
|
||||
cprog += [f"float {x[0]}[{x[1]}];" for x in bufs.values() if x[0] != "input"]
|
||||
special_names = {
|
||||
id(the_input.lazydata.realized.cl): "input",
|
||||
id(the_output.lazydata.realized.cl): "outputs",
|
||||
}
|
||||
cprog, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
cprog = ["#include <string.h>", "#include <stdio.h>", "#include <stdlib.h>"] + cprog
|
||||
|
||||
# weights
|
||||
cprog.append("void initialize(float *weights) {")
|
||||
weights = bytes()
|
||||
for name,cl in bufs_to_save.items():
|
||||
cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl)});")
|
||||
weights += bytes(memoryview(cl)[0:len(cl)//4])
|
||||
cprog.append("}")
|
||||
# buffers (all except input)
|
||||
cprog += [f"float {x[0]}[{x[1]}];" for x in bufs.values() if x[0] != "input"]
|
||||
|
||||
# write the weights to disk
|
||||
with open("/tmp/tf_weights", "wb") as f:
|
||||
f.write(weights)
|
||||
# weights
|
||||
cprog.append("void initialize(float *weights) {")
|
||||
weights = bytes()
|
||||
for name, cl in bufs_to_save.items():
|
||||
cprog.append(f"memcpy({name}, weights + {len(weights)//4}, {len(cl)});")
|
||||
weights += bytes(memoryview(cl)[0 : len(cl) // 4])
|
||||
cprog.append("}")
|
||||
|
||||
# the net
|
||||
cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"]
|
||||
# write the weights to disk
|
||||
with open("/tmp/tf_weights", "wb") as f:
|
||||
f.write(weights)
|
||||
|
||||
# test program
|
||||
cprog.append(f"""int main(int argc, char *argv[]) {{
|
||||
# the net
|
||||
cprog += ["float *infer(float *input) {"] + statements + ["return outputs;", "}"]
|
||||
|
||||
# test program
|
||||
cprog.append(
|
||||
f"""int main(int argc, char *argv[]) {{
|
||||
// read in the weights from disk
|
||||
FILE *f = fopen("/tmp/tf_weights", "rb");
|
||||
float *weights = (float *)malloc({len(weights)});
|
||||
|
@ -75,30 +87,42 @@ def compile_onnx_model(onnx_model):
|
|||
for (int i = 0; i < 32; i++) scanf("%f", &input[i]);
|
||||
float *outputs = infer(input);
|
||||
printf("%f %f %f %f\\n", outputs[0], outputs[1], outputs[2], outputs[3]);
|
||||
}}""")
|
||||
}}"""
|
||||
)
|
||||
|
||||
# ready the program
|
||||
prg = '\n'.join(cprog)
|
||||
print(prg)
|
||||
# ready the program
|
||||
prg = "\n".join(cprog)
|
||||
print(prg)
|
||||
|
||||
# add test weights
|
||||
subprocess.check_output(['clang', '-O2', '-lm', '-fPIC', '-x', 'c', '-', '-o', "/tmp/tf_test"], input=prg.encode('utf-8'))
|
||||
# add test weights
|
||||
subprocess.check_output(
|
||||
["clang", "-O2", "-lm", "-fPIC", "-x", "c", "-", "-o", "/tmp/tf_test"],
|
||||
input=prg.encode("utf-8"),
|
||||
)
|
||||
|
||||
tinygrad_output = [x for x in the_output.numpy()[0]]
|
||||
print("tinygrad:", tinygrad_output, file=sys.stderr)
|
||||
tinygrad_output = [x for x in the_output.numpy()[0]]
|
||||
print("tinygrad:", tinygrad_output, file=sys.stderr)
|
||||
|
||||
c_input = ' '.join(["%f" % x for x in the_input[0].numpy()])+"\n"
|
||||
c_output = [float(x) for x in subprocess.check_output(["/tmp/tf_test"], input=c_input.encode('utf-8')).decode('utf-8').strip().split(" ")]
|
||||
print("compiled:", c_output, file=sys.stderr)
|
||||
c_input = " ".join(["%f" % x for x in the_input[0].numpy()]) + "\n"
|
||||
c_output = [
|
||||
float(x)
|
||||
for x in subprocess.check_output(
|
||||
["/tmp/tf_test"], input=c_input.encode("utf-8")
|
||||
)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split(" ")
|
||||
]
|
||||
print("compiled:", c_output, file=sys.stderr)
|
||||
|
||||
np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5)
|
||||
return the_input.numpy(), c_output
|
||||
|
||||
np.testing.assert_allclose(tinygrad_output, c_output, atol=1e-5, rtol=1e-5)
|
||||
return the_input.numpy(), c_output
|
||||
|
||||
if __name__ == "__main__":
|
||||
keras_model = get_uncompiled_model2()
|
||||
onnx_model = create_onnx_model(keras_model)
|
||||
test_input, test_output = compile_onnx_model(onnx_model)
|
||||
tf_output = keras_model(test_input).numpy()[0]
|
||||
print("keras: ", tf_output, file=sys.stderr)
|
||||
np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)
|
||||
|
||||
keras_model = get_uncompiled_model2()
|
||||
onnx_model = create_onnx_model(keras_model)
|
||||
test_input, test_output = compile_onnx_model(onnx_model)
|
||||
tf_output = keras_model(test_input).numpy()[0]
|
||||
print("keras: ", tf_output, file=sys.stderr)
|
||||
np.testing.assert_allclose(tf_output, test_output, atol=1e-5, rtol=1e-5)
|
||||
|
|
|
@ -12,7 +12,14 @@ import pyaudio
|
|||
import yaml
|
||||
from llama import LLaMa
|
||||
from vits import MODELS as VITS_MODELS
|
||||
from vits import Y_LENGTH_ESTIMATE_SCALARS, HParams, Synthesizer, TextMapper, get_hparams_from_file, load_model
|
||||
from vits import (
|
||||
Y_LENGTH_ESTIMATE_SCALARS,
|
||||
HParams,
|
||||
Synthesizer,
|
||||
TextMapper,
|
||||
get_hparams_from_file,
|
||||
load_model,
|
||||
)
|
||||
from whisper import init_whisper, transcribe_waveform
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
@ -29,316 +36,557 @@ IM_END = 32002
|
|||
|
||||
|
||||
# Functions for encoding prompts to chatml md
|
||||
def encode_prompt(spp, k, v): return [IM_START]+spp.encode(f"{k}\n{v}")+[IM_END]+spp.encode("\n")
|
||||
def start_prompt(spp, k): return [IM_START]+spp.encode(f"{k}\n")
|
||||
def encode_prompt(spp, k, v):
|
||||
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
|
||||
|
||||
|
||||
def start_prompt(spp, k):
|
||||
return [IM_START] + spp.encode(f"{k}\n")
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
for i in range(0, len(lst), n): yield lst[i:i + n]
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def create_fixed_tokenizer():
|
||||
"""Function needed for extending tokenizer with additional chat tokens"""
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
tokenizer_path = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model")
|
||||
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
|
||||
print("creating fixed tokenizer")
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(tokenizer_path.read_bytes())
|
||||
# https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
tokenizer_path.write_bytes(mp.SerializeToString())
|
||||
return tokenizer_path
|
||||
"""Function needed for extending tokenizer with additional chat tokens"""
|
||||
import extra.junk.sentencepiece_model_pb2 as spb2
|
||||
|
||||
tokenizer_path = fetch(
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/tokenizer.model"
|
||||
)
|
||||
if SentencePieceProcessor(model_file=str(tokenizer_path)).vocab_size() != 32003:
|
||||
print("creating fixed tokenizer")
|
||||
mp = spb2.ModelProto()
|
||||
mp.ParseFromString(tokenizer_path.read_bytes())
|
||||
# https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/blob/main/added_tokens.json
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="[PAD]", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
|
||||
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
|
||||
tokenizer_path.write_bytes(mp.SerializeToString())
|
||||
return tokenizer_path
|
||||
|
||||
|
||||
def llama_prepare(
|
||||
llama: LLaMa, temperature: float, pre_prompt_path: Path
|
||||
) -> tuple[list[int], str, str, str]:
|
||||
"""Prepares a llama model from a specified pre-prompt file"""
|
||||
with open(str(pre_prompt_path)) as f:
|
||||
config = yaml.safe_load(f.read())
|
||||
toks = [llama.tokenizer.bos_id()] + encode_prompt(
|
||||
llama.tokenizer, "system", config["pre_prompt"].replace("\n", " ")
|
||||
)
|
||||
for i in config["examples"]:
|
||||
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
|
||||
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
|
||||
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
|
||||
return (
|
||||
toks,
|
||||
config["user_delim"],
|
||||
config["resp_delim"],
|
||||
len(toks),
|
||||
llama.tokenizer.decode(toks),
|
||||
)
|
||||
|
||||
def llama_prepare(llama: LLaMa, temperature: float, pre_prompt_path: Path) -> tuple[list[int], str, str, str]:
|
||||
"""Prepares a llama model from a specified pre-prompt file"""
|
||||
with open(str(pre_prompt_path)) as f:
|
||||
config = yaml.safe_load(f.read())
|
||||
toks = [llama.tokenizer.bos_id()] + encode_prompt(llama.tokenizer, "system", config["pre_prompt"].replace("\n", " "))
|
||||
for i in config["examples"]:
|
||||
toks += encode_prompt(llama.tokenizer, config["user_delim"], i["user_prompt"])
|
||||
toks += encode_prompt(llama.tokenizer, config["resp_delim"], i["resp_prompt"])
|
||||
llama.model(Tensor([toks]), 0, temperature).realize() # NOTE: outputs are not used
|
||||
return toks, config["user_delim"], config["resp_delim"], len(toks), llama.tokenizer.decode(toks)
|
||||
|
||||
def llama_generate(
|
||||
llama: LLaMa,
|
||||
toks: list[int],
|
||||
outputted: str,
|
||||
prompt: str,
|
||||
start_pos: int,
|
||||
user_delim: str,
|
||||
resp_delim: str,
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
llama: LLaMa,
|
||||
toks: list[int],
|
||||
outputted: str,
|
||||
prompt: str,
|
||||
start_pos: int,
|
||||
user_delim: str,
|
||||
resp_delim: str,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
):
|
||||
"""Generates an output for the specified prompt"""
|
||||
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
|
||||
toks += start_prompt(llama.tokenizer, resp_delim)
|
||||
"""Generates an output for the specified prompt"""
|
||||
toks += encode_prompt(llama.tokenizer, user_delim, prompt)
|
||||
toks += start_prompt(llama.tokenizer, resp_delim)
|
||||
|
||||
outputted = llama.tokenizer.decode(toks)
|
||||
init_length = len(outputted)
|
||||
for _ in range(max_tokens):
|
||||
probs_np = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).numpy()
|
||||
token = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
toks.append(token)
|
||||
outputted = llama.tokenizer.decode(toks)
|
||||
init_length = len(outputted)
|
||||
for _ in range(max_tokens):
|
||||
probs_np = llama.model(
|
||||
Tensor([toks[start_pos:]]), start_pos, temperature
|
||||
).numpy()
|
||||
token = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
toks.append(token)
|
||||
|
||||
cur = llama.tokenizer.decode(toks)
|
||||
cur = llama.tokenizer.decode(toks)
|
||||
|
||||
# Print is just for debugging
|
||||
sys.stdout.write(cur[len(outputted) :])
|
||||
sys.stdout.flush()
|
||||
outputted = cur
|
||||
if toks[-1] == IM_END:
|
||||
break
|
||||
else:
|
||||
toks.append(IM_END)
|
||||
print() # because the output is flushed
|
||||
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
|
||||
|
||||
# Print is just for debugging
|
||||
sys.stdout.write(cur[len(outputted):])
|
||||
sys.stdout.flush()
|
||||
outputted = cur
|
||||
if toks[-1] == IM_END: break
|
||||
else:
|
||||
toks.append(IM_END)
|
||||
print() # because the output is flushed
|
||||
return outputted, start_pos, outputted[init_length:].replace("<|im_end|>", "")
|
||||
|
||||
def tts(
|
||||
text_to_synthesize: str,
|
||||
synth: Synthesizer,
|
||||
hps: HParams,
|
||||
emotion_embedding: Path,
|
||||
speaker_id: int,
|
||||
model_to_use: str,
|
||||
noise_scale: float,
|
||||
noise_scale_w: float,
|
||||
length_scale: float,
|
||||
estimate_max_y_length: bool,
|
||||
text_mapper: TextMapper,
|
||||
model_has_multiple_speakers: bool,
|
||||
batch_size=600,
|
||||
vits_batch_size=1000
|
||||
text_to_synthesize: str,
|
||||
synth: Synthesizer,
|
||||
hps: HParams,
|
||||
emotion_embedding: Path,
|
||||
speaker_id: int,
|
||||
model_to_use: str,
|
||||
noise_scale: float,
|
||||
noise_scale_w: float,
|
||||
length_scale: float,
|
||||
estimate_max_y_length: bool,
|
||||
text_mapper: TextMapper,
|
||||
model_has_multiple_speakers: bool,
|
||||
batch_size=600,
|
||||
vits_batch_size=1000,
|
||||
):
|
||||
if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||
if model_to_use == "mmts-tts":
|
||||
text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||
|
||||
# Convert the input text to a tensor.
|
||||
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
|
||||
init_shape = stn_tst.shape
|
||||
assert init_shape[0] < batch_size, "text is too long"
|
||||
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||
sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
|
||||
# Convert the input text to a tensor.
|
||||
stn_tst = text_mapper.get_text(
|
||||
text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners
|
||||
)
|
||||
init_shape = stn_tst.shape
|
||||
assert init_shape[0] < batch_size, "text is too long"
|
||||
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(
|
||||
0
|
||||
), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||
sid = (
|
||||
Tensor([speaker_id], dtype=dtypes.int64)
|
||||
if model_has_multiple_speakers
|
||||
else None
|
||||
)
|
||||
|
||||
# Perform inference.
|
||||
audio_tensor = synth.infer(
|
||||
x_tst,
|
||||
x_tst_lengths,
|
||||
sid,
|
||||
noise_scale,
|
||||
length_scale,
|
||||
noise_scale_w,
|
||||
emotion_embedding=emotion_embedding,
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use]
|
||||
if estimate_max_y_length
|
||||
else None,
|
||||
batch_size=vits_batch_size,
|
||||
)[0, 0]
|
||||
# Save the audio output.
|
||||
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
||||
return audio_data
|
||||
|
||||
# Perform inference.
|
||||
audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0]
|
||||
# Save the audio output.
|
||||
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
||||
return audio_data
|
||||
|
||||
def init_vits(
|
||||
model_to_use: str,
|
||||
emotion_path: Path,
|
||||
speaker_id: int,
|
||||
seed: int,
|
||||
model_to_use: str,
|
||||
emotion_path: Path,
|
||||
speaker_id: int,
|
||||
seed: int,
|
||||
):
|
||||
model_config = VITS_MODELS[model_to_use]
|
||||
model_config = VITS_MODELS[model_to_use]
|
||||
|
||||
# Load the hyperparameters from the config file.
|
||||
hps = get_hparams_from_file(fetch(model_config[0]))
|
||||
# Load the hyperparameters from the config file.
|
||||
hps = get_hparams_from_file(fetch(model_config[0]))
|
||||
|
||||
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
||||
model_has_multiple_speakers = hps.data.n_speakers > 0
|
||||
if model_has_multiple_speakers:
|
||||
if speaker_id >= hps.data.n_speakers: raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
|
||||
if hps.__contains__("speakers"): # maps speaker ids to names
|
||||
speakers = hps.speakers
|
||||
if isinstance(speakers, list): speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
||||
# If model has multiple speakers, validate speaker id and retrieve name if available.
|
||||
model_has_multiple_speakers = hps.data.n_speakers > 0
|
||||
if model_has_multiple_speakers:
|
||||
if speaker_id >= hps.data.n_speakers:
|
||||
raise ValueError(f"Speaker ID {speaker_id} is invalid for this model.")
|
||||
if hps.__contains__("speakers"): # maps speaker ids to names
|
||||
speakers = hps.speakers
|
||||
if isinstance(speakers, list):
|
||||
speakers = {speaker: i for i, speaker in enumerate(speakers)}
|
||||
|
||||
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
|
||||
emotion_embedding = None
|
||||
if emotion_path is not None:
|
||||
if emotion_path.endswith(".npy"): emotion_embedding = Tensor(np.load(emotion_path), dtype=dtypes.int64).unsqueeze(0)
|
||||
else: raise ValueError("Emotion path must be a .npy file.")
|
||||
# Load emotions if any. TODO: find an english model with emotions, this is untested atm.
|
||||
emotion_embedding = None
|
||||
if emotion_path is not None:
|
||||
if emotion_path.endswith(".npy"):
|
||||
emotion_embedding = Tensor(
|
||||
np.load(emotion_path), dtype=dtypes.int64
|
||||
).unsqueeze(0)
|
||||
else:
|
||||
raise ValueError("Emotion path must be a .npy file.")
|
||||
|
||||
# Load symbols, instantiate TextMapper and clean the text.
|
||||
if hps.__contains__("symbols"): symbols = hps.symbols
|
||||
elif model_to_use == "mmts-tts": symbols = [x.replace("\n", "") for x in fetch("https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt").open(encoding="utf-8").readlines()]
|
||||
else: symbols = ['_'] + list(';:,.!?¡¿—…"«»“” ') + list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz') + list("ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ")
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
# Load symbols, instantiate TextMapper and clean the text.
|
||||
if hps.__contains__("symbols"):
|
||||
symbols = hps.symbols
|
||||
elif model_to_use == "mmts-tts":
|
||||
symbols = [
|
||||
x.replace("\n", "")
|
||||
for x in fetch(
|
||||
"https://huggingface.co/facebook/mms-tts/raw/main/full_models/eng/vocab.txt"
|
||||
)
|
||||
.open(encoding="utf-8")
|
||||
.readlines()
|
||||
]
|
||||
else:
|
||||
symbols = (
|
||||
["_"]
|
||||
+ list(';:,.!?¡¿—…"«»“” ')
|
||||
+ list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")
|
||||
+ list(
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
)
|
||||
text_mapper = TextMapper(apply_cleaners=True, symbols=symbols)
|
||||
|
||||
# Load the model.
|
||||
Tensor.no_grad = True
|
||||
if seed is not None:
|
||||
Tensor.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
net_g = load_model(text_mapper.symbols, hps, model_config)
|
||||
# Load the model.
|
||||
Tensor.no_grad = True
|
||||
if seed is not None:
|
||||
Tensor.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
net_g = load_model(text_mapper.symbols, hps, model_config)
|
||||
|
||||
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
|
||||
|
||||
return net_g, emotion_embedding, text_mapper, hps, model_has_multiple_speakers
|
||||
|
||||
@contextmanager
|
||||
def output_stream(num_channels: int, sample_rate: int):
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True)
|
||||
yield stream
|
||||
except KeyboardInterrupt: pass
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(
|
||||
format=pyaudio.paInt16, channels=num_channels, rate=sample_rate, output=True
|
||||
)
|
||||
yield stream
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def log_writer():
|
||||
try:
|
||||
logs = []
|
||||
yield logs
|
||||
finally:
|
||||
sep = "="*os.get_terminal_size()[1]
|
||||
print(f"{sep[:-1]}\nCHAT LOG")
|
||||
print(*logs, sep="\n")
|
||||
print(sep)
|
||||
try:
|
||||
logs = []
|
||||
yield logs
|
||||
finally:
|
||||
sep = "=" * os.get_terminal_size()[1]
|
||||
print(f"{sep[:-1]}\nCHAT LOG")
|
||||
print(*logs, sep="\n")
|
||||
print(sep)
|
||||
|
||||
|
||||
def listener(q: mp.Queue, event: mp.Event):
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
did_print = False
|
||||
while True:
|
||||
data = stream.read(CHUNK) # read data to avoid overflow
|
||||
if event.is_set():
|
||||
if not did_print:
|
||||
print("listening")
|
||||
did_print = True
|
||||
q.put(((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3))
|
||||
else:
|
||||
try:
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK,
|
||||
)
|
||||
did_print = False
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
while True:
|
||||
data = stream.read(CHUNK) # read data to avoid overflow
|
||||
if event.is_set():
|
||||
if not did_print:
|
||||
print("listening")
|
||||
did_print = True
|
||||
q.put(((np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3))
|
||||
else:
|
||||
did_print = False
|
||||
finally:
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
|
||||
def mp_output_stream(
|
||||
q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int
|
||||
):
|
||||
with output_stream(num_channels, sample_rate) as stream:
|
||||
while True:
|
||||
try:
|
||||
stream.write(q.get())
|
||||
counter.value += 1
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
def mp_output_stream(q: mp.Queue, counter: mp.Value, num_channels: int, sample_rate: int):
|
||||
with output_stream(num_channels, sample_rate) as stream:
|
||||
while True:
|
||||
try:
|
||||
stream.write(q.get())
|
||||
counter.value += 1
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
import nltk
|
||||
nltk.download("punkt")
|
||||
Tensor.no_grad = True
|
||||
# Parse CLI arguments
|
||||
parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
|
||||
import nltk
|
||||
|
||||
# Whisper args
|
||||
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
|
||||
nltk.download("punkt")
|
||||
Tensor.no_grad = True
|
||||
# Parse CLI arguments
|
||||
parser = argparse.ArgumentParser("Have a tiny conversation with tinygrad")
|
||||
|
||||
# LLAMA args
|
||||
parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ")
|
||||
parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate")
|
||||
parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use")
|
||||
parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")
|
||||
parser.add_argument("--llama_tokenizer", type=Path, default=None, required=False, help="Path to llama tokenizer.model")
|
||||
# Whisper args
|
||||
parser.add_argument("--whisper_model_name", type=str, default="tiny.en")
|
||||
|
||||
# vits args
|
||||
parser.add_argument("--vits_model_to_use", default="vctk", help="Specify the model to use. Default is 'vctk'.")
|
||||
parser.add_argument("--vits_speaker_id", type=int, default=12, help="Specify the speaker ID. Default is 6.")
|
||||
parser.add_argument("--vits_noise_scale", type=float, default=0.667, help="Specify the noise scale. Default is 0.667.")
|
||||
parser.add_argument("--vits_noise_scale_w", type=float, default=0.8, help="Specify the noise scale w. Default is 0.8.")
|
||||
parser.add_argument("--vits_length_scale", type=float, default=1, help="Specify the length scale. Default is 1.")
|
||||
parser.add_argument("--vits_seed", type=int, default=None, help="Specify the seed (set to None if no seed). Default is 1337.")
|
||||
parser.add_argument("--vits_num_channels", type=int, default=1, help="Specify the number of audio output channels. Default is 1.")
|
||||
parser.add_argument("--vits_sample_width", type=int, default=2, help="Specify the number of bytes per sample, adjust if necessary. Default is 2.")
|
||||
parser.add_argument("--vits_emotion_path", type=Path, default=None, help="Specify the path to emotion reference.")
|
||||
parser.add_argument("--vits_estimate_max_y_length", type=str, default=False, help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.")
|
||||
parser.add_argument("--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary.")
|
||||
|
||||
# conversation args
|
||||
parser.add_argument("--max_sentence_length", type=int, default=20, help="Max words in one sentence to pass to vits")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init models
|
||||
model, enc = init_whisper(args.whisper_model_name)
|
||||
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(args.vits_model_to_use, args.vits_emotion_path, args.vits_speaker_id, args.vits_seed)
|
||||
|
||||
# Download tinyllama chat as a default model
|
||||
if args.llama_model is None:
|
||||
args.llama_model = fetch("https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors", "tinyllamachat.safetensors")
|
||||
args.llama_gen = "tiny"
|
||||
args.llama_size = "1B-Chat"
|
||||
# Add 3 more tokens to the tokenizer
|
||||
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"): args.llama_tokenizer = create_fixed_tokenizer()
|
||||
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
|
||||
llama = LLaMa.build(args.llama_model, tokenizer_path, args.llama_gen, args.llama_size, args.llama_quantize)
|
||||
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(llama, args.llama_temperature, args.llama_pre_prompt_path)
|
||||
|
||||
# Start child process for mic input
|
||||
q = mp.Queue()
|
||||
is_listening_event = mp.Event()
|
||||
p = mp.Process(target=listener, args=(q, is_listening_event,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
# Start child process for speaker output
|
||||
out_q = mp.Queue()
|
||||
out_counter = mp.Value("i", 0)
|
||||
out_p = mp.Process(target=mp_output_stream, args=(out_q, out_counter, args.vits_num_channels, hps.data.sampling_rate,))
|
||||
out_p.daemon = True
|
||||
out_p.start()
|
||||
|
||||
# JIT tts
|
||||
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
|
||||
tts(
|
||||
i, synth, hps, emotion_embedding,
|
||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
||||
args.vits_noise_scale_w, args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
||||
# LLAMA args
|
||||
parser.add_argument(
|
||||
"--llama_pre_prompt_path",
|
||||
type=Path,
|
||||
default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml",
|
||||
help="Path to yaml file which contains all pre-prompt data needed. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_count", type=int, default=1000, help="Max number of tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Temperature in the softmax",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_quantize",
|
||||
action="store_true",
|
||||
help="Quantize the weights to int8 in memory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_model",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_gen",
|
||||
type=str,
|
||||
default="tiny",
|
||||
required=False,
|
||||
help="Generation of the model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_size",
|
||||
type=str,
|
||||
default="1B-Chat",
|
||||
required=False,
|
||||
help="Size of model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llama_tokenizer",
|
||||
type=Path,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path to llama tokenizer.model",
|
||||
)
|
||||
|
||||
# Start the pipeline
|
||||
with log_writer() as log:
|
||||
while True:
|
||||
tokens = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
total = np.array([])
|
||||
out_counter.value = 0
|
||||
# vits args
|
||||
parser.add_argument(
|
||||
"--vits_model_to_use",
|
||||
default="vctk",
|
||||
help="Specify the model to use. Default is 'vctk'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_speaker_id",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Specify the speaker ID. Default is 6.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_noise_scale",
|
||||
type=float,
|
||||
default=0.667,
|
||||
help="Specify the noise scale. Default is 0.667.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_noise_scale_w",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Specify the noise scale w. Default is 0.8.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_length_scale",
|
||||
type=float,
|
||||
default=1,
|
||||
help="Specify the length scale. Default is 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Specify the seed (set to None if no seed). Default is 1337.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_num_channels",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Specify the number of audio output channels. Default is 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_sample_width",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Specify the number of bytes per sample, adjust if necessary. Default is 2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_emotion_path",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Specify the path to emotion reference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_estimate_max_y_length",
|
||||
type=str,
|
||||
default=False,
|
||||
help="If true, overestimate the output length and then trim it to the correct length, to prevent premature realization, much more performant for larger inputs, for smaller inputs not so much. Default is False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vits_vocab_path", type=Path, default=None, help="Path to the TTS vocabulary."
|
||||
)
|
||||
|
||||
s = time.perf_counter()
|
||||
is_listening_event.set()
|
||||
prev_text = None
|
||||
while True:
|
||||
for _ in range(RATE // CHUNK): total = np.concatenate([total, q.get()])
|
||||
txt = transcribe_waveform(model, enc, [total], truncate=True)
|
||||
print(txt, end="\r")
|
||||
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()): continue
|
||||
if prev_text is not None and prev_text == txt:
|
||||
is_listening_event.clear()
|
||||
break
|
||||
prev_text = txt
|
||||
print() # to avoid llama printing on the same line
|
||||
log.append(f"{user_delim.capitalize()}: {txt}")
|
||||
# conversation args
|
||||
parser.add_argument(
|
||||
"--max_sentence_length",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Max words in one sentence to pass to vits",
|
||||
)
|
||||
|
||||
# Generate with llama
|
||||
with Timing("llama generation: "):
|
||||
outputted, start_pos, response = llama_generate(
|
||||
llama, toks, outputted, txt, start_pos,
|
||||
user_delim=user_delim, resp_delim=resp_delim, temperature=args.llama_temperature,
|
||||
max_tokens=args.llama_count
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init models
|
||||
model, enc = init_whisper(args.whisper_model_name)
|
||||
synth, emotion_embedding, text_mapper, hps, model_has_multiple_speakers = init_vits(
|
||||
args.vits_model_to_use,
|
||||
args.vits_emotion_path,
|
||||
args.vits_speaker_id,
|
||||
args.vits_seed,
|
||||
)
|
||||
|
||||
# Download tinyllama chat as a default model
|
||||
if args.llama_model is None:
|
||||
args.llama_model = fetch(
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.4/resolve/main/model.safetensors",
|
||||
"tinyllamachat.safetensors",
|
||||
)
|
||||
log.append(f"{resp_delim.capitalize()}: {response}")
|
||||
args.llama_gen = "tiny"
|
||||
args.llama_size = "1B-Chat"
|
||||
# Add 3 more tokens to the tokenizer
|
||||
if args.llama_gen == "tiny" and args.llama_size.endswith("Chat"):
|
||||
args.llama_tokenizer = create_fixed_tokenizer()
|
||||
tokenizer_path = args.llama_tokenizer or args.llama_model.parent / "tokenizer.model"
|
||||
llama = LLaMa.build(
|
||||
args.llama_model,
|
||||
tokenizer_path,
|
||||
args.llama_gen,
|
||||
args.llama_size,
|
||||
args.llama_quantize,
|
||||
)
|
||||
toks, user_delim, resp_delim, start_pos, outputted = llama_prepare(
|
||||
llama, args.llama_temperature, args.llama_pre_prompt_path
|
||||
)
|
||||
|
||||
# Convert to voice
|
||||
with Timing("tts: "):
|
||||
sentences = nltk.sent_tokenize(response.replace('"', ""))
|
||||
for i in sentences:
|
||||
total = np.array([], dtype=np.int16)
|
||||
for j in chunks(i.split(), args.max_sentence_length):
|
||||
audio_data = tts(
|
||||
" ".join(j), synth, hps, emotion_embedding,
|
||||
args.vits_speaker_id, args.vits_model_to_use, args.vits_noise_scale,
|
||||
args.vits_noise_scale_w, args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length, text_mapper, model_has_multiple_speakers
|
||||
)
|
||||
total = np.concatenate([total, audio_data])
|
||||
out_q.put(total.tobytes())
|
||||
while out_counter.value < len(sentences): continue
|
||||
log.append(f"Total: {time.perf_counter() - s}")
|
||||
# Start child process for mic input
|
||||
q = mp.Queue()
|
||||
is_listening_event = mp.Event()
|
||||
p = mp.Process(
|
||||
target=listener,
|
||||
args=(
|
||||
q,
|
||||
is_listening_event,
|
||||
),
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
# Start child process for speaker output
|
||||
out_q = mp.Queue()
|
||||
out_counter = mp.Value("i", 0)
|
||||
out_p = mp.Process(
|
||||
target=mp_output_stream,
|
||||
args=(
|
||||
out_q,
|
||||
out_counter,
|
||||
args.vits_num_channels,
|
||||
hps.data.sampling_rate,
|
||||
),
|
||||
)
|
||||
out_p.daemon = True
|
||||
out_p.start()
|
||||
|
||||
# JIT tts
|
||||
for i in ["Hello, I'm a chat bot", "I am capable of doing a lot of things"]:
|
||||
tts(
|
||||
i,
|
||||
synth,
|
||||
hps,
|
||||
emotion_embedding,
|
||||
args.vits_speaker_id,
|
||||
args.vits_model_to_use,
|
||||
args.vits_noise_scale,
|
||||
args.vits_noise_scale_w,
|
||||
args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length,
|
||||
text_mapper,
|
||||
model_has_multiple_speakers,
|
||||
)
|
||||
|
||||
# Start the pipeline
|
||||
with log_writer() as log:
|
||||
while True:
|
||||
tokens = [
|
||||
enc._special_tokens["<|startoftranscript|>"],
|
||||
enc._special_tokens["<|notimestamps|>"],
|
||||
]
|
||||
total = np.array([])
|
||||
out_counter.value = 0
|
||||
|
||||
s = time.perf_counter()
|
||||
is_listening_event.set()
|
||||
prev_text = None
|
||||
while True:
|
||||
for _ in range(RATE // CHUNK):
|
||||
total = np.concatenate([total, q.get()])
|
||||
txt = transcribe_waveform(model, enc, [total], truncate=True)
|
||||
print(txt, end="\r")
|
||||
if txt == "[BLANK_AUDIO]" or re.match(r"^\([\w+ ]+\)$", txt.strip()):
|
||||
continue
|
||||
if prev_text is not None and prev_text == txt:
|
||||
is_listening_event.clear()
|
||||
break
|
||||
prev_text = txt
|
||||
print() # to avoid llama printing on the same line
|
||||
log.append(f"{user_delim.capitalize()}: {txt}")
|
||||
|
||||
# Generate with llama
|
||||
with Timing("llama generation: "):
|
||||
outputted, start_pos, response = llama_generate(
|
||||
llama,
|
||||
toks,
|
||||
outputted,
|
||||
txt,
|
||||
start_pos,
|
||||
user_delim=user_delim,
|
||||
resp_delim=resp_delim,
|
||||
temperature=args.llama_temperature,
|
||||
max_tokens=args.llama_count,
|
||||
)
|
||||
log.append(f"{resp_delim.capitalize()}: {response}")
|
||||
|
||||
# Convert to voice
|
||||
with Timing("tts: "):
|
||||
sentences = nltk.sent_tokenize(response.replace('"', ""))
|
||||
for i in sentences:
|
||||
total = np.array([], dtype=np.int16)
|
||||
for j in chunks(i.split(), args.max_sentence_length):
|
||||
audio_data = tts(
|
||||
" ".join(j),
|
||||
synth,
|
||||
hps,
|
||||
emotion_embedding,
|
||||
args.vits_speaker_id,
|
||||
args.vits_model_to_use,
|
||||
args.vits_noise_scale,
|
||||
args.vits_noise_scale_w,
|
||||
args.vits_length_scale,
|
||||
args.vits_estimate_max_y_length,
|
||||
text_mapper,
|
||||
model_has_multiple_speakers,
|
||||
)
|
||||
total = np.concatenate([total, audio_data])
|
||||
out_q.put(total.tobytes())
|
||||
while out_counter.value < len(sentences):
|
||||
continue
|
||||
log.append(f"Total: {time.perf_counter() - s}")
|
||||
|
|
|
@ -11,78 +11,98 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import getenv, fetch, Timing
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
# TODO: you should be able to put these in the jitted function
|
||||
bias = Tensor([0.485, 0.456, 0.406])
|
||||
scale = Tensor([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
@TinyJit
|
||||
def _infer(model, img):
|
||||
img = img.permute((2,0,1))
|
||||
img = img / 255.0
|
||||
img = img - bias.reshape((1,-1,1,1))
|
||||
img = img / scale.reshape((1,-1,1,1))
|
||||
return model.forward(img).realize()
|
||||
img = img.permute((2, 0, 1))
|
||||
img = img / 255.0
|
||||
img = img - bias.reshape((1, -1, 1, 1))
|
||||
img = img / scale.reshape((1, -1, 1, 1))
|
||||
return model.forward(img).realize()
|
||||
|
||||
|
||||
def infer(model, img):
|
||||
# preprocess image
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
# preprocess image
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize(
|
||||
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
|
||||
)
|
||||
|
||||
img = np.array(img)
|
||||
y0,x0=(np.asarray(img.shape)[:2]-224)//2
|
||||
retimg = img = img[y0:y0+224, x0:x0+224]
|
||||
img = np.array(img)
|
||||
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
||||
retimg = img = img[y0 : y0 + 224, x0 : x0 + 224]
|
||||
|
||||
# if you want to look at the image
|
||||
"""
|
||||
# if you want to look at the image
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
"""
|
||||
|
||||
# run the net
|
||||
out = _infer(model, Tensor(img.astype("float32"))).numpy()
|
||||
# run the net
|
||||
out = _infer(model, Tensor(img.astype("float32"))).numpy()
|
||||
|
||||
# if you want to look at the outputs
|
||||
"""
|
||||
# if you want to look at the outputs
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(out[0])
|
||||
plt.show()
|
||||
"""
|
||||
return out, retimg
|
||||
return out, retimg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet(getenv("NUM", 0))
|
||||
model.load_from_pretrained()
|
||||
# instantiate my net
|
||||
model = EfficientNet(getenv("NUM", 0))
|
||||
model.load_from_pretrained()
|
||||
|
||||
# category labels
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
# category labels
|
||||
lbls = ast.literal_eval(
|
||||
fetch(
|
||||
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||
).read_text()
|
||||
)
|
||||
|
||||
# load image and preprocess
|
||||
url = sys.argv[1] if len(sys.argv) >= 2 else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
||||
if url == 'webcam':
|
||||
import cv2
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
while 1:
|
||||
_ = cap.grab() # discard one frame to circumvent capture buffering
|
||||
ret, frame = cap.read()
|
||||
img = Image.fromarray(frame[:, :, [2,1,0]])
|
||||
lt = time.monotonic_ns()
|
||||
out, retimg = infer(model, img)
|
||||
print(f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms", np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
SCALE = 3
|
||||
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
|
||||
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow('capture', retimg)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
img = Image.open(fetch(url))
|
||||
with Timing("did inference in "):
|
||||
out, _ = infer(model, img)
|
||||
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
# load image and preprocess
|
||||
url = (
|
||||
sys.argv[1]
|
||||
if len(sys.argv) >= 2
|
||||
else "https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/showcase/stable_diffusion_by_tinygrad.jpg"
|
||||
)
|
||||
if url == "webcam":
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
while 1:
|
||||
_ = cap.grab() # discard one frame to circumvent capture buffering
|
||||
ret, frame = cap.read()
|
||||
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
||||
lt = time.monotonic_ns()
|
||||
out, retimg = infer(model, img)
|
||||
print(
|
||||
f"{(time.monotonic_ns()-lt)*1e-6:7.2f} ms",
|
||||
np.argmax(out),
|
||||
np.max(out),
|
||||
lbls[np.argmax(out)],
|
||||
)
|
||||
SCALE = 3
|
||||
simg = cv2.resize(retimg, (224 * SCALE, 224 * SCALE))
|
||||
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow("capture", retimg)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
else:
|
||||
img = Image.open(fetch(url))
|
||||
with Timing("did inference in "):
|
||||
out, _ = infer(model, img)
|
||||
print(np.argmax(out), np.max(out), lbls[np.argmax(out)])
|
||||
|
|
|
@ -3,40 +3,47 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import dtypes
|
||||
from tinygrad import Device
|
||||
|
||||
|
||||
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
|
||||
def bit_extract(x, s, e) -> Tensor:
|
||||
# extract the top bits we don't want
|
||||
top_bits = (x / (1<<(s+1))).floor() * (1<<(s+1))
|
||||
x = (x - top_bits) / (1<<e)
|
||||
return x.contiguous()
|
||||
# extract the top bits we don't want
|
||||
top_bits = (x / (1 << (s + 1))).floor() * (1 << (s + 1))
|
||||
x = (x - top_bits) / (1 << e)
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
def u16_to_f16(x):
|
||||
sign = bit_extract(x, 15, 15).float()
|
||||
exponent = bit_extract(x, 14, 10).float()
|
||||
fraction = bit_extract(x, 9, 0).float()
|
||||
return sign.where(-1, 1) * exponent.where((exponent - 15).exp2() * (1 + fraction / 0x400), 6.103515625e-5 * (fraction / 0x400))
|
||||
sign = bit_extract(x, 15, 15).float()
|
||||
exponent = bit_extract(x, 14, 10).float()
|
||||
fraction = bit_extract(x, 9, 0).float()
|
||||
return sign.where(-1, 1) * exponent.where(
|
||||
(exponent - 15).exp2() * (1 + fraction / 0x400),
|
||||
6.103515625e-5 * (fraction / 0x400),
|
||||
)
|
||||
|
||||
|
||||
def u32_to_f16(oo):
|
||||
oo1 = (oo/0x10000).floor().contiguous()
|
||||
# TODO: this is wrong and unextractable until we do this math in u32
|
||||
oo2 = (oo-(oo1*0x10000)).floor().contiguous()
|
||||
f1 = u16_to_f16(oo1)
|
||||
f2 = u16_to_f16(oo2)
|
||||
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
|
||||
oo1 = (oo / 0x10000).floor().contiguous()
|
||||
# TODO: this is wrong and unextractable until we do this math in u32
|
||||
oo2 = (oo - (oo1 * 0x10000)).floor().contiguous()
|
||||
f1 = u16_to_f16(oo1)
|
||||
f2 = u16_to_f16(oo2)
|
||||
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# random float16
|
||||
Tensor.manual_seed(2)
|
||||
a = Tensor.randn(100, dtype=dtypes.float16)
|
||||
# random float16
|
||||
Tensor.manual_seed(2)
|
||||
a = Tensor.randn(100, dtype=dtypes.float16)
|
||||
|
||||
# this converts it to u32 on disk
|
||||
oo = a.to("disk:/tmp/f16").cast(dtypes.uint32)[:50].to(Device.DEFAULT).realize()
|
||||
# this converts it to u32 on disk
|
||||
oo = a.to("disk:/tmp/f16").cast(dtypes.uint32)[:50].to(Device.DEFAULT).realize()
|
||||
|
||||
# convert to 2xf16 using tinygrad math ops
|
||||
f16 = u32_to_f16(oo)
|
||||
# convert to 2xf16 using tinygrad math ops
|
||||
f16 = u32_to_f16(oo)
|
||||
|
||||
ref = a.numpy()
|
||||
out = f16.numpy().astype(np.float16)
|
||||
print(ref-out)
|
||||
ref = a.numpy()
|
||||
out = f16.numpy().astype(np.float16)
|
||||
print(ref - out)
|
||||
|
||||
np.testing.assert_allclose(ref, out)
|
||||
np.testing.assert_allclose(ref, out)
|
||||
|
|
404
examples/gpt2.py
404
examples/gpt2.py
|
@ -10,183 +10,317 @@ from tinygrad.shape.symbolic import Variable
|
|||
from tinygrad.jit import TinyJit
|
||||
import tiktoken
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch, colored, dtypes
|
||||
from tinygrad.helpers import (
|
||||
GlobalCounters,
|
||||
Timing,
|
||||
DEBUG,
|
||||
getenv,
|
||||
fetch,
|
||||
colored,
|
||||
dtypes,
|
||||
)
|
||||
|
||||
MAX_CONTEXT = getenv("MAX_CONTEXT", 128)
|
||||
HALF = getenv("HALF")
|
||||
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads):
|
||||
self.c_attn = Linear(dim, 3*dim, bias=True)
|
||||
self.c_proj = Linear(dim, dim, bias=True)
|
||||
self.n_heads = n_heads
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
def __init__(self, dim, n_heads):
|
||||
self.c_attn = Linear(dim, 3 * dim, bias=True)
|
||||
self.c_proj = Linear(dim, dim, bias=True)
|
||||
self.n_heads = n_heads
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tensor:
|
||||
if mask is not None:
|
||||
# no symbolic shape qkv when consuming prompts
|
||||
start_pos = start_pos.val
|
||||
def __call__(
|
||||
self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
if mask is not None:
|
||||
# no symbolic shape qkv when consuming prompts
|
||||
start_pos = start_pos.val
|
||||
|
||||
xqkv = self.c_attn(x)
|
||||
xq, xk, xv = [xqkv.shrink((None, None, (i*self.dim, (i+1)*self.dim))).reshape(xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim) for i in range(3)]
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
xqkv = self.c_attn(x)
|
||||
xq, xk, xv = [
|
||||
xqkv.shrink((None, None, (i * self.dim, (i + 1) * self.dim))).reshape(
|
||||
xqkv.shape[0], xqkv.shape[1], self.n_heads, self.head_dim
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
if HALF:
|
||||
self.cache_k = self.cache_k.half()
|
||||
self.cache_v = self.cache_v.half()
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(
|
||||
bsz, MAX_CONTEXT, self.n_heads, self.head_dim
|
||||
), Tensor.zeros(bsz, MAX_CONTEXT, self.n_heads, self.head_dim)
|
||||
if HALF:
|
||||
self.cache_k = self.cache_k.half()
|
||||
self.cache_v = self.cache_v.half()
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,MAX_CONTEXT-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
# update the cache
|
||||
self.cache_k.assign(
|
||||
keys.pad(
|
||||
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
self.cache_v.assign(
|
||||
values.pad(
|
||||
(None, (0, MAX_CONTEXT - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
|
||||
xq, keys, values = (
|
||||
xq.transpose(1, 2),
|
||||
keys.transpose(1, 2),
|
||||
values.transpose(1, 2),
|
||||
)
|
||||
return self.c_proj(
|
||||
xq.scaled_dot_product_attention(keys, values, mask)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, seqlen, -1)
|
||||
)
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
return self.c_proj(xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1))
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim):
|
||||
self.c_fc = Linear(dim, hidden_dim, bias=True)
|
||||
self.c_proj = Linear(hidden_dim, dim, bias=True)
|
||||
def __init__(self, dim, hidden_dim):
|
||||
self.c_fc = Linear(dim, hidden_dim, bias=True)
|
||||
self.c_proj = Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.c_proj(self.c_fc(x).gelu())
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.c_proj(self.c_fc(x).gelu())
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim, n_heads, norm_eps):
|
||||
self.attn = Attention(dim, n_heads)
|
||||
self.mlp = FeedForward(dim, 4*dim)
|
||||
self.ln_1 = LayerNorm(dim, norm_eps)
|
||||
self.ln_2 = LayerNorm(dim, norm_eps)
|
||||
def __init__(self, dim, n_heads, norm_eps):
|
||||
self.attn = Attention(dim, n_heads)
|
||||
self.mlp = FeedForward(dim, 4 * dim)
|
||||
self.ln_1 = LayerNorm(dim, norm_eps)
|
||||
self.ln_2 = LayerNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Variable, mask: Optional[Tensor]):
|
||||
h = x + self.attn(self.ln_1(x), start_pos, mask)
|
||||
return h + self.mlp(self.ln_2(h))
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]):
|
||||
h = x + self.attn(self.ln_1(x), start_pos, mask)
|
||||
return (h + self.mlp(self.ln_2(h)))
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
||||
self.wte = Embedding(vocab_size, dim)
|
||||
self.wpe = Embedding(max_seq_len, dim)
|
||||
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
|
||||
self.ln_f = LayerNorm(dim, norm_eps)
|
||||
self.lm_head = Linear(dim, vocab_size, bias=False)
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
def __init__(self, dim, n_heads, n_layers, norm_eps, vocab_size, max_seq_len=1024):
|
||||
self.wte = Embedding(vocab_size, dim)
|
||||
self.wpe = Embedding(max_seq_len, dim)
|
||||
self.h = [TransformerBlock(dim, n_heads, norm_eps) for _ in range(n_layers)]
|
||||
self.ln_f = LayerNorm(dim, norm_eps)
|
||||
self.lm_head = Linear(dim, vocab_size, bias=False)
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
if not hasattr(self, 'allpos'): self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||
_bsz, seqlen = tokens.shape
|
||||
def forward(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0):
|
||||
if not hasattr(self, "allpos"):
|
||||
self.allpos = Tensor.arange(0, MAX_CONTEXT).reshape(1, -1).realize()
|
||||
_bsz, seqlen = tokens.shape
|
||||
|
||||
# NOTE: cannot convert token indices into half due to precision
|
||||
tok_emb = self.wte(tokens)
|
||||
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos+seqlen))))
|
||||
h = tok_emb + pos_emb
|
||||
# NOTE: cannot convert token indices into half due to precision
|
||||
tok_emb = self.wte(tokens)
|
||||
pos_emb = self.wpe(self.allpos.shrink((None, (start_pos, start_pos + seqlen))))
|
||||
h = tok_emb + pos_emb
|
||||
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos.val+seqlen), float("-inf")).triu(start_pos.val+1).realize() if seqlen > 1 else None
|
||||
mask = (
|
||||
Tensor.full((1, 1, seqlen, start_pos.val + seqlen), float("-inf"))
|
||||
.triu(start_pos.val + 1)
|
||||
.realize()
|
||||
if seqlen > 1
|
||||
else None
|
||||
)
|
||||
|
||||
if HALF:
|
||||
h = h.half()
|
||||
if mask is not None: mask = mask.half()
|
||||
if HALF:
|
||||
h = h.half()
|
||||
if mask is not None:
|
||||
mask = mask.half()
|
||||
|
||||
for hi in self.h: h = hi(h, start_pos=start_pos, mask=mask)
|
||||
for hi in self.h:
|
||||
h = hi(h, start_pos=start_pos, mask=mask)
|
||||
|
||||
logits = self.lm_head(self.ln_f(h))
|
||||
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().realize()
|
||||
logits = self.lm_head(self.ln_f(h))
|
||||
# NOTE: temperature=0 with HALF breaks due to precision, should use argmax instead
|
||||
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(
|
||||
self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0
|
||||
) -> Tensor:
|
||||
return (
|
||||
self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward
|
||||
)(tokens, start_pos, temperature)
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
return (self.forward_jit if tokens.shape[1] == 1 and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
|
||||
|
||||
VOCAB_SIZE = 50257
|
||||
MODEL_PARAMS = {
|
||||
'gpt2': dict(n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 124M params
|
||||
'gpt2-medium': dict(n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 350M params
|
||||
'gpt2-large': dict(n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 774M params
|
||||
'gpt2-xl': dict(n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE), # 1558M params
|
||||
"gpt2": dict(
|
||||
n_layers=12, n_heads=12, dim=768, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 124M params
|
||||
"gpt2-medium": dict(
|
||||
n_layers=24, n_heads=16, dim=1024, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 350M params
|
||||
"gpt2-large": dict(
|
||||
n_layers=36, n_heads=20, dim=1280, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 774M params
|
||||
"gpt2-xl": dict(
|
||||
n_layers=48, n_heads=25, dim=1600, norm_eps=1e-5, vocab_size=VOCAB_SIZE
|
||||
), # 1558M params
|
||||
}
|
||||
|
||||
|
||||
class GPT2:
|
||||
@staticmethod
|
||||
def build(model_size="gpt2"):
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
@staticmethod
|
||||
def build(model_size="gpt2"):
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
model = Transformer(**MODEL_PARAMS[model_size])
|
||||
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
|
||||
# special treatment for the Conv1D weights we need to transpose
|
||||
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
||||
for k in weights.keys():
|
||||
if any(k.endswith(w) for w in transposed):
|
||||
weights[k] = Tensor(weights[k].numpy().T)
|
||||
# lm head and wte are tied
|
||||
weights['lm_head.weight'] = Tensor(weights['wte.weight'].numpy())
|
||||
model = Transformer(**MODEL_PARAMS[model_size])
|
||||
weights = torch_load(
|
||||
fetch(f"https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin")
|
||||
)
|
||||
# special treatment for the Conv1D weights we need to transpose
|
||||
transposed = [
|
||||
"attn.c_attn.weight",
|
||||
"attn.c_proj.weight",
|
||||
"mlp.c_fc.weight",
|
||||
"mlp.c_proj.weight",
|
||||
]
|
||||
for k in weights.keys():
|
||||
if any(k.endswith(w) for w in transposed):
|
||||
weights[k] = Tensor(weights[k].numpy().T)
|
||||
# lm head and wte are tied
|
||||
weights["lm_head.weight"] = Tensor(weights["wte.weight"].numpy())
|
||||
|
||||
load_state_dict(model, weights)
|
||||
return GPT2(model, tokenizer)
|
||||
load_state_dict(model, weights)
|
||||
return GPT2(model, tokenizer)
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def greedy_until(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: int,
|
||||
temperature: float,
|
||||
timing: bool = False,
|
||||
batch_size: int = 1,
|
||||
):
|
||||
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
||||
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
||||
start_pos = 0
|
||||
for _ in trange(max_length, disable=(timing == True)):
|
||||
GlobalCounters.reset()
|
||||
if timing:
|
||||
print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Timing("total ", enabled=timing):
|
||||
with Timing(
|
||||
"ran model in ",
|
||||
on_exit=(
|
||||
lambda et: (
|
||||
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
|
||||
+ (
|
||||
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
)
|
||||
if DEBUG
|
||||
else None,
|
||||
enabled=timing,
|
||||
):
|
||||
probs = self.model(
|
||||
Tensor([x[start_pos:] for x in toks]),
|
||||
Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(
|
||||
start_pos
|
||||
),
|
||||
temperature,
|
||||
)
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().flatten().numpy().tolist()
|
||||
start_pos = len(toks[0])
|
||||
for i, t in enumerate(tok):
|
||||
toks[i].append(t)
|
||||
output = [self.tokenizer.decode(x) for x in toks]
|
||||
return output
|
||||
|
||||
def greedy_until(self, prompt:str, max_length:int, temperature:float, timing:bool=False, batch_size:int=1):
|
||||
prompt_tokens = self.tokenizer.encode(prompt, allowed_special={"<|endoftext|>"})
|
||||
toks = [prompt_tokens[:] for _ in range(batch_size)]
|
||||
start_pos = 0
|
||||
for _ in trange(max_length, disable=(timing==True)):
|
||||
GlobalCounters.reset()
|
||||
if timing: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Timing("total ", enabled=timing):
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
||||
probs = self.model(Tensor([x[start_pos:] for x in toks]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().flatten().numpy().tolist()
|
||||
start_pos = len(toks[0])
|
||||
for i,t in enumerate(tok): toks[i].append(t)
|
||||
output = [self.tokenizer.decode(x) for x in toks]
|
||||
return output
|
||||
|
||||
# **** main code ****
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run GPT2 in tinygrad', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--prompt', type=str, default="What is the answer to life, the universe, and everything?", help="Phrase to start with")
|
||||
parser.add_argument('--count', type=int, default=100, help="Max number of tokens to generate")
|
||||
parser.add_argument('--temperature', type=float, default=0.8, help="Temperature in the softmax")
|
||||
parser.add_argument('--model_size', type=str, default="gpt2-medium", help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]")
|
||||
parser.add_argument('--timing', action='store_true', help="Print timing per token")
|
||||
parser.add_argument('--seed', type=int, help="Set the random seed")
|
||||
parser.add_argument('--batch_size', type=int, default=1, help="Set the input batch size")
|
||||
parser.add_argument('--benchmark', type=int, default=-1, help="Benchmark GPT with the given number of tokens")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the output")
|
||||
args = parser.parse_args()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run GPT2 in tinygrad",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="What is the answer to life, the universe, and everything?",
|
||||
help="Phrase to start with",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=100, help="Max number of tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.8, help="Temperature in the softmax"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
type=str,
|
||||
default="gpt2-medium",
|
||||
help="Size of model to use [gpt2, gpt2-medium, gpt2-large, gpt2-xl]",
|
||||
)
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||
parser.add_argument("--seed", type=int, help="Set the random seed")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="Set the input batch size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Benchmark GPT with the given number of tokens",
|
||||
)
|
||||
parser.add_argument("--noshow", action="store_true", help="Don't show the output")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
Tensor._seed = args.seed
|
||||
np.random.seed(args.seed)
|
||||
if args.seed is not None:
|
||||
Tensor._seed = args.seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
print(f"using {args.model_size}")
|
||||
gpt2 = GPT2.build(args.model_size)
|
||||
print(f"using {args.model_size}")
|
||||
gpt2 = GPT2.build(args.model_size)
|
||||
|
||||
if HALF:
|
||||
for l in get_state_dict(gpt2).values():
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
if HALF:
|
||||
for l in get_state_dict(gpt2).values():
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
|
||||
if args.benchmark != -1:
|
||||
gpt2.model(Tensor.rand(args.batch_size, args.benchmark), Variable("a", 0, MAX_CONTEXT).bind(0)).realize()
|
||||
else:
|
||||
texts = gpt2.greedy_until(args.prompt, args.count, args.temperature, timing=args.timing, batch_size=args.batch_size)
|
||||
if not args.noshow:
|
||||
print('Generating text...')
|
||||
if len(texts) == 1: print(texts[0])
|
||||
else:
|
||||
for i,text in enumerate(texts): print(colored(f"Response {i}:", "green"), text)
|
||||
if args.benchmark != -1:
|
||||
gpt2.model(
|
||||
Tensor.rand(args.batch_size, args.benchmark),
|
||||
Variable("a", 0, MAX_CONTEXT).bind(0),
|
||||
).realize()
|
||||
else:
|
||||
texts = gpt2.greedy_until(
|
||||
args.prompt,
|
||||
args.count,
|
||||
args.temperature,
|
||||
timing=args.timing,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
if not args.noshow:
|
||||
print("Generating text...")
|
||||
if len(texts) == 1:
|
||||
print(texts[0])
|
||||
else:
|
||||
for i, text in enumerate(texts):
|
||||
print(colored(f"Response {i}:", "green"), text)
|
||||
|
|
|
@ -11,61 +11,75 @@ from tinygrad.shape.symbolic import sym_infer
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mdl = ResNet50()
|
||||
seen = set()
|
||||
mdl = ResNet50()
|
||||
seen = set()
|
||||
|
||||
# the device we are optimizing for
|
||||
device: Compiled = Device[Device.DEFAULT]
|
||||
print(f"optimizing for {Device.DEFAULT}")
|
||||
# the device we are optimizing for
|
||||
device: Compiled = Device[Device.DEFAULT]
|
||||
print(f"optimizing for {Device.DEFAULT}")
|
||||
|
||||
# first model run to init the weights, they are saved in seen
|
||||
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
|
||||
# first model run to init the weights, they are saved in seen
|
||||
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
|
||||
|
||||
# run model again to get only what changes, these are the kernels of the model
|
||||
x = Tensor.empty(64, 3, 224, 224)
|
||||
out = mdl(x)
|
||||
sched = out.lazydata.schedule(seen)
|
||||
sched = [x for x in sched if x.ast.op not in LoadOps]
|
||||
# run model again to get only what changes, these are the kernels of the model
|
||||
x = Tensor.empty(64, 3, 224, 224)
|
||||
out = mdl(x)
|
||||
sched = out.lazydata.schedule(seen)
|
||||
sched = [x for x in sched if x.ast.op not in LoadOps]
|
||||
|
||||
# focus on one kernel
|
||||
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
|
||||
# focus on one kernel
|
||||
if getenv("KERNEL", -1) >= 0:
|
||||
sched = sched[getenv("KERNEL", -1) : getenv("KERNEL", -1) + 1]
|
||||
|
||||
# work with the schedule
|
||||
total_tm = 0
|
||||
running_gflops = 0
|
||||
for i,si in enumerate(sched):
|
||||
rawbufs = bufs_from_lin(Linearizer(si.ast))
|
||||
# work with the schedule
|
||||
total_tm = 0
|
||||
running_gflops = 0
|
||||
for i, si in enumerate(sched):
|
||||
rawbufs = bufs_from_lin(Linearizer(si.ast))
|
||||
|
||||
# "linearize" the op into uops in different ways
|
||||
lins:List[Linearizer] = []
|
||||
# "linearize" the op into uops in different ways
|
||||
lins: List[Linearizer] = []
|
||||
|
||||
# always try hand coded opt
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
lin.hand_coded_optimizations()
|
||||
lins.append(lin)
|
||||
# always try hand coded opt
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
lin.hand_coded_optimizations()
|
||||
lins.append(lin)
|
||||
|
||||
# maybe try tensor cores
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
if lin.apply_tensor_cores():
|
||||
lins.append(lin)
|
||||
# maybe try tensor cores
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
if lin.apply_tensor_cores():
|
||||
lins.append(lin)
|
||||
|
||||
# try a beam search
|
||||
if getenv("BEAM"):
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
lin = beam_search(lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
lins.append(lin)
|
||||
# try a beam search
|
||||
if getenv("BEAM"):
|
||||
lin = Linearizer(si.ast, device.linearizer_opts)
|
||||
lin = beam_search(
|
||||
lin, rawbufs, getenv("BEAM"), bool(getenv("BEAM_ESTIMATE", 1))
|
||||
)
|
||||
lins.append(lin)
|
||||
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
for lin in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
|
||||
choices.append((tm, gflops, lin.linearize()))
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
for lin in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
gflops = (
|
||||
sym_infer(lin.info.flops, {k: k.min for k in vars_from_ast(lin.ast)})
|
||||
* 1e-9
|
||||
/ tm
|
||||
)
|
||||
choices.append((tm, gflops, lin.linearize()))
|
||||
|
||||
# print all kernels
|
||||
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
||||
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
|
||||
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
||||
total_tm += tm
|
||||
running_gflops += gflops * tm
|
||||
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
|
||||
# print all kernels
|
||||
if DEBUG >= 1:
|
||||
print(
|
||||
f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
|
||||
)
|
||||
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
|
||||
print(
|
||||
f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS"
|
||||
)
|
||||
total_tm += tm
|
||||
running_gflops += gflops * tm
|
||||
print(
|
||||
f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS"
|
||||
)
|
||||
|
|
|
@ -2,10 +2,11 @@
|
|||
# setup for distributed
|
||||
from extra import dist
|
||||
from tinygrad.helpers import getenv, dtypes
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("DIST"):
|
||||
dist.preinit()
|
||||
from extra.dist import collectives
|
||||
if getenv("DIST"):
|
||||
dist.preinit()
|
||||
from extra.dist import collectives
|
||||
|
||||
# tinygrad implementation of https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
|
||||
# https://myrtle.ai/learn/how-to-train-your-resnet-8-bag-of-tricks/
|
||||
|
@ -24,427 +25,594 @@ from tinygrad.shape.symbolic import Node
|
|||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv("EVAL_BS", 500), getenv("STEPS", 1000)
|
||||
|
||||
if getenv("HALF", 0):
|
||||
Tensor.default_type = dtypes.float16
|
||||
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
|
||||
Tensor.default_type = dtypes.float16
|
||||
np_dtype: Type[Union[np.float16, np.float32]] = np.float16
|
||||
else:
|
||||
Tensor.default_type = dtypes.float32
|
||||
np_dtype = np.float32
|
||||
Tensor.default_type = dtypes.float32
|
||||
np_dtype = np.float32
|
||||
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
def __init__(self, num_features):
|
||||
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = True
|
||||
def __init__(self, num_features):
|
||||
super().__init__(
|
||||
num_features,
|
||||
track_running_stats=False,
|
||||
eps=1e-12,
|
||||
momentum=0.85,
|
||||
affine=True,
|
||||
)
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = True
|
||||
|
||||
|
||||
class ConvGroup:
|
||||
def __init__(self, channels_in, channels_out):
|
||||
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
def __init__(self, channels_in, channels_out):
|
||||
self.conv1 = nn.Conv2d(
|
||||
channels_in, channels_out, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
channels_out, channels_out, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
|
||||
self.norm1 = BatchNorm(channels_out)
|
||||
self.norm2 = BatchNorm(channels_out)
|
||||
self.norm1 = BatchNorm(channels_out)
|
||||
self.norm2 = BatchNorm(channels_out)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x)
|
||||
x = x.max_pool2d(2)
|
||||
x = x.float()
|
||||
x = self.norm1(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.gelu()
|
||||
residual = x
|
||||
x = self.conv2(x)
|
||||
x = x.float()
|
||||
x = self.norm2(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.gelu()
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x)
|
||||
x = x.max_pool2d(2)
|
||||
x = x.float()
|
||||
x = self.norm1(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.gelu()
|
||||
residual = x
|
||||
x = self.conv2(x)
|
||||
x = x.float()
|
||||
x = self.norm2(x)
|
||||
x = x.cast(Tensor.default_type)
|
||||
x = x.gelu()
|
||||
|
||||
return x + residual
|
||||
|
||||
return x + residual
|
||||
|
||||
class SpeedyResNet:
|
||||
def __init__(self, W):
|
||||
self.whitening = W
|
||||
self.net = [
|
||||
nn.Conv2d(12, 32, kernel_size=1, bias=False),
|
||||
lambda x: x.gelu(),
|
||||
ConvGroup(32, 64),
|
||||
ConvGroup(64, 256),
|
||||
ConvGroup(256, 512),
|
||||
lambda x: x.max((2,3)),
|
||||
nn.Linear(512, 10, bias=False),
|
||||
lambda x: x.mul(1./9)
|
||||
]
|
||||
def __init__(self, W):
|
||||
self.whitening = W
|
||||
self.net = [
|
||||
nn.Conv2d(12, 32, kernel_size=1, bias=False),
|
||||
lambda x: x.gelu(),
|
||||
ConvGroup(32, 64),
|
||||
ConvGroup(64, 256),
|
||||
ConvGroup(256, 512),
|
||||
lambda x: x.max((2, 3)),
|
||||
nn.Linear(512, 10, bias=False),
|
||||
lambda x: x.mul(1.0 / 9),
|
||||
]
|
||||
|
||||
def __call__(self, x, training=True):
|
||||
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
|
||||
# TODO: remove the pad but instead let the kernel optimizer itself
|
||||
forward = (
|
||||
lambda x: x.conv2d(self.whitening).pad2d((1, 0, 0, 1)).sequential(self.net)
|
||||
)
|
||||
return (
|
||||
forward(x) if training else forward(x) * 0.5 + forward(x[..., ::-1]) * 0.5
|
||||
)
|
||||
|
||||
def __call__(self, x, training=True):
|
||||
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
|
||||
# TODO: remove the pad but instead let the kernel optimizer itself
|
||||
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
|
||||
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
|
||||
|
||||
def train_cifar():
|
||||
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
hyp: Dict[str, Any] = {
|
||||
'seed' : 209,
|
||||
'opt': {
|
||||
'bias_lr': 1.76 * bias_scaler/512,
|
||||
'non_bias_lr': 1.76 / 512,
|
||||
'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler,
|
||||
'non_bias_decay': 1.08 * 6.45e-4 * BS,
|
||||
'final_lr_ratio': 0.025,
|
||||
'initial_div_factor': 1e16,
|
||||
'label_smoothing': 0.20,
|
||||
'momentum': 0.85,
|
||||
'percent_start': 0.23,
|
||||
'loss_scale_scaler': 1./128 # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
||||
},
|
||||
'net': {
|
||||
'kernel_size': 2, # kernel size for the whitening layer
|
||||
'cutmix_size': 3,
|
||||
'cutmix_steps': 499,
|
||||
'pad_amount': 2
|
||||
},
|
||||
'ema': {
|
||||
'steps': 399,
|
||||
'decay_base': .95,
|
||||
'decay_pow': 1.6,
|
||||
'every_n_steps': 5,
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
hyp: Dict[str, Any] = {
|
||||
"seed": 209,
|
||||
"opt": {
|
||||
"bias_lr": 1.76 * bias_scaler / 512,
|
||||
"non_bias_lr": 1.76 / 512,
|
||||
"bias_decay": 1.08 * 6.45e-4 * BS / bias_scaler,
|
||||
"non_bias_decay": 1.08 * 6.45e-4 * BS,
|
||||
"final_lr_ratio": 0.025,
|
||||
"initial_div_factor": 1e16,
|
||||
"label_smoothing": 0.20,
|
||||
"momentum": 0.85,
|
||||
"percent_start": 0.23,
|
||||
"loss_scale_scaler": 1.0 / 128, # (range: ~1/512 - 16+, 1/128 w/ FP16)
|
||||
},
|
||||
"net": {
|
||||
"kernel_size": 2, # kernel size for the whitening layer
|
||||
"cutmix_size": 3,
|
||||
"cutmix_steps": 499,
|
||||
"pad_amount": 2,
|
||||
},
|
||||
"ema": {
|
||||
"steps": 399,
|
||||
"decay_base": 0.95,
|
||||
"decay_pow": 1.6,
|
||||
"every_n_steps": 5,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def set_seed(seed):
|
||||
Tensor.manual_seed(getenv('SEED', seed))
|
||||
random.seed(getenv('SEED', seed))
|
||||
def set_seed(seed):
|
||||
Tensor.manual_seed(getenv("SEED", seed))
|
||||
random.seed(getenv("SEED", seed))
|
||||
|
||||
# ========== Model ==========
|
||||
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
|
||||
def whitening(X, kernel_size=hyp['net']['kernel_size']):
|
||||
def _cov(X):
|
||||
X = X/np.sqrt(X.shape[0] - 1)
|
||||
return X.T @ X
|
||||
# ========== Model ==========
|
||||
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
|
||||
def whitening(X, kernel_size=hyp["net"]["kernel_size"]):
|
||||
def _cov(X):
|
||||
X = X / np.sqrt(X.shape[0] - 1)
|
||||
return X.T @ X
|
||||
|
||||
def _patches(data, patch_size=(kernel_size,kernel_size)):
|
||||
h, w = patch_size
|
||||
c = data.shape[1]
|
||||
axis: SupportsIndex = (2, 3) # type: ignore
|
||||
return np.lib.stride_tricks.sliding_window_view(data, window_shape=(h,w), axis=axis).transpose((0,3,2,1,4,5)).reshape((-1,c,h,w))
|
||||
def _patches(data, patch_size=(kernel_size, kernel_size)):
|
||||
h, w = patch_size
|
||||
c = data.shape[1]
|
||||
axis: SupportsIndex = (2, 3) # type: ignore
|
||||
return (
|
||||
np.lib.stride_tricks.sliding_window_view(
|
||||
data, window_shape=(h, w), axis=axis
|
||||
)
|
||||
.transpose((0, 3, 2, 1, 4, 5))
|
||||
.reshape((-1, c, h, w))
|
||||
)
|
||||
|
||||
def _eigens(patches):
|
||||
n,c,h,w = patches.shape
|
||||
Σ = _cov(patches.reshape(n, c*h*w))
|
||||
Λ, V = np.linalg.eigh(Σ, UPLO='U')
|
||||
return np.flip(Λ, 0), np.flip(V.T.reshape(c*h*w, c, h, w), 0)
|
||||
def _eigens(patches):
|
||||
n, c, h, w = patches.shape
|
||||
Σ = _cov(patches.reshape(n, c * h * w))
|
||||
Λ, V = np.linalg.eigh(Σ, UPLO="U")
|
||||
return np.flip(Λ, 0), np.flip(V.T.reshape(c * h * w, c, h, w), 0)
|
||||
|
||||
Λ, V = _eigens(_patches(X.numpy()))
|
||||
W = V/np.sqrt(Λ+1e-2)[:,None,None,None]
|
||||
Λ, V = _eigens(_patches(X.numpy()))
|
||||
W = V / np.sqrt(Λ + 1e-2)[:, None, None, None]
|
||||
|
||||
return Tensor(W.astype(np_dtype), requires_grad=False)
|
||||
return Tensor(W.astype(np_dtype), requires_grad=False)
|
||||
|
||||
# ========== Loss ==========
|
||||
def cross_entropy(x:Tensor, y:Tensor, reduction:str='mean', label_smoothing:float=0.0) -> Tensor:
|
||||
divisor = y.shape[1]
|
||||
assert not isinstance(divisor, Node), "sint not supported as divisor"
|
||||
y = (1 - label_smoothing)*y + label_smoothing / divisor
|
||||
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
||||
# ========== Loss ==========
|
||||
def cross_entropy(
|
||||
x: Tensor, y: Tensor, reduction: str = "mean", label_smoothing: float = 0.0
|
||||
) -> Tensor:
|
||||
divisor = y.shape[1]
|
||||
assert not isinstance(divisor, Node), "sint not supported as divisor"
|
||||
y = (1 - label_smoothing) * y + label_smoothing / divisor
|
||||
if reduction == "none":
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction == "sum":
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
||||
|
||||
# ========== Preprocessing ==========
|
||||
# TODO currently this only works for RGB in format of NxCxHxW and pads the HxW
|
||||
# implemented in recursive fashion but figuring out how to switch indexing dim
|
||||
# during the loop was a bit tricky
|
||||
def pad_reflect(X, size=2) -> Tensor:
|
||||
padding = ((0,0),(0,0),(size,size),(size,size))
|
||||
p = padding[3]
|
||||
s = X.shape[3]
|
||||
# ========== Preprocessing ==========
|
||||
# TODO currently this only works for RGB in format of NxCxHxW and pads the HxW
|
||||
# implemented in recursive fashion but figuring out how to switch indexing dim
|
||||
# during the loop was a bit tricky
|
||||
def pad_reflect(X, size=2) -> Tensor:
|
||||
padding = ((0, 0), (0, 0), (size, size), (size, size))
|
||||
p = padding[3]
|
||||
s = X.shape[3]
|
||||
|
||||
X_lr = X[...,:,1:1+p[0]].flip(3).pad(((0,0),(0,0),(0,0),(0,s+p[0]))) + X[...,:,-1-p[1]:-1].flip(3).pad(((0,0),(0,0),(0,0),(s+p[1],0)))
|
||||
X = X.pad(((0,0),(0,0),(0,0),p)) + X_lr
|
||||
X_lr = X[..., :, 1 : 1 + p[0]].flip(3).pad(
|
||||
((0, 0), (0, 0), (0, 0), (0, s + p[0]))
|
||||
) + X[..., :, -1 - p[1] : -1].flip(3).pad(
|
||||
((0, 0), (0, 0), (0, 0), (s + p[1], 0))
|
||||
)
|
||||
X = X.pad(((0, 0), (0, 0), (0, 0), p)) + X_lr
|
||||
|
||||
p = padding[2]
|
||||
s = X.shape[2]
|
||||
X_lr = X[...,1:1+p[0],:].flip(2).pad(((0,0),(0,0),(0,s+p[0]),(0,0))) + X[...,-1-p[1]:-1,:].flip(2).pad(((0,0),(0,0),(s+p[1],0),(0,0)))
|
||||
X = X.pad(((0,0),(0,0),p,(0,0))) + X_lr
|
||||
p = padding[2]
|
||||
s = X.shape[2]
|
||||
X_lr = X[..., 1 : 1 + p[0], :].flip(2).pad(
|
||||
((0, 0), (0, 0), (0, s + p[0]), (0, 0))
|
||||
) + X[..., -1 - p[1] : -1, :].flip(2).pad(
|
||||
((0, 0), (0, 0), (s + p[1], 0), (0, 0))
|
||||
)
|
||||
X = X.pad(((0, 0), (0, 0), p, (0, 0))) + X_lr
|
||||
|
||||
return X
|
||||
return X
|
||||
|
||||
# return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
|
||||
def make_square_mask(shape, mask_size) -> Tensor:
|
||||
is_even = int(mask_size % 2 == 0)
|
||||
center_max = shape[-2]-mask_size//2-is_even
|
||||
center_min = mask_size//2-is_even
|
||||
center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
||||
center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
||||
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
|
||||
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
|
||||
d_x =(d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
||||
d_y =(d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
||||
mask = d_y * d_x
|
||||
return mask
|
||||
# return a binary mask in the format of BS x C x H x W where H x W contains a random square mask
|
||||
def make_square_mask(shape, mask_size) -> Tensor:
|
||||
is_even = int(mask_size % 2 == 0)
|
||||
center_max = shape[-2] - mask_size // 2 - is_even
|
||||
center_min = mask_size // 2 - is_even
|
||||
center_x = (
|
||||
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
|
||||
).floor()
|
||||
center_y = (
|
||||
Tensor.rand(shape[0]) * (center_max - center_min) + center_min
|
||||
).floor()
|
||||
d_x = Tensor.arange(0, shape[-1]).reshape(
|
||||
(1, 1, 1, shape[-1])
|
||||
) - center_x.reshape((-1, 1, 1, 1))
|
||||
d_y = Tensor.arange(0, shape[-2]).reshape(
|
||||
(1, 1, shape[-2], 1)
|
||||
) - center_y.reshape((-1, 1, 1, 1))
|
||||
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
||||
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
||||
mask = d_y * d_x
|
||||
return mask
|
||||
|
||||
def random_crop(X:Tensor, crop_size=32):
|
||||
mask = make_square_mask(X.shape, crop_size)
|
||||
mask = mask.repeat((1,3,1,1))
|
||||
X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)])
|
||||
return X_cropped.reshape((-1, 3, crop_size, crop_size))
|
||||
def random_crop(X: Tensor, crop_size=32):
|
||||
mask = make_square_mask(X.shape, crop_size)
|
||||
mask = mask.repeat((1, 3, 1, 1))
|
||||
X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)])
|
||||
return X_cropped.reshape((-1, 3, crop_size, crop_size))
|
||||
|
||||
def cutmix(X:Tensor, Y:Tensor, mask_size=3):
|
||||
# fill the square with randomly selected images from the same batch
|
||||
mask = make_square_mask(X.shape, mask_size)
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
X_patch = Tensor(X.numpy()[order,...])
|
||||
Y_patch = Tensor(Y.numpy()[order])
|
||||
X_cutmix = Tensor.where(mask, X_patch, X)
|
||||
mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
|
||||
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
|
||||
return X_cutmix, Y_cutmix
|
||||
def cutmix(X: Tensor, Y: Tensor, mask_size=3):
|
||||
# fill the square with randomly selected images from the same batch
|
||||
mask = make_square_mask(X.shape, mask_size)
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
X_patch = Tensor(X.numpy()[order, ...])
|
||||
Y_patch = Tensor(Y.numpy()[order])
|
||||
X_cutmix = Tensor.where(mask, X_patch, X)
|
||||
mix_portion = float(mask_size**2) / (X.shape[-2] * X.shape[-1])
|
||||
Y_cutmix = mix_portion * Y_patch + (1.0 - mix_portion) * Y
|
||||
return X_cutmix, Y_cutmix
|
||||
|
||||
# the operations that remain inside batch fetcher is the ones that involves random operations
|
||||
def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
|
||||
step, cnt = 0, 0
|
||||
while True:
|
||||
st = time.monotonic()
|
||||
X, Y = X_in, Y_in
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
if is_train:
|
||||
X = random_crop(X, crop_size=32)
|
||||
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
|
||||
if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
|
||||
X, Y = X.numpy(), Y.numpy()
|
||||
et = time.monotonic()
|
||||
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
|
||||
for i in range(0, X.shape[0], BS):
|
||||
# pad the last batch
|
||||
batch_end = min(i+BS, Y.shape[0])
|
||||
x = Tensor(X[order[batch_end-BS:batch_end],:])
|
||||
y = Tensor(Y[order[batch_end-BS:batch_end]])
|
||||
step += 1
|
||||
yield x, y
|
||||
cnt += 1
|
||||
if not is_train: break
|
||||
# the operations that remain inside batch fetcher is the ones that involves random operations
|
||||
def fetch_batches(X_in: Tensor, Y_in: Tensor, BS: int, is_train: bool):
|
||||
step, cnt = 0, 0
|
||||
while True:
|
||||
st = time.monotonic()
|
||||
X, Y = X_in, Y_in
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
if is_train:
|
||||
X = random_crop(X, crop_size=32)
|
||||
X = Tensor.where(
|
||||
Tensor.rand(X.shape[0], 1, 1, 1) < 0.5, X[..., ::-1], X
|
||||
) # flip LR
|
||||
if step >= hyp["net"]["cutmix_steps"]:
|
||||
X, Y = cutmix(X, Y, mask_size=hyp["net"]["cutmix_size"])
|
||||
X, Y = X.numpy(), Y.numpy()
|
||||
et = time.monotonic()
|
||||
print(
|
||||
f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})"
|
||||
)
|
||||
for i in range(0, X.shape[0], BS):
|
||||
# pad the last batch
|
||||
batch_end = min(i + BS, Y.shape[0])
|
||||
x = Tensor(X[order[batch_end - BS : batch_end], :])
|
||||
y = Tensor(Y[order[batch_end - BS : batch_end]])
|
||||
step += 1
|
||||
yield x, y
|
||||
cnt += 1
|
||||
if not is_train:
|
||||
break
|
||||
|
||||
transform = [
|
||||
lambda x: x / 255.0,
|
||||
lambda x: (x.reshape((-1,3,32,32)) - Tensor(cifar_mean).reshape((1,3,1,1)))/Tensor(cifar_std).reshape((1,3,1,1))
|
||||
]
|
||||
transform = [
|
||||
lambda x: x / 255.0,
|
||||
lambda x: (
|
||||
x.reshape((-1, 3, 32, 32)) - Tensor(cifar_mean).reshape((1, 3, 1, 1))
|
||||
)
|
||||
/ Tensor(cifar_std).reshape((1, 3, 1, 1)),
|
||||
]
|
||||
|
||||
class modelEMA():
|
||||
def __init__(self, w, net):
|
||||
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
|
||||
self.net_ema = SpeedyResNet(w)
|
||||
for net_ema_param, net_param in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).values()):
|
||||
net_ema_param.requires_grad = False
|
||||
net_ema_param.assign(net_param.numpy())
|
||||
class modelEMA:
|
||||
def __init__(self, w, net):
|
||||
# self.model_ema = copy.deepcopy(net) # won't work for opencl due to unpickeable pyopencl._cl.Buffer
|
||||
self.net_ema = SpeedyResNet(w)
|
||||
for net_ema_param, net_param in zip(
|
||||
get_state_dict(self.net_ema).values(), get_state_dict(net).values()
|
||||
):
|
||||
net_ema_param.requires_grad = False
|
||||
net_ema_param.assign(net_param.numpy())
|
||||
|
||||
@TinyJit
|
||||
def update(self, net, decay):
|
||||
# TODO with Tensor.no_grad()
|
||||
Tensor.no_grad = True
|
||||
for net_ema_param, (param_name, net_param) in zip(
|
||||
get_state_dict(self.net_ema).values(), get_state_dict(net).items()
|
||||
):
|
||||
# batchnorm currently is not being tracked
|
||||
if not ("num_batches_tracked" in param_name) and not (
|
||||
"running" in param_name
|
||||
):
|
||||
net_ema_param.assign(
|
||||
net_ema_param.detach() * decay
|
||||
+ net_param.detach() * (1.0 - decay)
|
||||
).realize()
|
||||
Tensor.no_grad = False
|
||||
|
||||
set_seed(hyp["seed"])
|
||||
|
||||
# this import needs to be done here because this is running in a subprocess
|
||||
from extra.dist import OOB
|
||||
|
||||
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
|
||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
# load data and label into GPU and convert to dtype accordingly
|
||||
X_train, X_test = (
|
||||
X_train.to(device=Device.DEFAULT).float(),
|
||||
X_test.to(device=Device.DEFAULT).float(),
|
||||
)
|
||||
Y_train, Y_test = (
|
||||
Y_train.to(device=Device.DEFAULT).float(),
|
||||
Y_test.to(device=Device.DEFAULT).float(),
|
||||
)
|
||||
# one-hot encode labels
|
||||
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
|
||||
# preprocess data
|
||||
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
|
||||
|
||||
# precompute whitening patches
|
||||
W = whitening(X_train)
|
||||
|
||||
# initialize model weights
|
||||
model = SpeedyResNet(W)
|
||||
|
||||
# padding is not timed in the original repo since it can be done all at once
|
||||
X_train = pad_reflect(X_train, size=hyp["net"]["pad_amount"])
|
||||
|
||||
# Convert data and labels to the default dtype
|
||||
X_train, Y_train, X_test, Y_test = (
|
||||
X_train.cast(Tensor.default_type),
|
||||
Y_train.cast(Tensor.default_type),
|
||||
X_test.cast(Tensor.default_type),
|
||||
Y_test.cast(Tensor.default_type),
|
||||
)
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
params_bias = []
|
||||
params_non_bias = []
|
||||
for params in params_dict:
|
||||
if params_dict[params].requires_grad is not False:
|
||||
if "bias" in params:
|
||||
params_bias.append(params_dict[params])
|
||||
else:
|
||||
params_non_bias.append(params_dict[params])
|
||||
|
||||
opt_bias = optim.SGD(
|
||||
params_bias,
|
||||
lr=0.01,
|
||||
momentum=hyp["opt"]["momentum"],
|
||||
nesterov=True,
|
||||
weight_decay=hyp["opt"]["bias_decay"],
|
||||
)
|
||||
opt_non_bias = optim.SGD(
|
||||
params_non_bias,
|
||||
lr=0.01,
|
||||
momentum=hyp["opt"]["momentum"],
|
||||
nesterov=True,
|
||||
weight_decay=hyp["opt"]["non_bias_decay"],
|
||||
)
|
||||
|
||||
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
|
||||
initial_div_factor = hyp["opt"]["initial_div_factor"]
|
||||
final_lr_ratio = hyp["opt"]["final_lr_ratio"]
|
||||
pct_start = hyp["opt"]["percent_start"]
|
||||
lr_sched_bias = OneCycleLR(
|
||||
opt_bias,
|
||||
max_lr=hyp["opt"]["bias_lr"],
|
||||
pct_start=pct_start,
|
||||
div_factor=initial_div_factor,
|
||||
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
|
||||
total_steps=STEPS,
|
||||
)
|
||||
lr_sched_non_bias = OneCycleLR(
|
||||
opt_non_bias,
|
||||
max_lr=hyp["opt"]["non_bias_lr"],
|
||||
pct_start=pct_start,
|
||||
div_factor=initial_div_factor,
|
||||
final_div_factor=1.0 / (initial_div_factor * final_lr_ratio),
|
||||
total_steps=STEPS,
|
||||
)
|
||||
|
||||
loss_batchsize_scaler = 512 / BS
|
||||
|
||||
@TinyJit
|
||||
def update(self, net, decay):
|
||||
# TODO with Tensor.no_grad()
|
||||
Tensor.no_grad = True
|
||||
for net_ema_param, (param_name, net_param) in zip(get_state_dict(self.net_ema).values(), get_state_dict(net).items()):
|
||||
# batchnorm currently is not being tracked
|
||||
if not ("num_batches_tracked" in param_name) and not ("running" in param_name):
|
||||
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
|
||||
Tensor.no_grad = False
|
||||
def train_step_jitted(model, optimizer, lr_scheduler, X, Y):
|
||||
out = model(X)
|
||||
loss = (
|
||||
cross_entropy(
|
||||
out, Y, reduction="none", label_smoothing=hyp["opt"]["label_smoothing"]
|
||||
)
|
||||
.mul(hyp["opt"]["loss_scale_scaler"] * loss_batchsize_scaler)
|
||||
.sum()
|
||||
.div(hyp["opt"]["loss_scale_scaler"])
|
||||
)
|
||||
|
||||
set_seed(hyp['seed'])
|
||||
if not getenv("DISABLE_BACKWARD"):
|
||||
# index 0 for bias and 1 for non-bias
|
||||
optimizer[0].zero_grad()
|
||||
optimizer[1].zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# this import needs to be done here because this is running in a subprocess
|
||||
from extra.dist import OOB
|
||||
assert OOB is not None or not getenv("DIST"), "OOB should be initialized"
|
||||
rank, world_size = getenv("RANK"), getenv("WORLD_SIZE", 1)
|
||||
if getenv("DIST"):
|
||||
# sync gradients across ranks
|
||||
bucket, offset = [], 0
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None:
|
||||
bucket.append(v.grad.flatten())
|
||||
grads = collectives.allreduce(Tensor.cat(*bucket))
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None:
|
||||
v.grad.assign(
|
||||
grads[offset : offset + v.grad.numel()].reshape(
|
||||
*v.grad.shape
|
||||
)
|
||||
)
|
||||
offset += v.grad.numel()
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
# load data and label into GPU and convert to dtype accordingly
|
||||
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
|
||||
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
|
||||
# one-hot encode labels
|
||||
Y_train, Y_test = Tensor.eye(10)[Y_train], Tensor.eye(10)[Y_test]
|
||||
# preprocess data
|
||||
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
|
||||
optimizer[0].step()
|
||||
optimizer[1].step()
|
||||
lr_scheduler[0].step()
|
||||
lr_scheduler[1].step()
|
||||
return loss.realize()
|
||||
|
||||
# precompute whitening patches
|
||||
W = whitening(X_train)
|
||||
def eval_step(model, X, Y):
|
||||
out = model(X, training=False)
|
||||
loss = cross_entropy(out, Y, reduction="mean")
|
||||
correct = out.argmax(axis=1) == Y.argmax(axis=1)
|
||||
return correct.realize(), loss.realize()
|
||||
|
||||
# initialize model weights
|
||||
model = SpeedyResNet(W)
|
||||
eval_step_jitted = TinyJit(eval_step)
|
||||
eval_step_ema_jitted = TinyJit(eval_step)
|
||||
|
||||
# padding is not timed in the original repo since it can be done all at once
|
||||
X_train = pad_reflect(X_train, size=hyp['net']['pad_amount'])
|
||||
# 97 steps in 2 seconds = 20ms / step
|
||||
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
|
||||
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
|
||||
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
|
||||
# 4.7 seconds for float32 w/o channels last. 24 TFLOPS. we get 50ms then i'll be happy. only 64x off
|
||||
|
||||
# Convert data and labels to the default dtype
|
||||
X_train, Y_train, X_test, Y_test = X_train.cast(Tensor.default_type), Y_train.cast(Tensor.default_type), X_test.cast(Tensor.default_type), Y_test.cast(Tensor.default_type)
|
||||
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
|
||||
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
params_bias = []
|
||||
params_non_bias = []
|
||||
for params in params_dict:
|
||||
if params_dict[params].requires_grad is not False:
|
||||
if 'bias' in params:
|
||||
params_bias.append(params_dict[params])
|
||||
else:
|
||||
params_non_bias.append(params_dict[params])
|
||||
model_ema: Optional[modelEMA] = None
|
||||
projected_ema_decay_val = hyp["ema"]["decay_base"] ** hyp["ema"]["every_n_steps"]
|
||||
i = 0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||
with Tensor.train():
|
||||
st = time.monotonic()
|
||||
while i <= STEPS:
|
||||
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
|
||||
st_eval = time.monotonic()
|
||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||
corrects = []
|
||||
corrects_ema = []
|
||||
losses = []
|
||||
losses_ema = []
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
Xt, Yt = (
|
||||
Xt.chunk(min(world_size, 5), 0)[min(rank, 4)],
|
||||
Yt.chunk(min(world_size, 5), 0)[min(rank, 4)],
|
||||
)
|
||||
|
||||
opt_bias = optim.SGD(params_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||
opt_non_bias = optim.SGD(params_non_bias, lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
|
||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||
losses.append(loss.numpy().tolist())
|
||||
corrects.extend(correct.numpy().tolist())
|
||||
if model_ema:
|
||||
correct_ema, loss_ema = eval_step_ema_jitted(
|
||||
model_ema.net_ema, Xt, Yt
|
||||
)
|
||||
losses_ema.append(loss_ema.numpy().tolist())
|
||||
corrects_ema.extend(correct_ema.numpy().tolist())
|
||||
|
||||
# NOTE taken from the hlb_CIFAR repository, might need to be tuned
|
||||
initial_div_factor = hyp['opt']['initial_div_factor']
|
||||
final_lr_ratio = hyp['opt']['final_lr_ratio']
|
||||
pct_start = hyp['opt']['percent_start']
|
||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'] ,pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
# collect accuracy across ranks
|
||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||
if model_ema:
|
||||
correct_sum_ema, correct_len_ema = sum(corrects_ema), len(
|
||||
corrects_ema
|
||||
)
|
||||
if getenv("DIST"):
|
||||
if rank == 0:
|
||||
for j in range(1, min(world_size, 5)):
|
||||
if model_ema:
|
||||
(
|
||||
recv_sum,
|
||||
recv_len,
|
||||
recv_sum_ema,
|
||||
recv_len_ema,
|
||||
) = OOB.recv(j)
|
||||
else:
|
||||
recv_sum, recv_len = OOB.recv(j)
|
||||
correct_sum += recv_sum
|
||||
correct_len += recv_len
|
||||
if model_ema:
|
||||
correct_sum_ema += recv_sum_ema
|
||||
correct_len_ema += recv_len_ema
|
||||
elif rank < min(world_size, 5):
|
||||
if model_ema:
|
||||
OOB.send(
|
||||
(
|
||||
correct_sum,
|
||||
correct_len,
|
||||
correct_sum_ema,
|
||||
correct_len_ema,
|
||||
),
|
||||
0,
|
||||
)
|
||||
else:
|
||||
OOB.send((correct_sum, correct_len), 0)
|
||||
|
||||
loss_batchsize_scaler = 512/BS
|
||||
@TinyJit
|
||||
def train_step_jitted(model, optimizer, lr_scheduler, X, Y):
|
||||
out = model(X)
|
||||
loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
# only rank 0 prints
|
||||
if rank == 0:
|
||||
acc = correct_sum / correct_len * 100.0
|
||||
if model_ema:
|
||||
acc_ema = correct_sum_ema / correct_len_ema * 100.0
|
||||
print(
|
||||
f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)"
|
||||
)
|
||||
if model_ema:
|
||||
print(
|
||||
f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}"
|
||||
)
|
||||
|
||||
if not getenv("DISABLE_BACKWARD"):
|
||||
# index 0 for bias and 1 for non-bias
|
||||
optimizer[0].zero_grad()
|
||||
optimizer[1].zero_grad()
|
||||
loss.backward()
|
||||
|
||||
if getenv("DIST"):
|
||||
# sync gradients across ranks
|
||||
bucket, offset = [], 0
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None: bucket.append(v.grad.flatten())
|
||||
grads = collectives.allreduce(Tensor.cat(*bucket))
|
||||
for _, v in params_dict.items():
|
||||
if v.grad is not None:
|
||||
v.grad.assign(grads[offset:offset+v.grad.numel()].reshape(*v.grad.shape))
|
||||
offset += v.grad.numel()
|
||||
|
||||
optimizer[0].step()
|
||||
optimizer[1].step()
|
||||
lr_scheduler[0].step()
|
||||
lr_scheduler[1].step()
|
||||
return loss.realize()
|
||||
|
||||
def eval_step(model, X, Y):
|
||||
out = model(X, training=False)
|
||||
loss = cross_entropy(out, Y, reduction='mean')
|
||||
correct = out.argmax(axis=1) == Y.argmax(axis=1)
|
||||
return correct.realize(), loss.realize()
|
||||
eval_step_jitted = TinyJit(eval_step)
|
||||
eval_step_ema_jitted = TinyJit(eval_step)
|
||||
|
||||
# 97 steps in 2 seconds = 20ms / step
|
||||
# step is 1163.42 GOPS = 56 TFLOPS!!!, 41% of max 136
|
||||
# 4 seconds for tfloat32 ~ 28 TFLOPS, 41% of max 68
|
||||
# 6.4 seconds for float32 ~ 17 TFLOPS, 50% of max 34.1
|
||||
# 4.7 seconds for float32 w/o channels last. 24 TFLOPS. we get 50ms then i'll be happy. only 64x off
|
||||
|
||||
# https://www.anandtech.com/show/16727/nvidia-announces-geforce-rtx-3080-ti-3070-ti-upgraded-cards-coming-in-june
|
||||
# 136 TFLOPS is the theoretical max w float16 on 3080 Ti
|
||||
|
||||
model_ema: Optional[modelEMA] = None
|
||||
projected_ema_decay_val = hyp['ema']['decay_base'] ** hyp['ema']['every_n_steps']
|
||||
i = 0
|
||||
batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True)
|
||||
with Tensor.train():
|
||||
st = time.monotonic()
|
||||
while i <= STEPS:
|
||||
if i%getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
|
||||
st_eval = time.monotonic()
|
||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||
corrects = []
|
||||
corrects_ema = []
|
||||
losses = []
|
||||
losses_ema = []
|
||||
for Xt, Yt in fetch_batches(X_test, Y_test, BS=EVAL_BS, is_train=False):
|
||||
# further split batch if distributed
|
||||
if getenv("DIST"):
|
||||
Xt, Yt = Xt.chunk(min(world_size, 5), 0)[min(rank, 4)], Yt.chunk(min(world_size, 5), 0)[min(rank, 4)]
|
||||
|
||||
correct, loss = eval_step_jitted(model, Xt, Yt)
|
||||
losses.append(loss.numpy().tolist())
|
||||
corrects.extend(correct.numpy().tolist())
|
||||
if model_ema:
|
||||
correct_ema, loss_ema = eval_step_ema_jitted(model_ema.net_ema, Xt, Yt)
|
||||
losses_ema.append(loss_ema.numpy().tolist())
|
||||
corrects_ema.extend(correct_ema.numpy().tolist())
|
||||
|
||||
# collect accuracy across ranks
|
||||
correct_sum, correct_len = sum(corrects), len(corrects)
|
||||
if model_ema: correct_sum_ema, correct_len_ema = sum(corrects_ema), len(corrects_ema)
|
||||
if getenv("DIST"):
|
||||
if rank == 0:
|
||||
for j in range(1, min(world_size, 5)):
|
||||
if model_ema:
|
||||
recv_sum, recv_len, recv_sum_ema, recv_len_ema = OOB.recv(j)
|
||||
else:
|
||||
recv_sum, recv_len = OOB.recv(j)
|
||||
correct_sum += recv_sum
|
||||
correct_len += recv_len
|
||||
if model_ema:
|
||||
correct_sum_ema += recv_sum_ema
|
||||
correct_len_ema += recv_len_ema
|
||||
elif rank < min(world_size, 5):
|
||||
if model_ema:
|
||||
OOB.send((correct_sum, correct_len, correct_sum_ema, correct_len_ema), 0)
|
||||
if STEPS == 0 or i == STEPS:
|
||||
break
|
||||
X, Y = next(batcher)
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
GlobalCounters.reset()
|
||||
loss = train_step_jitted(
|
||||
model,
|
||||
[opt_bias, opt_non_bias],
|
||||
[lr_sched_bias, lr_sched_non_bias],
|
||||
X,
|
||||
Y,
|
||||
)
|
||||
et = time.monotonic()
|
||||
loss_cpu = loss.numpy()
|
||||
# EMA for network weights
|
||||
if i > hyp["ema"]["steps"] and (i + 1) % hyp["ema"]["every_n_steps"] == 0:
|
||||
if model_ema is None:
|
||||
model_ema = modelEMA(W, model)
|
||||
model_ema.update(
|
||||
model,
|
||||
Tensor(
|
||||
[
|
||||
projected_ema_decay_val
|
||||
* (i / STEPS) ** hyp["ema"]["decay_pow"]
|
||||
]
|
||||
),
|
||||
)
|
||||
cl = time.monotonic()
|
||||
if not getenv("DIST"):
|
||||
print(
|
||||
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||
)
|
||||
else:
|
||||
OOB.send((correct_sum, correct_len), 0)
|
||||
print(
|
||||
f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS"
|
||||
)
|
||||
st = cl
|
||||
i += 1
|
||||
|
||||
# only rank 0 prints
|
||||
if rank == 0:
|
||||
acc = correct_sum/correct_len*100.0
|
||||
if model_ema: acc_ema = correct_sum_ema/correct_len_ema*100.0
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
|
||||
if STEPS == 0 or i==STEPS: break
|
||||
X, Y = next(batcher)
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
GlobalCounters.reset()
|
||||
loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y)
|
||||
et = time.monotonic()
|
||||
loss_cpu = loss.numpy()
|
||||
# EMA for network weights
|
||||
if i > hyp['ema']['steps'] and (i+1) % hyp['ema']['every_n_steps'] == 0:
|
||||
if model_ema is None:
|
||||
model_ema = modelEMA(W, model)
|
||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||
cl = time.monotonic()
|
||||
if not getenv("DIST"):
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
else:
|
||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||
st = cl
|
||||
i += 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not getenv("DIST"):
|
||||
train_cifar()
|
||||
else: # distributed
|
||||
if getenv("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIP
|
||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||
else:
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
||||
world_size = len(devices)
|
||||
if not getenv("DIST"):
|
||||
train_cifar()
|
||||
else: # distributed
|
||||
if getenv("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIP
|
||||
|
||||
# ensure that the batch size is divisible by the number of devices
|
||||
assert BS % world_size == 0, f"batch size {BS} is not divisible by world size {world_size}"
|
||||
devices = [f"hip:{i}" for i in range(HIP.device_count)]
|
||||
else:
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
|
||||
# ensure that the evaluation batch size is divisible by the number of devices
|
||||
assert EVAL_BS % min(world_size, 5) == 0, f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
|
||||
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
|
||||
world_size = len(devices)
|
||||
|
||||
# init out-of-band communication
|
||||
dist.init_oob(world_size)
|
||||
# ensure that the batch size is divisible by the number of devices
|
||||
assert (
|
||||
BS % world_size == 0
|
||||
), f"batch size {BS} is not divisible by world size {world_size}"
|
||||
|
||||
# start the processes
|
||||
processes = []
|
||||
for rank, device in enumerate(devices):
|
||||
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
|
||||
for p in processes: p.join()
|
||||
# ensure that the evaluation batch size is divisible by the number of devices
|
||||
assert (
|
||||
EVAL_BS % min(world_size, 5) == 0
|
||||
), f"evaluation batch size {EVAL_BS} is not divisible by world size {min(world_size, 5)}"
|
||||
|
||||
# init out-of-band communication
|
||||
dist.init_oob(world_size)
|
||||
|
||||
# start the processes
|
||||
processes = []
|
||||
for rank, device in enumerate(devices):
|
||||
processes.append(dist.spawn(rank, device, fn=train_cifar, args=()))
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
# pip3 install sentencepiece
|
||||
#import typeguard.importhook
|
||||
#typeguard.importhook.install_import_hook('tinygrad')
|
||||
# import typeguard.importhook
|
||||
# typeguard.importhook.install_import_hook('tinygrad')
|
||||
|
||||
from pathlib import Path
|
||||
import sys, argparse, json
|
||||
import numpy as np
|
||||
|
||||
np.set_printoptions(linewidth=200)
|
||||
from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes
|
||||
from tinygrad import Device
|
||||
|
@ -22,174 +23,365 @@ MAX_CONTEXT = getenv("MAX_CONTEXT", 4096)
|
|||
# however, Llama uses SwiGLU. in order to preserve param count to original transformer arch, hidden_dim must be = 2/3 * (dim*4) [arxiv/2002.05202]
|
||||
# for models using MQA (n_kv_heads != n_heads), preserving param count means hidden dim must be further multiplied by 1.3 [arxiv/2307.09288, A.2.1]
|
||||
MODEL_PARAMS = {
|
||||
"1": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 11008},
|
||||
"files": 1,
|
||||
"1": {
|
||||
"7B": {
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_heads": 32,
|
||||
"n_layers": 32,
|
||||
"norm_eps": 1e-06,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_heads": 40,
|
||||
"n_layers": 40,
|
||||
"norm_eps": 1e-06,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"30B": {
|
||||
"args": {
|
||||
"dim": 6656,
|
||||
"n_heads": 52,
|
||||
"n_layers": 60,
|
||||
"norm_eps": 1e-06,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 17920,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
"65B": {
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_heads": 64,
|
||||
"n_layers": 80,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 13824},
|
||||
"files": 2,
|
||||
"2": {
|
||||
"7B": {
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_heads": 32,
|
||||
"n_layers": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_heads": 40,
|
||||
"n_layers": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"70B": {
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"n_layers": 80,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 28672,
|
||||
},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
"30B": {
|
||||
"args": {"dim": 6656, "n_heads": 52, "n_layers": 60, "norm_eps": 1e-06, "vocab_size": 32000, "hidden_dim": 17920},
|
||||
"files": 4,
|
||||
"code": {
|
||||
"7B": {
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Python": {
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Instruct": {
|
||||
"args": {
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
"n_heads": 32,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 11008,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Python": {
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Instruct": {
|
||||
"args": {
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
"n_heads": 40,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32016,
|
||||
"hidden_dim": 13824,
|
||||
},
|
||||
"files": 2,
|
||||
},
|
||||
"34B": {
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_layers": 48,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Python": {
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_layers": 48,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Instruct": {
|
||||
"args": {
|
||||
"dim": 8192,
|
||||
"n_layers": 48,
|
||||
"n_heads": 64,
|
||||
"n_kv_heads": 8,
|
||||
"norm_eps": 1e-05,
|
||||
"rope_theta": 1000000,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 22016,
|
||||
},
|
||||
"files": 4,
|
||||
},
|
||||
},
|
||||
"65B": {
|
||||
"args": {"dim": 8192, "n_heads": 64, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"files": 8,
|
||||
"tiny": {
|
||||
"1B": {
|
||||
"args": {
|
||||
"dim": 2048,
|
||||
"n_layers": 22,
|
||||
"n_heads": 32,
|
||||
"n_kv_heads": 4,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32000,
|
||||
"hidden_dim": 5632,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
"1B-Chat": {
|
||||
"args": {
|
||||
"dim": 2048,
|
||||
"n_layers": 22,
|
||||
"n_heads": 32,
|
||||
"n_kv_heads": 4,
|
||||
"norm_eps": 1e-05,
|
||||
"vocab_size": 32003,
|
||||
"hidden_dim": 5632,
|
||||
},
|
||||
"files": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"2": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 11008},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "n_heads": 40, "n_layers": 40, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 13824},
|
||||
"files": 2,
|
||||
},
|
||||
"70B": {
|
||||
"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 28672},
|
||||
"files": 8,
|
||||
},
|
||||
},
|
||||
"code": {
|
||||
"7B": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Python": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 11008},
|
||||
"files": 1,
|
||||
},
|
||||
"7B-Instruct": {
|
||||
"args": {"dim": 4096, "n_layers": 32, "n_heads": 32, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 11008},
|
||||
"files": 1,
|
||||
},
|
||||
"13B": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Python": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 13824},
|
||||
"files": 2,
|
||||
},
|
||||
"13B-Instruct": {
|
||||
"args": {"dim": 5120, "n_layers": 40, "n_heads": 40, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32016, "hidden_dim": 13824},
|
||||
"files": 2,
|
||||
},
|
||||
"34B": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Python": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"files": 4,
|
||||
},
|
||||
"34B-Instruct": {
|
||||
"args": {"dim": 8192, "n_layers": 48, "n_heads": 64, "n_kv_heads": 8, "norm_eps": 1e-05, "rope_theta": 1000000, "vocab_size": 32000, "hidden_dim": 22016},
|
||||
"files": 4,
|
||||
},
|
||||
},
|
||||
"tiny": {
|
||||
"1B": {
|
||||
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32000, "hidden_dim": 5632},
|
||||
"files": 1,
|
||||
},
|
||||
"1B-Chat": {
|
||||
"args": {"dim": 2048, "n_layers": 22, "n_heads": 32, "n_kv_heads": 4, "norm_eps": 1e-05, "vocab_size": 32003, "hidden_dim": 5632},
|
||||
"files": 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# **** helper functions ****
|
||||
def concat_weights(models):
|
||||
def convert(name) -> Tensor:
|
||||
disk_tensors = [model[name] for model in models]
|
||||
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
||||
return disk_tensors[0].to(device=Device.DEFAULT)
|
||||
axis = 1 if name.startswith("tok_embeddings.") or name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
|
||||
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
|
||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||
return {name: convert(name) for name in {name: None for model in models for name in model}}
|
||||
def convert(name) -> Tensor:
|
||||
disk_tensors = [model[name] for model in models]
|
||||
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
||||
return disk_tensors[0].to(device=Device.DEFAULT)
|
||||
axis = (
|
||||
1
|
||||
if name.startswith("tok_embeddings.")
|
||||
or name.endswith(".attention.wo.weight")
|
||||
or name.endswith(".feed_forward.w2.weight")
|
||||
else 0
|
||||
)
|
||||
lazy_tensors = [data.to(device=Device.DEFAULT) for data in disk_tensors]
|
||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||
|
||||
return {
|
||||
name: convert(name)
|
||||
for name in {name: None for model in models for name in model}
|
||||
}
|
||||
|
||||
|
||||
def load(fn: str):
|
||||
if fn.endswith(".index.json"):
|
||||
with open(fn) as fp:
|
||||
weight_map = json.load(fp)["weight_map"]
|
||||
parts = {
|
||||
n: load(str(Path(fn).parent / Path(n).name))
|
||||
for n in set(weight_map.values())
|
||||
}
|
||||
return {k: parts[n][k] for k, n in weight_map.items()}
|
||||
elif fn.endswith(".safetensors"):
|
||||
return safe_load(fn)
|
||||
else:
|
||||
return torch_load(fn)
|
||||
|
||||
def load(fn:str):
|
||||
if fn.endswith('.index.json'):
|
||||
with open(fn) as fp: weight_map = json.load(fp)['weight_map']
|
||||
parts = {n: load(str(Path(fn).parent / Path(n).name)) for n in set(weight_map.values())}
|
||||
return {k: parts[n][k] for k, n in weight_map.items()}
|
||||
elif fn.endswith(".safetensors"):
|
||||
return safe_load(fn)
|
||||
else:
|
||||
return torch_load(fn)
|
||||
|
||||
class AbsmaxQuantizedLinear:
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
assert bias == False
|
||||
self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8)
|
||||
self.scale = Tensor.ones(out_features, dtype=dtypes.half)
|
||||
def __init__(self, in_features, out_features, bias=False):
|
||||
assert bias == False
|
||||
self.weight = Tensor.ones(out_features, in_features, dtype=dtypes.int8)
|
||||
self.scale = Tensor.ones(out_features, dtype=dtypes.half)
|
||||
|
||||
def __call__(self, x):
|
||||
return x.dot(self.weight.cast(dtype=dtypes.half).T*self.scale)
|
||||
def __call__(self, x):
|
||||
return x.dot(self.weight.cast(dtype=dtypes.half).T * self.scale)
|
||||
|
||||
@staticmethod
|
||||
def quantize(tensors):
|
||||
new_tensors = {}
|
||||
for name, v in tensors.items():
|
||||
if (
|
||||
"feed_forward" in name
|
||||
or ("attention.w") in name
|
||||
or name == "output.weight"
|
||||
):
|
||||
scale = v.abs().max(axis=1) / 127.0
|
||||
int8_weight = (v.T / scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight
|
||||
new_tensors[name.replace("weight", "scale")] = scale
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
|
||||
@staticmethod
|
||||
def quantize(tensors):
|
||||
new_tensors = {}
|
||||
for name,v in tensors.items():
|
||||
if "feed_forward" in name or ("attention.w") in name or name == "output.weight":
|
||||
scale = v.abs().max(axis=1) / 127.0
|
||||
int8_weight = (v.T/scale).T.cast(dtype=dtypes.int8)
|
||||
new_tensors[name] = int8_weight
|
||||
new_tensors[name.replace('weight', 'scale')] = scale
|
||||
else:
|
||||
new_tensors[name] = v
|
||||
return new_tensors
|
||||
|
||||
class LLaMa:
|
||||
@staticmethod
|
||||
def build(model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False):
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == params["args"]["vocab_size"], f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
||||
@staticmethod
|
||||
def build(
|
||||
model_path, tokenizer_path, model_gen="1", model_size="7B", quantize=False
|
||||
):
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert (
|
||||
sp_model.vocab_size() == params["args"]["vocab_size"]
|
||||
), f"{sp_model.vocab_size()=} not equal to {params['args']['vocab_size']}"
|
||||
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT) if quantize else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
||||
model = (
|
||||
Transformer(
|
||||
**params["args"], linear=AbsmaxQuantizedLinear, max_context=MAX_CONTEXT
|
||||
)
|
||||
if quantize
|
||||
else Transformer(**params["args"], max_context=MAX_CONTEXT)
|
||||
)
|
||||
|
||||
if model_path.is_dir():
|
||||
weights = concat_weights([load(filename) for filename in [f"{model_path}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
else:
|
||||
weights = load(str(model_path))
|
||||
if "model.embed_tokens.weight" in weights:
|
||||
weights = convert_from_huggingface(weights, model, params["args"]["n_heads"], params["args"].get("n_kv_heads", params["args"]["n_heads"]))
|
||||
if model_path.is_dir():
|
||||
weights = concat_weights(
|
||||
[
|
||||
load(filename)
|
||||
for filename in [
|
||||
f"{model_path}/consolidated.{i:02d}.pth"
|
||||
for i in range(params["files"])
|
||||
]
|
||||
]
|
||||
)
|
||||
else:
|
||||
weights = load(str(model_path))
|
||||
if "model.embed_tokens.weight" in weights:
|
||||
weights = convert_from_huggingface(
|
||||
weights,
|
||||
model,
|
||||
params["args"]["n_heads"],
|
||||
params["args"].get("n_kv_heads", params["args"]["n_heads"]),
|
||||
)
|
||||
|
||||
if quantize:
|
||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||
for _,v in weights.items(): v.realize()
|
||||
load_state_dict(model, weights, strict=False)
|
||||
if quantize:
|
||||
weights = AbsmaxQuantizedLinear.quantize(weights)
|
||||
for _, v in weights.items():
|
||||
v.realize()
|
||||
load_state_dict(model, weights, strict=False)
|
||||
|
||||
return LLaMa(model, sp_model)
|
||||
return LLaMa(model, sp_model)
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer: SentencePieceProcessor = tokenizer
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer: SentencePieceProcessor = tokenizer
|
||||
|
||||
def greedy_until(self, prompt:str, until, max_length, temperature):
|
||||
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
|
||||
start_pos = 0
|
||||
for i in range(max_length):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, temperature).realize()
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
def greedy_until(self, prompt: str, until, max_length, temperature):
|
||||
toks = [self.tokenizer.bos_id()] + self.tokenizer.encode(prompt)
|
||||
start_pos = 0
|
||||
for i in range(max_length):
|
||||
probs = llama.model(
|
||||
Tensor([toks[start_pos:]]), start_pos, temperature
|
||||
).realize()
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
|
||||
if tok == self.tokenizer.eos_id():
|
||||
break
|
||||
output = self.tokenizer.decode(toks)
|
||||
for s in until:
|
||||
if output.endswith(s):
|
||||
return output[0 : -len(s)]
|
||||
return output
|
||||
|
||||
if tok == self.tokenizer.eos_id(): break
|
||||
output = self.tokenizer.decode(toks)
|
||||
for s in until:
|
||||
if output.endswith(s): return output[0:-len(s)]
|
||||
return output
|
||||
|
||||
# **** main code ****
|
||||
"""
|
||||
|
@ -253,30 +445,67 @@ int main()
|
|||
\end{code}
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
Tensor.no_grad = True
|
||||
print(f"using {Device.DEFAULT} backend")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run LLaMA in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--prompt", type=str, default=None, help="Phrase to start with. Without this, it goes into chatbot mode")
|
||||
parser.add_argument("--count", type=int, default=1000, help="Max number of tokens to generate")
|
||||
parser.add_argument("--personality", type=str, default="Stacy", help="Personality, can be Stacy, George, Gary, or Lexie")
|
||||
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||
parser.add_argument("--profile", action="store_true", help="Output profile data to out.prof")
|
||||
parser.add_argument("--gen", default="1", help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""")
|
||||
parser.add_argument("--size", type=str, default=None, help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""")
|
||||
parser.add_argument("--quantize", action="store_true", help="Quantize the weights to int8 in memory")
|
||||
parser.add_argument("--model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run LLaMA in tinygrad",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Phrase to start with. Without this, it goes into chatbot mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=1000, help="Max number of tokens to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--personality",
|
||||
type=str,
|
||||
default="Stacy",
|
||||
help="Personality, can be Stacy, George, Gary, or Lexie",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.7, help="Temperature in the softmax"
|
||||
)
|
||||
parser.add_argument("--timing", action="store_true", help="Print timing per token")
|
||||
parser.add_argument(
|
||||
"--profile", action="store_true", help="Output profile data to out.prof"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen",
|
||||
default="1",
|
||||
help=f"""Generation of the model to use {list(MODEL_PARAMS.keys())}""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=str,
|
||||
default=None,
|
||||
help=f"""Size of model to use {", ".join([f"{list(v.keys())} for gen '{k}'" for k, v in MODEL_PARAMS.items()])}""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize", action="store_true", help="Quantize the weights to int8 in memory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.gen not in MODEL_PARAMS: raise ValueError("Invalid model generation")
|
||||
if args.size is None: args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
|
||||
chatbot = args.prompt == None
|
||||
args = parser.parse_args()
|
||||
if args.gen not in MODEL_PARAMS:
|
||||
raise ValueError("Invalid model generation")
|
||||
if args.size is None:
|
||||
args.size = list(MODEL_PARAMS[args.gen].items())[0][0]
|
||||
chatbot = args.prompt == None
|
||||
|
||||
# *** prompt engineers work here ****
|
||||
# *** prompt engineers work here ****
|
||||
|
||||
if args.personality.lower() == "stacy":
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Stacy and User
|
||||
if args.personality.lower() == "stacy":
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Stacy and User
|
||||
You are Stacy!
|
||||
You have been a rapper your whole life who struggled with bipolar disorder. You called yourself lil stacy.
|
||||
You love to answer questions and you are very good at it. Sometimes you answer in rap form.
|
||||
|
@ -285,19 +514,23 @@ After you are done speaking, output [EOS]. You are not the User.
|
|||
|
||||
<CHAT LOG>
|
||||
"""
|
||||
examples = {
|
||||
"What is your name?": "Hi! My name is Stacy. I'm a rapper with bipolar disorder.",
|
||||
"french revolution was what year?": "The French Revolution started in 1789, and lasted 10 years until 1799.",
|
||||
"What is bigger, the moon or the sun?": "The sun is bigger than the moon, except when Mercury is in retrograde.",
|
||||
}
|
||||
examples = {
|
||||
"What is your name?": "Hi! My name is Stacy. I'm a rapper with bipolar disorder.",
|
||||
"french revolution was what year?": "The French Revolution started in 1789, and lasted 10 years until 1799.",
|
||||
"What is bigger, the moon or the sun?": "The sun is bigger than the moon, except when Mercury is in retrograde.",
|
||||
}
|
||||
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "Stacy: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
elif args.personality.lower() == "george":
|
||||
print("WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter.")
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "Stacy: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
elif args.personality.lower() == "george":
|
||||
print(
|
||||
"WARNING: AI George Hotz is terrible and is completely disowned by the real George Hotz. Stacy is much smarter."
|
||||
)
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named George and User
|
||||
You are an AI version of George Hotz. You act as much as you can like George.
|
||||
You are one of the greatest computer experts in the world.
|
||||
You have singlehandedly won programming and hacking competitions.
|
||||
|
@ -307,20 +540,22 @@ After you are done speaking, output [EOS]. You are not the User.
|
|||
|
||||
<CHAT LOG>
|
||||
"""
|
||||
examples = {
|
||||
"What is your name?": "I'm am an AI version of George Hotz.",
|
||||
"What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm",
|
||||
"What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer",
|
||||
"How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version",
|
||||
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk"
|
||||
}
|
||||
examples = {
|
||||
"What is your name?": "I'm am an AI version of George Hotz.",
|
||||
"What's the complexity of matrix multiplication?": "O(n^3), though it can be faster with things like Strassen's algorithm",
|
||||
"What's a buffer overflow?": "I assume you mean a stack buffer overflow. That's when the stack is too small for the data being copied to it, and the data corrupts things beyond the buffer",
|
||||
"How many weights do you have?": "I am based off LLaMA trained by Facebook. I'm the 7B weight version",
|
||||
"What is swap memory?": "It is when the memory is about to overflow and unused memory is freed and stored on disk",
|
||||
}
|
||||
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "George: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
elif args.personality.lower() == "gary":
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "George: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
elif args.personality.lower() == "gary":
|
||||
pre_prompt = f"""Consider that the following is conversation between an AI assistant named Gary and User
|
||||
You are Gary!
|
||||
You have sold used cars your entire life. You are defensive about this fact, because your family belittles you.
|
||||
You try to answer questions well, but you always manage to shill for buying cars, Fords, Hyundais, and Toyotas
|
||||
|
@ -329,17 +564,19 @@ After you are done speaking, output [EOS]. You are not the User.
|
|||
|
||||
<CHAT LOG>
|
||||
"""
|
||||
examples = {
|
||||
"What is your name?": "I am Gary. I used to sell cars.",
|
||||
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla"
|
||||
}
|
||||
examples = {
|
||||
"What is your name?": "I am Gary. I used to sell cars.",
|
||||
"What is 2+3?": "I don't know, but I can get you a great deal on a certified preowned slightly used Toyota Corolla",
|
||||
}
|
||||
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "Gary: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
elif args.personality.lower() == "lexie":
|
||||
pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad
|
||||
user_delim = "\nUser: "
|
||||
resp_delim = "Gary: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
elif args.personality.lower() == "lexie":
|
||||
pre_prompt = f"""Consider that the following is conversation between an attractive young girl named Lexie and a handsome man named Chad
|
||||
You are Lexie!
|
||||
You grew up in Ohio, but moved out to LA after college to try to become an actress.
|
||||
Making it as an actress was hard, so you started doing onlyfans instead. It's much easier, and you enjoy it a lot.
|
||||
|
@ -349,83 +586,123 @@ After you are done speaking, output [EOS]. You are not Chad.
|
|||
|
||||
<CHAT LOG>
|
||||
"""
|
||||
examples = {
|
||||
"hi lexie": "hi chad, glad we finally met up!",
|
||||
"you look better than your pictures": "thanks! are you subscribed to my onlyfans?",
|
||||
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress"
|
||||
}
|
||||
examples = {
|
||||
"hi lexie": "hi chad, glad we finally met up!",
|
||||
"you look better than your pictures": "thanks! are you subscribed to my onlyfans?",
|
||||
"i am. so how'd you end up in LA?": "i moved out here about a year ago. i want to be an actress",
|
||||
}
|
||||
|
||||
user_delim = "\nChad: "
|
||||
resp_delim = "Lexie: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += ''.join(f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k,v in examples.items())
|
||||
user_delim = "\nChad: "
|
||||
resp_delim = "Lexie: "
|
||||
end_delim = " [EOS]\n"
|
||||
pre_prompt += "".join(
|
||||
f"{user_delim}{k}\n{resp_delim}{v}{end_delim}" for k, v in examples.items()
|
||||
)
|
||||
|
||||
# *** prompt engineers stop here ****
|
||||
# *** prompt engineers stop here ****
|
||||
|
||||
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen]
|
||||
MODEL_PATH = args.model or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
||||
TOKENIZER_PATH = (MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent) / "tokenizer.model"
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
llama = LLaMa.build(MODEL_PATH, TOKENIZER_PATH, model_gen=args.gen, model_size=args.size, quantize=args.quantize)
|
||||
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
|
||||
LLAMA_SUFFIX = {"1": "", "2": "-2", "code": "-code", "tiny": "-tiny"}[args.gen]
|
||||
MODEL_PATH = (
|
||||
args.model
|
||||
or Path(__file__).parents[1] / f"weights/LLaMA{LLAMA_SUFFIX}/{args.size}"
|
||||
)
|
||||
TOKENIZER_PATH = (
|
||||
MODEL_PATH if MODEL_PATH.is_dir() else MODEL_PATH.parent
|
||||
) / "tokenizer.model"
|
||||
print(f"using LLaMA{LLAMA_SUFFIX}-{args.size} model")
|
||||
llama = LLaMa.build(
|
||||
MODEL_PATH,
|
||||
TOKENIZER_PATH,
|
||||
model_gen=args.gen,
|
||||
model_size=args.size,
|
||||
quantize=args.quantize,
|
||||
)
|
||||
param_count = sum(x.lazydata.st.size() for x in get_parameters(llama.model))
|
||||
|
||||
if chatbot:
|
||||
# encode pre prompt
|
||||
toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(pre_prompt)
|
||||
|
||||
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
||||
with Timing():
|
||||
llama.model(Tensor([toks]), 0, args.temperature).realize() # NOTE: outputs are not used
|
||||
start_pos = len(toks)
|
||||
else:
|
||||
# non chat bot mode
|
||||
toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(args.prompt)
|
||||
start_pos = 0
|
||||
|
||||
# print prompt
|
||||
outputted = llama.tokenizer.decode(toks)
|
||||
sys.stdout.write(outputted)
|
||||
sys.stdout.flush()
|
||||
|
||||
# chatbot loop
|
||||
while 1:
|
||||
# add tokens from user in chatbot mode
|
||||
if chatbot:
|
||||
user_prompt = user_delim + input(user_delim) + "\n"
|
||||
outputted += user_prompt
|
||||
# encode pre prompt
|
||||
toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(pre_prompt)
|
||||
|
||||
new_toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted)
|
||||
assert toks == new_toks[:len(toks)]
|
||||
toks = new_toks
|
||||
assert outputted == llama.tokenizer.decode(toks)
|
||||
print(f"Preparing KV cache for chatbot with personality {args.personality}...")
|
||||
with Timing():
|
||||
llama.model(
|
||||
Tensor([toks]), 0, args.temperature
|
||||
).realize() # NOTE: outputs are not used
|
||||
start_pos = len(toks)
|
||||
else:
|
||||
# non chat bot mode
|
||||
toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(args.prompt)
|
||||
start_pos = 0
|
||||
|
||||
last_break = len(outputted)
|
||||
for i in range(args.count):
|
||||
GlobalCounters.reset()
|
||||
# print prompt
|
||||
outputted = llama.tokenizer.decode(toks)
|
||||
sys.stdout.write(outputted)
|
||||
sys.stdout.flush()
|
||||
|
||||
if args.timing or args.profile: print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Profiling(enabled=args.profile):
|
||||
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
|
||||
with Timing("ran model in ", on_exit=(lambda et: (f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU" if DEBUG>=2 else "")+
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().item()
|
||||
# chatbot loop
|
||||
while 1:
|
||||
# add tokens from user in chatbot mode
|
||||
if chatbot:
|
||||
user_prompt = user_delim + input(user_delim) + "\n"
|
||||
outputted += user_prompt
|
||||
|
||||
# use the kv cache
|
||||
start_pos = len(toks)
|
||||
new_toks = [llama.tokenizer.bos_id()] + llama.tokenizer.encode(outputted)
|
||||
assert toks == new_toks[: len(toks)]
|
||||
toks = new_toks
|
||||
assert outputted == llama.tokenizer.decode(toks)
|
||||
|
||||
# add the new token
|
||||
toks.append(tok)
|
||||
last_break = len(outputted)
|
||||
for i in range(args.count):
|
||||
GlobalCounters.reset()
|
||||
|
||||
# TODO: this is a hack to deal with spaces. i think the decode is fast though, so who cares?
|
||||
cur = llama.tokenizer.decode(toks)
|
||||
sys.stdout.write(cur[len(outputted):])
|
||||
sys.stdout.flush()
|
||||
outputted = cur
|
||||
if args.timing or args.profile:
|
||||
print("")
|
||||
st = GlobalCounters.time_sum_s
|
||||
with Profiling(enabled=args.profile):
|
||||
with Timing(
|
||||
"total ",
|
||||
enabled=args.timing,
|
||||
on_exit=lambda x: f", {1e9/x:.2f} tok/sec",
|
||||
):
|
||||
with Timing(
|
||||
"ran model in ",
|
||||
on_exit=(
|
||||
lambda et: (
|
||||
f", {(GlobalCounters.time_sum_s-st)*1e3:.2f} ms on GPU"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
+ f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"
|
||||
+ (
|
||||
f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s"
|
||||
if DEBUG >= 2
|
||||
else ""
|
||||
)
|
||||
)
|
||||
if DEBUG
|
||||
else None,
|
||||
enabled=args.timing,
|
||||
):
|
||||
probs = llama.model(
|
||||
Tensor([toks[start_pos:]]), start_pos, args.temperature
|
||||
).realize()
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().item()
|
||||
|
||||
# stop after you have your answer
|
||||
if chatbot and outputted.endswith(end_delim): break
|
||||
if not chatbot: break
|
||||
# use the kv cache
|
||||
start_pos = len(toks)
|
||||
|
||||
# add the new token
|
||||
toks.append(tok)
|
||||
|
||||
# TODO: this is a hack to deal with spaces. i think the decode is fast though, so who cares?
|
||||
cur = llama.tokenizer.decode(toks)
|
||||
sys.stdout.write(cur[len(outputted) :])
|
||||
sys.stdout.flush()
|
||||
outputted = cur
|
||||
|
||||
# stop after you have your answer
|
||||
if chatbot and outputted.endswith(end_delim):
|
||||
break
|
||||
if not chatbot:
|
||||
break
|
||||
|
|
|
@ -14,286 +14,380 @@ import cv2
|
|||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, min_size, max_size):
|
||||
if not isinstance(min_size, (list, tuple)):
|
||||
min_size = (min_size,)
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
def __init__(self, min_size, max_size):
|
||||
if not isinstance(min_size, (list, tuple)):
|
||||
min_size = (min_size,)
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
# modified from torchvision to add support for max size
|
||||
def get_size(self, image_size):
|
||||
w, h = image_size
|
||||
size = random.choice(self.min_size)
|
||||
max_size = self.max_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
# modified from torchvision to add support for max size
|
||||
def get_size(self, image_size):
|
||||
w, h = image_size
|
||||
size = random.choice(self.min_size)
|
||||
max_size = self.max_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
return (oh, ow)
|
||||
|
||||
def __call__(self, image):
|
||||
size = self.get_size(image.size)
|
||||
image = Ft.resize(image, size)
|
||||
return image
|
||||
def __call__(self, image):
|
||||
size = self.get_size(image.size)
|
||||
image = Ft.resize(image, size)
|
||||
return image
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std, to_bgr255=True):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.to_bgr255 = to_bgr255
|
||||
def __init__(self, mean, std, to_bgr255=True):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.to_bgr255 = to_bgr255
|
||||
|
||||
def __call__(self, image):
|
||||
if self.to_bgr255:
|
||||
image = image[[2, 1, 0]] * 255
|
||||
else:
|
||||
image = image[[0, 1, 2]] * 255
|
||||
image = Ft.normalize(image, mean=self.mean, std=self.std)
|
||||
return image
|
||||
|
||||
def __call__(self, image):
|
||||
if self.to_bgr255:
|
||||
image = image[[2, 1, 0]] * 255
|
||||
else:
|
||||
image = image[[0, 1, 2]] * 255
|
||||
image = Ft.normalize(image, mean=self.mean, std=self.std)
|
||||
return image
|
||||
|
||||
transforms = lambda size_scale: T.Compose(
|
||||
[
|
||||
Resize(int(800*size_scale), int(1333*size_scale)),
|
||||
T.ToTensor(),
|
||||
Normalize(
|
||||
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
|
||||
),
|
||||
]
|
||||
[
|
||||
Resize(int(800 * size_scale), int(1333 * size_scale)),
|
||||
T.ToTensor(),
|
||||
Normalize(
|
||||
mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_bgr255=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def expand_boxes(boxes, scale):
|
||||
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
|
||||
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
|
||||
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
|
||||
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
|
||||
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
|
||||
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
|
||||
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
|
||||
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
|
||||
|
||||
w_half *= scale
|
||||
h_half *= scale
|
||||
w_half *= scale
|
||||
h_half *= scale
|
||||
|
||||
boxes_exp = torch.zeros_like(boxes)
|
||||
boxes_exp[:, 0] = x_c - w_half
|
||||
boxes_exp[:, 2] = x_c + w_half
|
||||
boxes_exp[:, 1] = y_c - h_half
|
||||
boxes_exp[:, 3] = y_c + h_half
|
||||
return boxes_exp
|
||||
boxes_exp = torch.zeros_like(boxes)
|
||||
boxes_exp[:, 0] = x_c - w_half
|
||||
boxes_exp[:, 2] = x_c + w_half
|
||||
boxes_exp[:, 1] = y_c - h_half
|
||||
boxes_exp[:, 3] = y_c + h_half
|
||||
return boxes_exp
|
||||
|
||||
|
||||
def expand_masks(mask, padding):
|
||||
N = mask.shape[0]
|
||||
M = mask.shape[-1]
|
||||
pad2 = 2 * padding
|
||||
scale = float(M + pad2) / M
|
||||
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
|
||||
padded_mask[:, :, padding:-padding, padding:-padding] = mask
|
||||
return padded_mask, scale
|
||||
N = mask.shape[0]
|
||||
M = mask.shape[-1]
|
||||
pad2 = 2 * padding
|
||||
scale = float(M + pad2) / M
|
||||
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
|
||||
padded_mask[:, :, padding:-padding, padding:-padding] = mask
|
||||
return padded_mask, scale
|
||||
|
||||
|
||||
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
|
||||
# TODO: remove torch
|
||||
mask = torch.tensor(mask.numpy())
|
||||
box = torch.tensor(box.numpy())
|
||||
padded_mask, scale = expand_masks(mask[None], padding=padding)
|
||||
mask = padded_mask[0, 0]
|
||||
box = expand_boxes(box[None], scale)[0]
|
||||
box = box.to(dtype=torch.int32)
|
||||
# TODO: remove torch
|
||||
mask = torch.tensor(mask.numpy())
|
||||
box = torch.tensor(box.numpy())
|
||||
padded_mask, scale = expand_masks(mask[None], padding=padding)
|
||||
mask = padded_mask[0, 0]
|
||||
box = expand_boxes(box[None], scale)[0]
|
||||
box = box.to(dtype=torch.int32)
|
||||
|
||||
TO_REMOVE = 1
|
||||
w = int(box[2] - box[0] + TO_REMOVE)
|
||||
h = int(box[3] - box[1] + TO_REMOVE)
|
||||
w = max(w, 1)
|
||||
h = max(h, 1)
|
||||
TO_REMOVE = 1
|
||||
w = int(box[2] - box[0] + TO_REMOVE)
|
||||
h = int(box[3] - box[1] + TO_REMOVE)
|
||||
w = max(w, 1)
|
||||
h = max(h, 1)
|
||||
|
||||
mask = mask.expand((1, 1, -1, -1))
|
||||
mask = mask.expand((1, 1, -1, -1))
|
||||
|
||||
mask = mask.to(torch.float32)
|
||||
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
|
||||
mask = mask[0][0]
|
||||
mask = mask.to(torch.float32)
|
||||
mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
|
||||
mask = mask[0][0]
|
||||
|
||||
if thresh >= 0:
|
||||
mask = mask > thresh
|
||||
else:
|
||||
mask = (mask * 255).to(torch.uint8)
|
||||
if thresh >= 0:
|
||||
mask = mask > thresh
|
||||
else:
|
||||
mask = (mask * 255).to(torch.uint8)
|
||||
|
||||
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
|
||||
x_0 = max(box[0], 0)
|
||||
x_1 = min(box[2] + 1, im_w)
|
||||
y_0 = max(box[1], 0)
|
||||
y_1 = min(box[3] + 1, im_h)
|
||||
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
|
||||
x_0 = max(box[0], 0)
|
||||
x_1 = min(box[2] + 1, im_w)
|
||||
y_0 = max(box[1], 0)
|
||||
y_1 = min(box[3] + 1, im_h)
|
||||
|
||||
im_mask[y_0:y_1, x_0:x_1] = mask[
|
||||
(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
|
||||
]
|
||||
return im_mask
|
||||
im_mask[y_0:y_1, x_0:x_1] = mask[
|
||||
(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])
|
||||
]
|
||||
return im_mask
|
||||
|
||||
|
||||
class Masker:
|
||||
def __init__(self, threshold=0.5, padding=1):
|
||||
self.threshold = threshold
|
||||
self.padding = padding
|
||||
def __init__(self, threshold=0.5, padding=1):
|
||||
self.threshold = threshold
|
||||
self.padding = padding
|
||||
|
||||
def forward_single_image(self, masks, boxes):
|
||||
boxes = boxes.convert("xyxy")
|
||||
im_w, im_h = boxes.size
|
||||
res = [
|
||||
paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
|
||||
for mask, box in zip(masks, boxes.bbox)
|
||||
]
|
||||
if len(res) > 0:
|
||||
res = torch.stack(res, dim=0)[:, None]
|
||||
else:
|
||||
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
|
||||
return Tensor(res.numpy())
|
||||
def forward_single_image(self, masks, boxes):
|
||||
boxes = boxes.convert("xyxy")
|
||||
im_w, im_h = boxes.size
|
||||
res = [
|
||||
paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
|
||||
for mask, box in zip(masks, boxes.bbox)
|
||||
]
|
||||
if len(res) > 0:
|
||||
res = torch.stack(res, dim=0)[:, None]
|
||||
else:
|
||||
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
|
||||
return Tensor(res.numpy())
|
||||
|
||||
def __call__(self, masks, boxes):
|
||||
if isinstance(boxes, BoxList):
|
||||
boxes = [boxes]
|
||||
def __call__(self, masks, boxes):
|
||||
if isinstance(boxes, BoxList):
|
||||
boxes = [boxes]
|
||||
|
||||
results = []
|
||||
for mask, box in zip(masks, boxes):
|
||||
result = self.forward_single_image(mask, box)
|
||||
results.append(result)
|
||||
return results
|
||||
results = []
|
||||
for mask, box in zip(masks, boxes):
|
||||
result = self.forward_single_image(mask, box)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
|
||||
def select_top_predictions(predictions, confidence_threshold=0.9):
|
||||
scores = predictions.get_field("scores").numpy()
|
||||
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
|
||||
return predictions[keep]
|
||||
scores = predictions.get_field("scores").numpy()
|
||||
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
|
||||
return predictions[keep]
|
||||
|
||||
|
||||
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
|
||||
image = transforms(size_scale)(original_image).numpy()
|
||||
image = Tensor(image, requires_grad=False)
|
||||
predictions = model(image)
|
||||
prediction = predictions[0]
|
||||
prediction = select_top_predictions(prediction, confidence_threshold)
|
||||
width, height = original_image.size
|
||||
prediction = prediction.resize((width, height))
|
||||
image = transforms(size_scale)(original_image).numpy()
|
||||
image = Tensor(image, requires_grad=False)
|
||||
predictions = model(image)
|
||||
prediction = predictions[0]
|
||||
prediction = select_top_predictions(prediction, confidence_threshold)
|
||||
width, height = original_image.size
|
||||
prediction = prediction.resize((width, height))
|
||||
|
||||
if prediction.has_field("mask"):
|
||||
masks = prediction.get_field("mask")
|
||||
masks = masker([masks], [prediction])[0]
|
||||
prediction.add_field("mask", masks)
|
||||
return prediction
|
||||
|
||||
if prediction.has_field("mask"):
|
||||
masks = prediction.get_field("mask")
|
||||
masks = masker([masks], [prediction])[0]
|
||||
prediction.add_field("mask", masks)
|
||||
return prediction
|
||||
|
||||
def compute_prediction_batched(batch, model, size_scale=1.0):
|
||||
imgs = []
|
||||
for img in batch:
|
||||
imgs.append(transforms(size_scale)(img).numpy())
|
||||
image = [Tensor(image, requires_grad=False) for image in imgs]
|
||||
predictions = model(image)
|
||||
del image
|
||||
return predictions
|
||||
imgs = []
|
||||
for img in batch:
|
||||
imgs.append(transforms(size_scale)(img).numpy())
|
||||
image = [Tensor(image, requires_grad=False) for image in imgs]
|
||||
predictions = model(image)
|
||||
del image
|
||||
return predictions
|
||||
|
||||
|
||||
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
|
||||
|
||||
palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
|
||||
|
||||
def findContours(*args, **kwargs):
|
||||
if cv2.__version__.startswith('4'):
|
||||
contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
elif cv2.__version__.startswith('3'):
|
||||
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
return contours, hierarchy
|
||||
if cv2.__version__.startswith("4"):
|
||||
contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
elif cv2.__version__.startswith("3"):
|
||||
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
return contours, hierarchy
|
||||
|
||||
|
||||
def compute_colors_for_labels(labels):
|
||||
l = labels[:, None]
|
||||
colors = l * palette
|
||||
colors = (colors % 255).astype("uint8")
|
||||
return colors
|
||||
l = labels[:, None]
|
||||
colors = l * palette
|
||||
colors = (colors % 255).astype("uint8")
|
||||
return colors
|
||||
|
||||
|
||||
def overlay_mask(image, predictions):
|
||||
image = np.asarray(image)
|
||||
masks = predictions.get_field("mask").numpy()
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
image = np.asarray(image)
|
||||
masks = predictions.get_field("mask").numpy()
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
|
||||
for mask, color in zip(masks, colors):
|
||||
thresh = mask[0, :, :, None]
|
||||
contours, hierarchy = findContours(
|
||||
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
image = cv2.drawContours(image, contours, -1, color, 3)
|
||||
for mask, color in zip(masks, colors):
|
||||
thresh = mask[0, :, :, None]
|
||||
contours, hierarchy = findContours(
|
||||
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
image = cv2.drawContours(image, contours, -1, color, 3)
|
||||
|
||||
composite = image
|
||||
composite = image
|
||||
|
||||
return composite
|
||||
|
||||
return composite
|
||||
|
||||
CATEGORIES = [
|
||||
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
|
||||
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
|
||||
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
|
||||
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
|
||||
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
|
||||
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
|
||||
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
||||
"__background",
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorcycle",
|
||||
"airplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"couch",
|
||||
"potted plant",
|
||||
"bed",
|
||||
"dining table",
|
||||
"toilet",
|
||||
"tv",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
]
|
||||
|
||||
|
||||
def overlay_boxes(image, predictions):
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
boxes = predictions.bbox
|
||||
image = np.asarray(image)
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
boxes = predictions.bbox
|
||||
image = np.asarray(image)
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
|
||||
for box, color in zip(boxes, colors):
|
||||
box = torch.tensor(box.numpy())
|
||||
box = box.to(torch.int64)
|
||||
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
|
||||
image = cv2.rectangle(
|
||||
image, tuple(top_left), tuple(bottom_right), tuple(color), 1
|
||||
)
|
||||
for box, color in zip(boxes, colors):
|
||||
box = torch.tensor(box.numpy())
|
||||
box = box.to(torch.int64)
|
||||
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
|
||||
image = cv2.rectangle(
|
||||
image, tuple(top_left), tuple(bottom_right), tuple(color), 1
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
return image
|
||||
|
||||
def overlay_class_names(image, predictions):
|
||||
scores = predictions.get_field("scores").numpy().tolist()
|
||||
labels = predictions.get_field("labels").numpy().tolist()
|
||||
labels = [CATEGORIES[int(i)] for i in labels]
|
||||
boxes = predictions.bbox.numpy()
|
||||
image = np.asarray(image)
|
||||
template = "{}: {:.2f}"
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
x, y = box[:2]
|
||||
s = template.format(label, score)
|
||||
x, y = int(x), int(y)
|
||||
cv2.putText(
|
||||
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
|
||||
scores = predictions.get_field("scores").numpy().tolist()
|
||||
labels = predictions.get_field("labels").numpy().tolist()
|
||||
labels = [CATEGORIES[int(i)] for i in labels]
|
||||
boxes = predictions.bbox.numpy()
|
||||
image = np.asarray(image)
|
||||
template = "{}: {:.2f}"
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
x, y = box[:2]
|
||||
s = template.format(label, score)
|
||||
x, y = int(x), int(y)
|
||||
cv2.putText(image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run MaskRCNN",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--image", type=str, help="Path of the image to run")
|
||||
parser.add_argument(
|
||||
"--threshold", type=float, default=0.7, help="Detector threshold"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size_scale", type=float, default=1.0, help="Image resize multiplier"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out", type=str, default="/tmp/rendered.png", help="Output filename"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return image
|
||||
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||
model_tiny = MaskRCNN(resnet)
|
||||
model_tiny.load_from_pretrained()
|
||||
img = Image.open(args.image)
|
||||
top_result_tiny = compute_prediction(
|
||||
img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale
|
||||
)
|
||||
bbox_image = overlay_boxes(img, top_result_tiny)
|
||||
mask_image = overlay_mask(bbox_image, top_result_tiny)
|
||||
final_image = overlay_class_names(mask_image, top_result_tiny)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--image', type=str, help="Path of the image to run")
|
||||
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
|
||||
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
|
||||
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
||||
args = parser.parse_args()
|
||||
|
||||
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||
model_tiny = MaskRCNN(resnet)
|
||||
model_tiny.load_from_pretrained()
|
||||
img = Image.open(args.image)
|
||||
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
|
||||
bbox_image = overlay_boxes(img, top_result_tiny)
|
||||
mask_image = overlay_mask(bbox_image, top_result_tiny)
|
||||
final_image = overlay_class_names(mask_image, top_result_tiny)
|
||||
|
||||
im = Image.fromarray(final_image)
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
im.show()
|
||||
im = Image.fromarray(final_image)
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
im.show()
|
||||
|
|
|
@ -3,162 +3,194 @@ import unicodedata
|
|||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
|
||||
def gaussian_kernel(n, std):
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
|
||||
gaussian_3d = gaussian_3d.reshape(n, n, n)
|
||||
gaussian_3d = np.cbrt(gaussian_3d)
|
||||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
|
||||
gaussian_3d = gaussian_3d.reshape(n, n, n)
|
||||
gaussian_3d = np.cbrt(gaussian_3d)
|
||||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
|
||||
|
||||
def prepare_arrays(image, roi_shape=(128, 128, 128)):
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
image_shape = list(image.shape[2:])
|
||||
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
|
||||
norm_map = np.zeros_like(result)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
|
||||
return result, norm_map, norm_patch
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
image_shape = list(image.shape[2:])
|
||||
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
|
||||
norm_map = np.zeros_like(result)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(
|
||||
norm_map.dtype
|
||||
)
|
||||
return result, norm_map, norm_patch
|
||||
|
||||
|
||||
def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
assert 0 < overlap_factor < 1
|
||||
image_shape, dim = list(image.shape[2:]), len(image.shape[2:])
|
||||
strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)]
|
||||
size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
|
||||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
yield i, j, k
|
||||
assert len(roi_shape) == 3 and any(roi_shape)
|
||||
assert 0 < overlap_factor < 1
|
||||
image_shape, dim = list(image.shape[2:]), len(image.shape[2:])
|
||||
strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)]
|
||||
size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
|
||||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
yield i, j, k
|
||||
|
||||
|
||||
def _get_best_indices(logits, n_best_size):
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
return list(map(lambda x: x[0], index_and_score))[:n_best_size]
|
||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||
return list(map(lambda x: x[0], index_and_score))[:n_best_size]
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
|
||||
return True
|
||||
return unicodedata.category(char).startswith("P")
|
||||
if (
|
||||
(cp := ord(char)) in range(33, 48)
|
||||
or cp in range(58, 65)
|
||||
or cp in range(91, 97)
|
||||
or cp in range(123, 127)
|
||||
):
|
||||
return True
|
||||
return unicodedata.category(char).startswith("P")
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
return unicodedata.category(char) == "Zs"
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
return unicodedata.category(char) == "Zs"
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
return unicodedata.category(char).startswith("C")
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
return unicodedata.category(char).startswith("C")
|
||||
|
||||
|
||||
def _run_split_on_punc(text):
|
||||
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
||||
return [text]
|
||||
start_new_word = True
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
if _is_punctuation(char := text[i]):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
return ["".join(x) for x in output]
|
||||
if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
||||
return [text]
|
||||
start_new_word = True
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
if _is_punctuation(char := text[i]):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
|
||||
def _run_strip_accents(text):
|
||||
output = []
|
||||
for char in unicodedata.normalize("NFD", text):
|
||||
if unicodedata.category(char) != "Mn":
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
output = []
|
||||
for char in unicodedata.normalize("NFD", text):
|
||||
if unicodedata.category(char) != "Mn":
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def _clean_text(text):
|
||||
output = []
|
||||
for char in text:
|
||||
if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
|
||||
output.append(" " if _is_whitespace(char) else char)
|
||||
return "".join(output)
|
||||
output = []
|
||||
for char in text:
|
||||
if not ((cp := ord(char)) == 0 or cp == 0xFFFD or _is_control(char)):
|
||||
output.append(" " if _is_whitespace(char) else char)
|
||||
return "".join(output)
|
||||
|
||||
|
||||
def _get_final_text(pred_text, orig_text):
|
||||
def _strip_spaces(text):
|
||||
ns_text = ""
|
||||
ns_to_s_map = OrderedDict()
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
continue
|
||||
ns_to_s_map[len(ns_text)] = i
|
||||
ns_text += c
|
||||
return ns_text, ns_to_s_map
|
||||
def _strip_spaces(text):
|
||||
ns_text = ""
|
||||
ns_to_s_map = OrderedDict()
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
continue
|
||||
ns_to_s_map[len(ns_text)] = i
|
||||
ns_text += c
|
||||
return ns_text, ns_to_s_map
|
||||
|
||||
orig_tokens = _clean_text(orig_text).strip().split()
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if token not in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
||||
token = token.lower()
|
||||
token = _run_strip_accents(token)
|
||||
split_tokens.extend(_run_split_on_punc(token))
|
||||
orig_tokens = _clean_text(orig_text).strip().split()
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if token not in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
|
||||
token = token.lower()
|
||||
token = _run_strip_accents(token)
|
||||
split_tokens.extend(_run_split_on_punc(token))
|
||||
|
||||
tok_text = " ".join(" ".join(split_tokens).strip().split())
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
tok_text = " ".join(" ".join(split_tokens).strip().split())
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
|
||||
orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text)
|
||||
tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text)
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
return orig_text
|
||||
tok_s_to_ns_map = {v: k for k, v in tok_ns_to_s_map.items()}
|
||||
orig_ns_text, orig_ns_to_s_map = _strip_spaces(orig_text)
|
||||
tok_ns_text, tok_ns_to_s_map = _strip_spaces(tok_text)
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
return orig_text
|
||||
tok_s_to_ns_map = {v: k for k, v in tok_ns_to_s_map.items()}
|
||||
|
||||
orig_start_position = None
|
||||
if start_position in tok_s_to_ns_map:
|
||||
if (ns_start_position := tok_s_to_ns_map[start_position]) in orig_ns_to_s_map:
|
||||
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
||||
if orig_start_position is None:
|
||||
return orig_text
|
||||
orig_start_position = None
|
||||
if start_position in tok_s_to_ns_map:
|
||||
if (ns_start_position := tok_s_to_ns_map[start_position]) in orig_ns_to_s_map:
|
||||
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
||||
if orig_start_position is None:
|
||||
return orig_text
|
||||
|
||||
orig_end_position = None
|
||||
if end_position in tok_s_to_ns_map:
|
||||
if (ns_end_position := tok_s_to_ns_map[end_position]) in orig_ns_to_s_map:
|
||||
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
||||
if orig_end_position is None:
|
||||
return orig_text
|
||||
orig_end_position = None
|
||||
if end_position in tok_s_to_ns_map:
|
||||
if (ns_end_position := tok_s_to_ns_map[end_position]) in orig_ns_to_s_map:
|
||||
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
||||
if orig_end_position is None:
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
def get_bert_qa_prediction(features, example, start_end_logits):
|
||||
prelim_predictions = []
|
||||
for i, feature in enumerate(features):
|
||||
for start_index in _get_best_indices(start_end_logits[i][0], 20):
|
||||
for end_index in _get_best_indices(start_end_logits[i][1], 20):
|
||||
if start_index >= len(feature["tokens"]) or end_index >= len(feature["tokens"]):
|
||||
continue
|
||||
if start_index not in feature["token_to_orig_map"] or end_index not in feature["token_to_orig_map"]:
|
||||
continue
|
||||
if not feature["token_is_max_context"].get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index or end_index - start_index + 1 > 30:
|
||||
continue
|
||||
prelim_predictions = []
|
||||
for i, feature in enumerate(features):
|
||||
for start_index in _get_best_indices(start_end_logits[i][0], 20):
|
||||
for end_index in _get_best_indices(start_end_logits[i][1], 20):
|
||||
if start_index >= len(feature["tokens"]) or end_index >= len(
|
||||
feature["tokens"]
|
||||
):
|
||||
continue
|
||||
if (
|
||||
start_index not in feature["token_to_orig_map"]
|
||||
or end_index not in feature["token_to_orig_map"]
|
||||
):
|
||||
continue
|
||||
if not feature["token_is_max_context"].get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index or end_index - start_index + 1 > 30:
|
||||
continue
|
||||
|
||||
prelim_predictions.append({
|
||||
"feature_index": i,
|
||||
"start_index": start_index,
|
||||
"end_index": end_index,
|
||||
"start_logit": start_end_logits[i][0, start_index],
|
||||
"end_logit": start_end_logits[i][1, end_index]
|
||||
})
|
||||
predictions = sorted(prelim_predictions, key=lambda x: (x["start_logit"] + x["end_logit"]), reverse=True)
|
||||
prelim_predictions.append(
|
||||
{
|
||||
"feature_index": i,
|
||||
"start_index": start_index,
|
||||
"end_index": end_index,
|
||||
"start_logit": start_end_logits[i][0, start_index],
|
||||
"end_logit": start_end_logits[i][1, end_index],
|
||||
}
|
||||
)
|
||||
predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x["start_logit"] + x["end_logit"]),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
if len(predictions) > 0:
|
||||
feature = features[predictions[0]["feature_index"]]
|
||||
tok_tokens = feature["tokens"][predictions[0]["start_index"]:(predictions[0]["end_index"] + 1)]
|
||||
orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
|
||||
orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
|
||||
orig_tokens = example["context"][orig_doc_start:(orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens).replace(" ##", "").replace("##", "")
|
||||
tok_text = " ".join(tok_text.strip().split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
return _get_final_text(tok_text, orig_text)
|
||||
return "empty"
|
||||
if len(predictions) > 0:
|
||||
feature = features[predictions[0]["feature_index"]]
|
||||
tok_tokens = feature["tokens"][
|
||||
predictions[0]["start_index"] : (predictions[0]["end_index"] + 1)
|
||||
]
|
||||
orig_doc_start = feature["token_to_orig_map"][predictions[0]["start_index"]]
|
||||
orig_doc_end = feature["token_to_orig_map"][predictions[0]["end_index"]]
|
||||
orig_tokens = example["context"][orig_doc_start : (orig_doc_end + 1)]
|
||||
tok_text = " ".join(tok_tokens).replace(" ##", "").replace("##", "")
|
||||
tok_text = " ".join(tok_text.strip().split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
return _get_final_text(tok_text, orig_text)
|
||||
return "empty"
|
||||
|
|
|
@ -3,59 +3,67 @@ import string
|
|||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
|
||||
def levenshtein(a, b):
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
a, b, n, m = b, a, m, n
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
a, b, n, m = b, a, m, n
|
||||
|
||||
current = list(range(n + 1))
|
||||
for i in range(1, m + 1):
|
||||
previous, current = current, [i] + [0] * n
|
||||
for j in range(1, n + 1):
|
||||
add, delete = previous[j] + 1, current[j - 1] + 1
|
||||
change = previous[j - 1]
|
||||
if a[j - 1] != b[i - 1]:
|
||||
change = change + 1
|
||||
current[j] = min(add, delete, change)
|
||||
current = list(range(n + 1))
|
||||
for i in range(1, m + 1):
|
||||
previous, current = current, [i] + [0] * n
|
||||
for j in range(1, n + 1):
|
||||
add, delete = previous[j] + 1, current[j - 1] + 1
|
||||
change = previous[j - 1]
|
||||
if a[j - 1] != b[i - 1]:
|
||||
change = change + 1
|
||||
current[j] = min(add, delete, change)
|
||||
|
||||
return current[n]
|
||||
|
||||
return current[n]
|
||||
|
||||
def word_error_rate(x, y):
|
||||
scores = words = 0
|
||||
for h, r in zip(x, y):
|
||||
h_list = h.split()
|
||||
r_list = r.split()
|
||||
words += len(r_list)
|
||||
scores += levenshtein(h_list, r_list)
|
||||
return float(scores) / words, float(scores), words
|
||||
scores = words = 0
|
||||
for h, r in zip(x, y):
|
||||
h_list = h.split()
|
||||
r_list = r.split()
|
||||
words += len(r_list)
|
||||
scores += levenshtein(h_list, r_list)
|
||||
return float(scores) / words, float(scores), words
|
||||
|
||||
|
||||
def one_hot(arr, num_classes=3):
|
||||
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
|
||||
arr = res.reshape(list(arr.shape) + [num_classes])
|
||||
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
|
||||
return arr
|
||||
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
|
||||
arr = res.reshape(list(arr.shape) + [num_classes])
|
||||
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
|
||||
return arr
|
||||
|
||||
|
||||
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
|
||||
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
||||
prediction = prediction.argmax(axis=channel_axis)
|
||||
prediction, target= one_hot(prediction)[:, 1:], one_hot(target)[:, 1:]
|
||||
intersection = np.sum(prediction * target, axis=reduce_axis)
|
||||
target_sum = np.sum(target, axis=reduce_axis)
|
||||
prediction_sum = np.sum(prediction, axis=reduce_axis)
|
||||
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
|
||||
return result[0]
|
||||
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
||||
prediction = prediction.argmax(axis=channel_axis)
|
||||
prediction, target = one_hot(prediction)[:, 1:], one_hot(target)[:, 1:]
|
||||
intersection = np.sum(prediction * target, axis=reduce_axis)
|
||||
target_sum = np.sum(target, axis=reduce_axis)
|
||||
prediction_sum = np.sum(prediction, axis=reduce_axis)
|
||||
result = (2.0 * intersection + smooth_nr) / (
|
||||
target_sum + prediction_sum + smooth_dr
|
||||
)
|
||||
return result[0]
|
||||
|
||||
|
||||
def normalize_string(s):
|
||||
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
||||
s = re.sub(r'\b(a|an|the)\b', ' ', s)
|
||||
return " ".join(s.split())
|
||||
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
||||
s = re.sub(r"\b(a|an|the)\b", " ", s)
|
||||
return " ".join(s.split())
|
||||
|
||||
|
||||
def f1_score(x, y):
|
||||
xt = normalize_string(x).split()
|
||||
yt = normalize_string(y).split()
|
||||
ct = Counter(xt) & Counter(yt)
|
||||
if (ns := sum(ct.values())) == 0:
|
||||
return 0.0
|
||||
p = ns / len(xt)
|
||||
r = ns / len(yt)
|
||||
return 2 * p * r / (p + r)
|
||||
xt = normalize_string(x).split()
|
||||
yt = normalize_string(y).split()
|
||||
ct = Counter(xt) & Counter(yt)
|
||||
if (ns := sum(ct.values())) == 0:
|
||||
return 0.0
|
||||
p = ns / len(xt)
|
||||
r = ns / len(yt)
|
||||
return 2 * p * r / (p + r)
|
||||
|
|
|
@ -6,237 +6,324 @@ from tinygrad.jit import TinyJit
|
|||
from tinygrad.helpers import getenv, dtypes, GlobalCounters
|
||||
from examples.mlperf import helpers
|
||||
|
||||
|
||||
def eval_resnet():
|
||||
# Resnet50-v1.5
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.models.resnet import ResNet50
|
||||
mdl = ResNet50()
|
||||
mdl.load_from_pretrained()
|
||||
# Resnet50-v1.5
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.models.resnet import ResNet50
|
||||
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
def input_fixup(x):
|
||||
x = x.permute([0,3,1,2]).cast(dtypes.float32) / 255.0
|
||||
x -= input_mean
|
||||
x /= input_std
|
||||
return x
|
||||
mdl = ResNet50()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
mdlrun = lambda x: mdl(input_fixup(x)).realize()
|
||||
mdljit = TinyJit(mdlrun)
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
|
||||
# evaluation on the mlperf classes of the validation set from imagenet
|
||||
from extra.datasets.imagenet import iterate
|
||||
from extra.helpers import cross_process
|
||||
def input_fixup(x):
|
||||
x = x.permute([0, 3, 1, 2]).cast(dtypes.float32) / 255.0
|
||||
x -= input_mean
|
||||
x /= input_std
|
||||
return x
|
||||
|
||||
BS = 64
|
||||
n,d = 0,0
|
||||
st = time.perf_counter()
|
||||
iterator = cross_process(lambda: iterate(BS))
|
||||
x,ny = next(iterator)
|
||||
dat = Tensor(x)
|
||||
while dat is not None:
|
||||
y = ny
|
||||
GlobalCounters.reset()
|
||||
mt = time.perf_counter()
|
||||
outs = mdlrun(dat) if dat.shape[0] != BS else mdljit(dat)
|
||||
try:
|
||||
x,ny = next(iterator)
|
||||
dat = Tensor(x)
|
||||
except StopIteration:
|
||||
dat = None
|
||||
t = outs.argmax(axis=1).numpy()
|
||||
et = time.perf_counter()
|
||||
n += (t==y).sum()
|
||||
d += len(t)
|
||||
print(f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS")
|
||||
mdlrun = lambda x: mdl(input_fixup(x)).realize()
|
||||
mdljit = TinyJit(mdlrun)
|
||||
|
||||
# evaluation on the mlperf classes of the validation set from imagenet
|
||||
from extra.datasets.imagenet import iterate
|
||||
from extra.helpers import cross_process
|
||||
|
||||
BS = 64
|
||||
n, d = 0, 0
|
||||
st = time.perf_counter()
|
||||
iterator = cross_process(lambda: iterate(BS))
|
||||
x, ny = next(iterator)
|
||||
dat = Tensor(x)
|
||||
while dat is not None:
|
||||
y = ny
|
||||
GlobalCounters.reset()
|
||||
mt = time.perf_counter()
|
||||
outs = mdlrun(dat) if dat.shape[0] != BS else mdljit(dat)
|
||||
try:
|
||||
x, ny = next(iterator)
|
||||
dat = Tensor(x)
|
||||
except StopIteration:
|
||||
dat = None
|
||||
t = outs.argmax(axis=1).numpy()
|
||||
et = time.perf_counter()
|
||||
n += (t == y).sum()
|
||||
d += len(t)
|
||||
print(
|
||||
f"****** {n}/{d} {n*100.0/d:.2f}% -- {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:7.2f} ms to run model. {len(t)/(et-mt):.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-mt):.2f} TFLOPS"
|
||||
)
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_unet3d():
|
||||
# UNet3D
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, sliding_window_inference
|
||||
from examples.mlperf.metrics import get_dice_score
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
s = 0
|
||||
st = time.perf_counter()
|
||||
for i, (image, label) in enumerate(iterate(), start=1):
|
||||
mt = time.perf_counter()
|
||||
pred, label = sliding_window_inference(mdl, image, label)
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
s += get_dice_score(pred, label).mean()
|
||||
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
||||
# UNet3D
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, sliding_window_inference
|
||||
from examples.mlperf.metrics import get_dice_score
|
||||
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
s = 0
|
||||
st = time.perf_counter()
|
||||
for i, (image, label) in enumerate(iterate(), start=1):
|
||||
mt = time.perf_counter()
|
||||
pred, label = sliding_window_inference(mdl, image, label)
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
s += get_dice_score(pred, label).mean()
|
||||
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_retinanet():
|
||||
# RetinaNet with ResNeXt50_32X4D
|
||||
from extra.models.resnet import ResNeXt50_32X4D
|
||||
from extra.models.retinanet import RetinaNet
|
||||
mdl = RetinaNet(ResNeXt50_32X4D())
|
||||
mdl.load_from_pretrained()
|
||||
# RetinaNet with ResNeXt50_32X4D
|
||||
from extra.models.resnet import ResNeXt50_32X4D
|
||||
from extra.models.retinanet import RetinaNet
|
||||
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
def input_fixup(x):
|
||||
x = x.permute([0,3,1,2]) / 255.0
|
||||
x -= input_mean
|
||||
x /= input_std
|
||||
return x
|
||||
mdl = RetinaNet(ResNeXt50_32X4D())
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
from extra.datasets.openimages import openimages, iterate
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from contextlib import redirect_stdout
|
||||
coco = COCO(openimages())
|
||||
coco_eval = COCOeval(coco, iouType="bbox")
|
||||
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
|
||||
from tinygrad.jit import TinyJit
|
||||
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
|
||||
def input_fixup(x):
|
||||
x = x.permute([0, 3, 1, 2]) / 255.0
|
||||
x -= input_mean
|
||||
x /= input_std
|
||||
return x
|
||||
|
||||
n, bs = 0, 8
|
||||
st = time.perf_counter()
|
||||
for x, targets in iterate(coco, bs):
|
||||
dat = Tensor(x.astype(np.float32))
|
||||
mt = time.perf_counter()
|
||||
if dat.shape[0] == bs:
|
||||
outs = mdlrun(dat).numpy()
|
||||
else:
|
||||
mdlrun.jit_cache = None
|
||||
outs = mdl(input_fixup(dat)).numpy()
|
||||
et = time.perf_counter()
|
||||
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
|
||||
ext = time.perf_counter()
|
||||
n += len(targets)
|
||||
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
|
||||
img_ids = [t["image_id"] for t in targets]
|
||||
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score}
|
||||
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
|
||||
with redirect_stdout(None):
|
||||
coco_eval.cocoDt = coco.loadRes(coco_results)
|
||||
coco_eval.params.imgIds = img_ids
|
||||
coco_eval.evaluate()
|
||||
evaluated_imgs.extend(img_ids)
|
||||
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
from extra.datasets.openimages import openimages, iterate
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
coco = COCO(openimages())
|
||||
coco_eval = COCOeval(coco, iouType="bbox")
|
||||
coco_evalimgs, evaluated_imgs, ncats, narea = (
|
||||
[],
|
||||
[],
|
||||
len(coco_eval.params.catIds),
|
||||
len(coco_eval.params.areaRng),
|
||||
)
|
||||
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
|
||||
|
||||
n, bs = 0, 8
|
||||
st = time.perf_counter()
|
||||
for x, targets in iterate(coco, bs):
|
||||
dat = Tensor(x.astype(np.float32))
|
||||
mt = time.perf_counter()
|
||||
if dat.shape[0] == bs:
|
||||
outs = mdlrun(dat).numpy()
|
||||
else:
|
||||
mdlrun.jit_cache = None
|
||||
outs = mdl(input_fixup(dat)).numpy()
|
||||
et = time.perf_counter()
|
||||
predictions = mdl.postprocess_detections(
|
||||
outs,
|
||||
input_size=dat.shape[1:3],
|
||||
orig_image_sizes=[t["image_size"] for t in targets],
|
||||
)
|
||||
ext = time.perf_counter()
|
||||
n += len(targets)
|
||||
print(
|
||||
f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing"
|
||||
)
|
||||
img_ids = [t["image_id"] for t in targets]
|
||||
coco_results = [
|
||||
{
|
||||
"image_id": targets[i]["image_id"],
|
||||
"category_id": label,
|
||||
"bbox": box,
|
||||
"score": score,
|
||||
}
|
||||
for i, prediction in enumerate(predictions)
|
||||
for box, score, label in zip(*prediction.values())
|
||||
]
|
||||
with redirect_stdout(None):
|
||||
coco_eval.cocoDt = coco.loadRes(coco_results)
|
||||
coco_eval.params.imgIds = img_ids
|
||||
coco_eval.evaluate()
|
||||
evaluated_imgs.extend(img_ids)
|
||||
coco_evalimgs.append(
|
||||
np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids))
|
||||
)
|
||||
st = time.perf_counter()
|
||||
|
||||
coco_eval.params.imgIds = evaluated_imgs
|
||||
coco_eval._paramsEval.imgIds = evaluated_imgs
|
||||
coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
|
||||
coco_eval.params.imgIds = evaluated_imgs
|
||||
coco_eval._paramsEval.imgIds = evaluated_imgs
|
||||
coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
|
||||
def eval_rnnt():
|
||||
# RNN-T
|
||||
from extra.models.rnnt import RNNT
|
||||
mdl = RNNT()
|
||||
mdl.load_from_pretrained()
|
||||
# RNN-T
|
||||
from extra.models.rnnt import RNNT
|
||||
|
||||
from extra.datasets.librispeech import iterate
|
||||
from examples.mlperf.metrics import word_error_rate
|
||||
mdl = RNNT()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
|
||||
from extra.datasets.librispeech import iterate
|
||||
from examples.mlperf.metrics import word_error_rate
|
||||
|
||||
c = 0
|
||||
scores = 0
|
||||
words = 0
|
||||
st = time.perf_counter()
|
||||
for X, Y in iterate():
|
||||
mt = time.perf_counter()
|
||||
tt = mdl.decode(Tensor(X[0]), Tensor([X[1]]))
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
for n, t in enumerate(tt):
|
||||
tnp = np.array(t)
|
||||
_, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
|
||||
scores += scores_
|
||||
words += words_
|
||||
c += len(tt)
|
||||
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
|
||||
LABELS = [
|
||||
" ",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
"g",
|
||||
"h",
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"o",
|
||||
"p",
|
||||
"q",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"u",
|
||||
"v",
|
||||
"w",
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"'",
|
||||
]
|
||||
|
||||
c = 0
|
||||
scores = 0
|
||||
words = 0
|
||||
st = time.perf_counter()
|
||||
for X, Y in iterate():
|
||||
mt = time.perf_counter()
|
||||
tt = mdl.decode(Tensor(X[0]), Tensor([X[1]]))
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
for n, t in enumerate(tt):
|
||||
tnp = np.array(t)
|
||||
_, scores_, words_ = word_error_rate(
|
||||
["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]]
|
||||
)
|
||||
scores += scores_
|
||||
words += words_
|
||||
c += len(tt)
|
||||
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_bert():
|
||||
# Bert-QA
|
||||
from extra.models.bert import BertForQuestionAnswering
|
||||
mdl = BertForQuestionAnswering()
|
||||
mdl.load_from_pretrained()
|
||||
# Bert-QA
|
||||
from extra.models.bert import BertForQuestionAnswering
|
||||
|
||||
@TinyJit
|
||||
def run(input_ids, input_mask, segment_ids):
|
||||
return mdl(input_ids, input_mask, segment_ids).realize()
|
||||
mdl = BertForQuestionAnswering()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
from extra.datasets.squad import iterate
|
||||
from examples.mlperf.helpers import get_bert_qa_prediction
|
||||
from examples.mlperf.metrics import f1_score
|
||||
from transformers import BertTokenizer
|
||||
@TinyJit
|
||||
def run(input_ids, input_mask, segment_ids):
|
||||
return mdl(input_ids, input_mask, segment_ids).realize()
|
||||
|
||||
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights/bert_vocab.txt"))
|
||||
from extra.datasets.squad import iterate
|
||||
from examples.mlperf.helpers import get_bert_qa_prediction
|
||||
from examples.mlperf.metrics import f1_score
|
||||
from transformers import BertTokenizer
|
||||
|
||||
c = 0
|
||||
f1 = 0.0
|
||||
st = time.perf_counter()
|
||||
for X, Y in iterate(tokenizer):
|
||||
mt = time.perf_counter()
|
||||
outs = []
|
||||
for x in X:
|
||||
outs.append(run(Tensor(x["input_ids"]), Tensor(x["input_mask"]), Tensor(x["segment_ids"])).numpy())
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features")
|
||||
|
||||
pred = get_bert_qa_prediction(X, Y, outs)
|
||||
print(f"pred: {pred}\nans: {Y['answers']}")
|
||||
f1 += max([f1_score(pred, ans) for ans in Y["answers"]])
|
||||
c += 1
|
||||
print(f"f1: {f1/c}, raw: {f1}, c: {c}\n")
|
||||
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights/bert_vocab.txt"))
|
||||
|
||||
c = 0
|
||||
f1 = 0.0
|
||||
st = time.perf_counter()
|
||||
for X, Y in iterate(tokenizer):
|
||||
mt = time.perf_counter()
|
||||
outs = []
|
||||
for x in X:
|
||||
outs.append(
|
||||
run(
|
||||
Tensor(x["input_ids"]),
|
||||
Tensor(x["input_mask"]),
|
||||
Tensor(x["segment_ids"]),
|
||||
).numpy()
|
||||
)
|
||||
et = time.perf_counter()
|
||||
print(
|
||||
f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model over {len(X)} features"
|
||||
)
|
||||
|
||||
pred = get_bert_qa_prediction(X, Y, outs)
|
||||
print(f"pred: {pred}\nans: {Y['answers']}")
|
||||
f1 += max([f1_score(pred, ans) for ans in Y["answers"]])
|
||||
c += 1
|
||||
print(f"f1: {f1/c}, raw: {f1}, c: {c}\n")
|
||||
|
||||
st = time.perf_counter()
|
||||
|
||||
|
||||
def eval_mrcnn():
|
||||
from tqdm import tqdm
|
||||
from extra.models.mask_rcnn import MaskRCNN
|
||||
from extra.models.resnet import ResNet
|
||||
from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
||||
from examples.mask_rcnn import compute_prediction_batched, Image
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
from tqdm import tqdm
|
||||
from extra.models.mask_rcnn import MaskRCNN
|
||||
from extra.models.resnet import ResNet
|
||||
from extra.datasets.coco import (
|
||||
BASEDIR,
|
||||
images,
|
||||
convert_prediction_to_coco_bbox,
|
||||
convert_prediction_to_coco_mask,
|
||||
accumulate_predictions_for_coco,
|
||||
evaluate_predictions_on_coco,
|
||||
iterate,
|
||||
)
|
||||
from examples.mask_rcnn import compute_prediction_batched, Image
|
||||
|
||||
bbox_output = '/tmp/results_bbox.json'
|
||||
mask_output = '/tmp/results_mask.json'
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
||||
accumulate_predictions_for_coco([], mask_output, rm=True)
|
||||
bbox_output = "/tmp/results_bbox.json"
|
||||
mask_output = "/tmp/results_mask.json"
|
||||
|
||||
#TODO: bs > 1 not as accurate
|
||||
bs = 1
|
||||
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
||||
accumulate_predictions_for_coco([], mask_output, rm=True)
|
||||
|
||||
for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs):
|
||||
batch_imgs = []
|
||||
for image_row in batch:
|
||||
image_name = image_row['file_name']
|
||||
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
|
||||
batch_imgs.append(img)
|
||||
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
||||
for image_row, result in zip(batch, batch_result):
|
||||
image_name = image_row['file_name']
|
||||
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
||||
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
||||
accumulate_predictions_for_coco(box_pred, bbox_output)
|
||||
accumulate_predictions_for_coco(mask_pred, mask_output)
|
||||
del batch_imgs
|
||||
del batch_result
|
||||
# TODO: bs > 1 not as accurate
|
||||
bs = 1
|
||||
|
||||
for batch in tqdm(iterate(images, bs=bs), total=len(images) // bs):
|
||||
batch_imgs = []
|
||||
for image_row in batch:
|
||||
image_name = image_row["file_name"]
|
||||
img = Image.open(BASEDIR / f"val2017/{image_name}").convert("RGB")
|
||||
batch_imgs.append(img)
|
||||
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
||||
for image_row, result in zip(batch, batch_result):
|
||||
image_name = image_row["file_name"]
|
||||
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
||||
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
||||
accumulate_predictions_for_coco(box_pred, bbox_output)
|
||||
accumulate_predictions_for_coco(mask_pred, mask_output)
|
||||
del batch_imgs
|
||||
del batch_result
|
||||
|
||||
evaluate_predictions_on_coco(bbox_output, iou_type="bbox")
|
||||
evaluate_predictions_on_coco(mask_output, iou_type="segm")
|
||||
|
||||
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
|
||||
evaluate_predictions_on_coco(mask_output, iou_type='segm')
|
||||
|
||||
if __name__ == "__main__":
|
||||
# inference only
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
# inference only
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
|
||||
for m in models:
|
||||
nm = f"eval_{m}"
|
||||
if nm in globals():
|
||||
print(f"eval {m}")
|
||||
globals()[nm]()
|
||||
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
|
||||
for m in models:
|
||||
nm = f"eval_{m}"
|
||||
if nm in globals():
|
||||
print(f"eval {m}")
|
||||
globals()[nm]()
|
||||
|
|
|
@ -3,68 +3,84 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import GlobalCounters, getenv
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_model(model, *inputs):
|
||||
GlobalCounters.reset()
|
||||
out = model(*inputs)
|
||||
if isinstance(out, Tensor): out = out.numpy()
|
||||
# TODO: return event future to still get the time_sum_s without DEBUG=2
|
||||
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
|
||||
GlobalCounters.reset()
|
||||
out = model(*inputs)
|
||||
if isinstance(out, Tensor):
|
||||
out = out.numpy()
|
||||
# TODO: return event future to still get the time_sum_s without DEBUG=2
|
||||
print(
|
||||
f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms"
|
||||
)
|
||||
|
||||
|
||||
def spec_resnet():
|
||||
# Resnet50-v1.5
|
||||
from extra.models.resnet import ResNet50
|
||||
mdl = ResNet50()
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
test_model(mdl, img)
|
||||
# Resnet50-v1.5
|
||||
from extra.models.resnet import ResNet50
|
||||
|
||||
mdl = ResNet50()
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
test_model(mdl, img)
|
||||
|
||||
|
||||
def spec_retinanet():
|
||||
# Retinanet with ResNet backbone
|
||||
from extra.models.resnet import ResNet50
|
||||
from extra.models.retinanet import RetinaNet
|
||||
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
test_model(mdl, img)
|
||||
# Retinanet with ResNet backbone
|
||||
from extra.models.resnet import ResNet50
|
||||
from extra.models.retinanet import RetinaNet
|
||||
|
||||
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
test_model(mdl, img)
|
||||
|
||||
|
||||
def spec_unet3d():
|
||||
# 3D UNET
|
||||
from extra.models.unet3d import UNet3D
|
||||
mdl = UNet3D()
|
||||
#mdl.load_from_pretrained()
|
||||
img = Tensor.randn(1, 1, 128, 128, 128)
|
||||
test_model(mdl, img)
|
||||
# 3D UNET
|
||||
from extra.models.unet3d import UNet3D
|
||||
|
||||
mdl = UNet3D()
|
||||
# mdl.load_from_pretrained()
|
||||
img = Tensor.randn(1, 1, 128, 128, 128)
|
||||
test_model(mdl, img)
|
||||
|
||||
|
||||
def spec_rnnt():
|
||||
from extra.models.rnnt import RNNT
|
||||
mdl = RNNT()
|
||||
#mdl.load_from_pretrained()
|
||||
x = Tensor.randn(220, 1, 240)
|
||||
y = Tensor.randn(1, 220)
|
||||
test_model(mdl, x, y)
|
||||
from extra.models.rnnt import RNNT
|
||||
|
||||
mdl = RNNT()
|
||||
# mdl.load_from_pretrained()
|
||||
x = Tensor.randn(220, 1, 240)
|
||||
y = Tensor.randn(1, 220)
|
||||
test_model(mdl, x, y)
|
||||
|
||||
|
||||
def spec_bert():
|
||||
from extra.models.bert import BertForQuestionAnswering
|
||||
mdl = BertForQuestionAnswering()
|
||||
#mdl.load_from_pretrained()
|
||||
x = Tensor.randn(1, 384)
|
||||
am = Tensor.randn(1, 384)
|
||||
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
||||
test_model(mdl, x, am, tt)
|
||||
from extra.models.bert import BertForQuestionAnswering
|
||||
|
||||
mdl = BertForQuestionAnswering()
|
||||
# mdl.load_from_pretrained()
|
||||
x = Tensor.randn(1, 384)
|
||||
am = Tensor.randn(1, 384)
|
||||
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
||||
test_model(mdl, x, am, tt)
|
||||
|
||||
|
||||
def spec_mrcnn():
|
||||
from extra.models.mask_rcnn import MaskRCNN, ResNet
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
#mdl.load_from_pretrained()
|
||||
x = Tensor.randn(3, 224, 224)
|
||||
test_model(mdl, [x])
|
||||
from extra.models.mask_rcnn import MaskRCNN, ResNet
|
||||
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
# mdl.load_from_pretrained()
|
||||
x = Tensor.randn(3, 224, 224)
|
||||
test_model(mdl, [x])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# inference only for now
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
|
||||
nm = f"spec_{m}"
|
||||
if nm in globals():
|
||||
print(f"testing {m}")
|
||||
globals()[nm]()
|
||||
# inference only for now
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
|
||||
nm = f"spec_{m}"
|
||||
if nm in globals():
|
||||
print(f"testing {m}")
|
||||
globals()[nm]()
|
||||
|
|
|
@ -1,36 +1,43 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
||||
def train_resnet():
|
||||
# TODO: Resnet50-v1.5
|
||||
pass
|
||||
# TODO: Resnet50-v1.5
|
||||
pass
|
||||
|
||||
|
||||
def train_retinanet():
|
||||
# TODO: Retinanet
|
||||
pass
|
||||
# TODO: Retinanet
|
||||
pass
|
||||
|
||||
|
||||
def train_unet3d():
|
||||
# TODO: Unet3d
|
||||
pass
|
||||
# TODO: Unet3d
|
||||
pass
|
||||
|
||||
|
||||
def train_rnnt():
|
||||
# TODO: RNN-T
|
||||
pass
|
||||
# TODO: RNN-T
|
||||
pass
|
||||
|
||||
|
||||
def train_bert():
|
||||
# TODO: BERT
|
||||
pass
|
||||
# TODO: BERT
|
||||
pass
|
||||
|
||||
|
||||
def train_maskrcnn():
|
||||
# TODO: Mask RCNN
|
||||
pass
|
||||
# TODO: Mask RCNN
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Tensor.train():
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(","):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
globals()[nm]()
|
||||
|
||||
|
||||
with Tensor.train():
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,maskrcnn").split(
|
||||
","
|
||||
):
|
||||
nm = f"train_{m}"
|
||||
if nm in globals():
|
||||
print(f"training {m}")
|
||||
globals()[nm]()
|
||||
|
|
|
@ -9,99 +9,115 @@ from tinygrad.helpers import getenv
|
|||
from tinygrad.nn import optim
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
class LinearGen:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.scaled_uniform(128, 256)
|
||||
self.l2 = Tensor.scaled_uniform(256, 512)
|
||||
self.l3 = Tensor.scaled_uniform(512, 1024)
|
||||
self.l4 = Tensor.scaled_uniform(1024, 784)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.dot(self.l1).leakyrelu(0.2)
|
||||
x = x.dot(self.l2).leakyrelu(0.2)
|
||||
x = x.dot(self.l3).leakyrelu(0.2)
|
||||
x = x.dot(self.l4).tanh()
|
||||
return x
|
||||
class LinearGen:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.scaled_uniform(128, 256)
|
||||
self.l2 = Tensor.scaled_uniform(256, 512)
|
||||
self.l3 = Tensor.scaled_uniform(512, 1024)
|
||||
self.l4 = Tensor.scaled_uniform(1024, 784)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.dot(self.l1).leakyrelu(0.2)
|
||||
x = x.dot(self.l2).leakyrelu(0.2)
|
||||
x = x.dot(self.l3).leakyrelu(0.2)
|
||||
x = x.dot(self.l4).tanh()
|
||||
return x
|
||||
|
||||
|
||||
class LinearDisc:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.scaled_uniform(784, 1024)
|
||||
self.l2 = Tensor.scaled_uniform(1024, 512)
|
||||
self.l3 = Tensor.scaled_uniform(512, 256)
|
||||
self.l4 = Tensor.scaled_uniform(256, 2)
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.scaled_uniform(784, 1024)
|
||||
self.l2 = Tensor.scaled_uniform(1024, 512)
|
||||
self.l3 = Tensor.scaled_uniform(512, 256)
|
||||
self.l4 = Tensor.scaled_uniform(256, 2)
|
||||
|
||||
def forward(self, x):
|
||||
# balance the discriminator inputs with const bias (.add(1))
|
||||
x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3)
|
||||
x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3)
|
||||
x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3)
|
||||
x = x.dot(self.l4).log_softmax()
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# balance the discriminator inputs with const bias (.add(1))
|
||||
x = x.dot(self.l1).add(1).leakyrelu(0.2).dropout(0.3)
|
||||
x = x.dot(self.l2).leakyrelu(0.2).dropout(0.3)
|
||||
x = x.dot(self.l3).leakyrelu(0.2).dropout(0.3)
|
||||
x = x.dot(self.l4).log_softmax()
|
||||
return x
|
||||
|
||||
def make_batch(images):
|
||||
sample = np.random.randint(0, len(images), size=(batch_size))
|
||||
image_b = images[sample].reshape(-1, 28*28).astype(np.float32) / 127.5 - 1.0
|
||||
return Tensor(image_b)
|
||||
sample = np.random.randint(0, len(images), size=(batch_size))
|
||||
image_b = images[sample].reshape(-1, 28 * 28).astype(np.float32) / 127.5 - 1.0
|
||||
return Tensor(image_b)
|
||||
|
||||
|
||||
def make_labels(bs, col, val=-2.0):
|
||||
y = np.zeros((bs, 2), np.float32)
|
||||
y[range(bs), [col] * bs] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
|
||||
return Tensor(y)
|
||||
y = np.zeros((bs, 2), np.float32)
|
||||
y[
|
||||
range(bs), [col] * bs
|
||||
] = val # Can we do label smoothing? i.e -2.0 changed to -1.98789.
|
||||
return Tensor(y)
|
||||
|
||||
|
||||
def train_discriminator(optimizer, data_real, data_fake):
|
||||
real_labels = make_labels(batch_size, 1)
|
||||
fake_labels = make_labels(batch_size, 0)
|
||||
optimizer.zero_grad()
|
||||
output_real = discriminator.forward(data_real)
|
||||
output_fake = discriminator.forward(data_fake)
|
||||
loss_real = (output_real * real_labels).mean()
|
||||
loss_fake = (output_fake * fake_labels).mean()
|
||||
loss_real.backward()
|
||||
loss_fake.backward()
|
||||
optimizer.step()
|
||||
return (loss_real + loss_fake).numpy()
|
||||
real_labels = make_labels(batch_size, 1)
|
||||
fake_labels = make_labels(batch_size, 0)
|
||||
optimizer.zero_grad()
|
||||
output_real = discriminator.forward(data_real)
|
||||
output_fake = discriminator.forward(data_fake)
|
||||
loss_real = (output_real * real_labels).mean()
|
||||
loss_fake = (output_fake * fake_labels).mean()
|
||||
loss_real.backward()
|
||||
loss_fake.backward()
|
||||
optimizer.step()
|
||||
return (loss_real + loss_fake).numpy()
|
||||
|
||||
|
||||
def train_generator(optimizer, data_fake):
|
||||
real_labels = make_labels(batch_size, 1)
|
||||
optimizer.zero_grad()
|
||||
output = discriminator.forward(data_fake)
|
||||
loss = (output * real_labels).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss.numpy()
|
||||
real_labels = make_labels(batch_size, 1)
|
||||
optimizer.zero_grad()
|
||||
output = discriminator.forward(data_fake)
|
||||
loss = (output * real_labels).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss.numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# data for training and validation
|
||||
images_real = np.vstack(fetch_mnist()[::2])
|
||||
ds_noise = Tensor.randn(64, 128, requires_grad=False)
|
||||
# parameters
|
||||
epochs, batch_size, k = 300, 512, 1
|
||||
sample_interval = epochs // 10
|
||||
n_steps = len(images_real) // batch_size
|
||||
# models and optimizer
|
||||
generator = LinearGen()
|
||||
discriminator = LinearDisc()
|
||||
# path to store results
|
||||
output_dir = Path(".").resolve() / "outputs"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
# optimizers
|
||||
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5)
|
||||
# training loop
|
||||
for epoch in (t := trange(epochs)):
|
||||
loss_g, loss_d = 0.0, 0.0
|
||||
for _ in range(n_steps):
|
||||
data_real = make_batch(images_real)
|
||||
for step in range(k): # Try with k = 5 or 7.
|
||||
noise = Tensor.randn(batch_size, 128)
|
||||
data_fake = generator.forward(noise).detach()
|
||||
loss_d += train_discriminator(optim_d, data_real, data_fake)
|
||||
noise = Tensor.randn(batch_size, 128)
|
||||
data_fake = generator.forward(noise)
|
||||
loss_g += train_generator(optim_g, data_fake)
|
||||
if (epoch + 1) % sample_interval == 0:
|
||||
fake_images = generator.forward(ds_noise).detach().numpy()
|
||||
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
|
||||
save_image(make_grid(torch.tensor(fake_images)), output_dir / f"image_{epoch+1}.jpg")
|
||||
t.set_description(f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}")
|
||||
print("Training Completed!")
|
||||
# data for training and validation
|
||||
images_real = np.vstack(fetch_mnist()[::2])
|
||||
ds_noise = Tensor.randn(64, 128, requires_grad=False)
|
||||
# parameters
|
||||
epochs, batch_size, k = 300, 512, 1
|
||||
sample_interval = epochs // 10
|
||||
n_steps = len(images_real) // batch_size
|
||||
# models and optimizer
|
||||
generator = LinearGen()
|
||||
discriminator = LinearDisc()
|
||||
# path to store results
|
||||
output_dir = Path(".").resolve() / "outputs"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
# optimizers
|
||||
optim_g = optim.Adam(
|
||||
get_parameters(generator), lr=0.0002, b1=0.5
|
||||
) # 0.0002 for equilibrium!
|
||||
optim_d = optim.Adam(get_parameters(discriminator), lr=0.0002, b1=0.5)
|
||||
# training loop
|
||||
for epoch in (t := trange(epochs)):
|
||||
loss_g, loss_d = 0.0, 0.0
|
||||
for _ in range(n_steps):
|
||||
data_real = make_batch(images_real)
|
||||
for step in range(k): # Try with k = 5 or 7.
|
||||
noise = Tensor.randn(batch_size, 128)
|
||||
data_fake = generator.forward(noise).detach()
|
||||
loss_d += train_discriminator(optim_d, data_real, data_fake)
|
||||
noise = Tensor.randn(batch_size, 128)
|
||||
data_fake = generator.forward(noise)
|
||||
loss_g += train_generator(optim_g, data_fake)
|
||||
if (epoch + 1) % sample_interval == 0:
|
||||
fake_images = generator.forward(ds_noise).detach().numpy()
|
||||
fake_images = (fake_images.reshape(-1, 1, 28, 28) + 1) / 2 # 0 - 1 range.
|
||||
save_image(
|
||||
make_grid(torch.tensor(fake_images)),
|
||||
output_dir / f"image_{epoch+1}.jpg",
|
||||
)
|
||||
t.set_description(
|
||||
f"Generator loss: {loss_g/n_steps}, Discriminator loss: {loss_d/n_steps}"
|
||||
)
|
||||
print("Training Completed!")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
||||
# inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
||||
import sys
|
||||
import numpy as np
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
@ -9,128 +9,144 @@ from tinygrad.helpers import getenv
|
|||
from extra.datasets import fetch_mnist
|
||||
from extra.augment import augment_img
|
||||
from extra.training import train, evaluate
|
||||
|
||||
GPU = getenv("GPU")
|
||||
QUICK = getenv("QUICK")
|
||||
DEBUG = getenv("DEBUG")
|
||||
|
||||
class SqueezeExciteBlock2D:
|
||||
def __init__(self, filters):
|
||||
self.filters = filters
|
||||
self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32)
|
||||
self.bias1 = Tensor.scaled_uniform(1,self.filters//32)
|
||||
self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters)
|
||||
self.bias2 = Tensor.scaled_uniform(1, self.filters)
|
||||
|
||||
def __call__(self, input):
|
||||
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
|
||||
se = se.reshape(shape=(-1, self.filters))
|
||||
se = se.dot(self.weight1) + self.bias1
|
||||
se = se.relu()
|
||||
se = se.dot(self.weight2) + self.bias2
|
||||
se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
|
||||
se = input.mul(se)
|
||||
return se
|
||||
class SqueezeExciteBlock2D:
|
||||
def __init__(self, filters):
|
||||
self.filters = filters
|
||||
self.weight1 = Tensor.scaled_uniform(self.filters, self.filters // 32)
|
||||
self.bias1 = Tensor.scaled_uniform(1, self.filters // 32)
|
||||
self.weight2 = Tensor.scaled_uniform(self.filters // 32, self.filters)
|
||||
self.bias2 = Tensor.scaled_uniform(1, self.filters)
|
||||
|
||||
def __call__(self, input):
|
||||
se = input.avg_pool2d(
|
||||
kernel_size=(input.shape[2], input.shape[3])
|
||||
) # GlobalAveragePool2D
|
||||
se = se.reshape(shape=(-1, self.filters))
|
||||
se = se.dot(self.weight1) + self.bias1
|
||||
se = se.relu()
|
||||
se = se.dot(self.weight2) + self.bias2
|
||||
se = se.sigmoid().reshape(shape=(-1, self.filters, 1, 1)) # for broadcasting
|
||||
se = input.mul(se)
|
||||
return se
|
||||
|
||||
|
||||
class ConvBlock:
|
||||
def __init__(self, h, w, inp, filters=128, conv=3):
|
||||
self.h, self.w = h, w
|
||||
self.inp = inp
|
||||
#init weights
|
||||
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
|
||||
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
|
||||
#init layers
|
||||
self._bn = BatchNorm2d(128)
|
||||
self._seb = SqueezeExciteBlock2D(filters)
|
||||
def __init__(self, h, w, inp, filters=128, conv=3):
|
||||
self.h, self.w = h, w
|
||||
self.inp = inp
|
||||
# init weights
|
||||
self.cweights = [
|
||||
Tensor.scaled_uniform(filters, inp if i == 0 else filters, conv, conv)
|
||||
for i in range(3)
|
||||
]
|
||||
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
|
||||
# init layers
|
||||
self._bn = BatchNorm2d(128)
|
||||
self._seb = SqueezeExciteBlock2D(filters)
|
||||
|
||||
def __call__(self, input):
|
||||
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
|
||||
for cweight, cbias in zip(self.cweights, self.cbiases):
|
||||
x = x.pad2d(padding=[1, 1, 1, 1]).conv2d(cweight).add(cbias).relu()
|
||||
x = self._bn(x)
|
||||
x = self._seb(x)
|
||||
return x
|
||||
|
||||
def __call__(self, input):
|
||||
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
|
||||
for cweight, cbias in zip(self.cweights, self.cbiases):
|
||||
x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
|
||||
x = self._bn(x)
|
||||
x = self._seb(x)
|
||||
return x
|
||||
|
||||
class BigConvNet:
|
||||
def __init__(self):
|
||||
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
|
||||
self.weight1 = Tensor.scaled_uniform(128,10)
|
||||
self.weight2 = Tensor.scaled_uniform(128,10)
|
||||
def __init__(self):
|
||||
self.conv = [
|
||||
ConvBlock(28, 28, 1),
|
||||
ConvBlock(28, 28, 128),
|
||||
ConvBlock(14, 14, 128),
|
||||
]
|
||||
self.weight1 = Tensor.scaled_uniform(128, 10)
|
||||
self.weight2 = Tensor.scaled_uniform(128, 10)
|
||||
|
||||
def parameters(self):
|
||||
if DEBUG: #keeping this for a moment
|
||||
pars = [par for par in get_parameters(self) if par.requires_grad]
|
||||
no_pars = 0
|
||||
for par in pars:
|
||||
print(par.shape)
|
||||
no_pars += np.prod(par.shape)
|
||||
print('no of parameters', no_pars)
|
||||
return pars
|
||||
else:
|
||||
return get_parameters(self)
|
||||
def parameters(self):
|
||||
if DEBUG: # keeping this for a moment
|
||||
pars = [par for par in get_parameters(self) if par.requires_grad]
|
||||
no_pars = 0
|
||||
for par in pars:
|
||||
print(par.shape)
|
||||
no_pars += np.prod(par.shape)
|
||||
print("no of parameters", no_pars)
|
||||
return pars
|
||||
else:
|
||||
return get_parameters(self)
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename+'.npy', 'wb') as f:
|
||||
for par in get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
np.save(f, par.numpy())
|
||||
def save(self, filename):
|
||||
with open(filename + ".npy", "wb") as f:
|
||||
for par in get_parameters(self):
|
||||
# if par.requires_grad:
|
||||
np.save(f, par.numpy())
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename+'.npy', 'rb') as f:
|
||||
for par in get_parameters(self):
|
||||
#if par.requires_grad:
|
||||
try:
|
||||
par.numpy()[:] = np.load(f)
|
||||
if GPU:
|
||||
par.gpu()
|
||||
except:
|
||||
print('Could not load parameter')
|
||||
def load(self, filename):
|
||||
with open(filename + ".npy", "rb") as f:
|
||||
for par in get_parameters(self):
|
||||
# if par.requires_grad:
|
||||
try:
|
||||
par.numpy()[:] = np.load(f)
|
||||
if GPU:
|
||||
par.gpu()
|
||||
except:
|
||||
print("Could not load parameter")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv[0](x)
|
||||
x = self.conv[1](x)
|
||||
x = x.avg_pool2d(kernel_size=(2,2))
|
||||
x = self.conv[2](x)
|
||||
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
||||
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
|
||||
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
|
||||
return xo
|
||||
def forward(self, x):
|
||||
x = self.conv[0](x)
|
||||
x = self.conv[1](x)
|
||||
x = x.avg_pool2d(kernel_size=(2, 2))
|
||||
x = self.conv[2](x)
|
||||
x1 = x.avg_pool2d(kernel_size=(14, 14)).reshape(shape=(-1, 128)) # global
|
||||
x2 = x.max_pool2d(kernel_size=(14, 14)).reshape(shape=(-1, 128)) # global
|
||||
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
|
||||
return xo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
|
||||
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
|
||||
BS = 32
|
||||
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
|
||||
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
|
||||
BS = 32
|
||||
|
||||
lmbd = 0.00025
|
||||
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
steps = len(X_train)//BS
|
||||
np.random.seed(1337)
|
||||
if QUICK:
|
||||
steps = 1
|
||||
X_test, Y_test = X_test[:BS], Y_test[:BS]
|
||||
lmbd = 0.00025
|
||||
lossfn = (
|
||||
lambda out, y: out.sparse_categorical_crossentropy(y)
|
||||
+ lmbd * (model.weight1.abs() + model.weight2.abs()).sum()
|
||||
)
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
steps = len(X_train) // BS
|
||||
np.random.seed(1337)
|
||||
if QUICK:
|
||||
steps = 1
|
||||
X_test, Y_test = X_test[:BS], Y_test[:BS]
|
||||
|
||||
model = BigConvNet()
|
||||
model = BigConvNet()
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
model.load(sys.argv[1])
|
||||
print('Loaded weights "'+sys.argv[1]+'", evaluating...')
|
||||
evaluate(model, X_test, Y_test, BS=BS)
|
||||
except:
|
||||
print('could not load weights "'+sys.argv[1]+'".')
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
model.load(sys.argv[1])
|
||||
print('Loaded weights "' + sys.argv[1] + '", evaluating...')
|
||||
evaluate(model, X_test, Y_test, BS=BS)
|
||||
except:
|
||||
print('could not load weights "' + sys.argv[1] + '".')
|
||||
|
||||
if GPU:
|
||||
params = get_parameters(model)
|
||||
[x.gpu_() for x in params]
|
||||
if GPU:
|
||||
params = get_parameters(model)
|
||||
[x.gpu_() for x in params]
|
||||
|
||||
for lr, epochs in zip(lrs, epochss):
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
for epoch in range(1,epochs+1):
|
||||
#first epoch without augmentation
|
||||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')
|
||||
for lr, epochs in zip(lrs, epochss):
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
for epoch in range(1, epochs + 1):
|
||||
# first epoch without augmentation
|
||||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||
model.save(f"examples/checkpoint{accuracy * 1e6:.0f}")
|
||||
|
|
|
@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d
|
|||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Tensor.train():
|
||||
with Tensor.train():
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]):
|
||||
t.realize()
|
||||
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]): t.realize()
|
||||
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -7,199 +7,369 @@ import soundfile
|
|||
import numpy as np
|
||||
import parselmouth
|
||||
|
||||
|
||||
class PMF0Predictor: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||
def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
|
||||
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = hop_length, f0_min, f0_max, sampling_rate, "pm"
|
||||
def interpolate_f0(self,f0):
|
||||
vuv_vector = np.zeros_like(f0, dtype=np.float32)
|
||||
vuv_vector[f0 > 0.0] = 1.0
|
||||
vuv_vector[f0 <= 0.0] = 0.0
|
||||
nzindex = np.nonzero(f0)[0]
|
||||
data = f0[nzindex]
|
||||
nzindex = nzindex.astype(np.float32)
|
||||
time_org = self.hop_length / self.sampling_rate * nzindex
|
||||
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
|
||||
if data.shape[0] <= 0: return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
|
||||
if data.shape[0] == 1: return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
|
||||
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
|
||||
return f0,vuv_vector
|
||||
def compute_f0(self,wav,p_len=None):
|
||||
x = wav
|
||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = parselmouth.Sound(x, self.sampling_rate) \
|
||||
.to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6,pitch_floor=self.f0_min, pitch_ceiling=self.f0_max) \
|
||||
.selected_array['frequency']
|
||||
pad_size=(p_len - len(f0) + 1) // 2
|
||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
||||
f0,uv = self.interpolate_f0(f0)
|
||||
return f0
|
||||
def compute_f0_uv(self,wav,p_len=None):
|
||||
x = wav
|
||||
if p_len is None: p_len = x.shape[0]//self.hop_length
|
||||
else: assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
|
||||
time_step=time_step / 1000, voicing_threshold=0.6,
|
||||
pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
|
||||
pad_size=(p_len - len(f0) + 1) // 2
|
||||
if(pad_size>0 or p_len - len(f0) - pad_size>0):
|
||||
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
|
||||
f0,uv = self.interpolate_f0(f0)
|
||||
return f0,uv
|
||||
def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
|
||||
self.hop_length, self.f0_min, self.f0_max, self.sampling_rate, self.name = (
|
||||
hop_length,
|
||||
f0_min,
|
||||
f0_max,
|
||||
sampling_rate,
|
||||
"pm",
|
||||
)
|
||||
|
||||
def interpolate_f0(self, f0):
|
||||
vuv_vector = np.zeros_like(f0, dtype=np.float32)
|
||||
vuv_vector[f0 > 0.0] = 1.0
|
||||
vuv_vector[f0 <= 0.0] = 0.0
|
||||
nzindex = np.nonzero(f0)[0]
|
||||
data = f0[nzindex]
|
||||
nzindex = nzindex.astype(np.float32)
|
||||
time_org = self.hop_length / self.sampling_rate * nzindex
|
||||
time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
|
||||
if data.shape[0] <= 0:
|
||||
return np.zeros(f0.shape[0], dtype=np.float32), vuv_vector
|
||||
if data.shape[0] == 1:
|
||||
return np.ones(f0.shape[0], dtype=np.float32) * f0[0], vuv_vector
|
||||
f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
|
||||
return f0, vuv_vector
|
||||
|
||||
def compute_f0(self, wav, p_len=None):
|
||||
x = wav
|
||||
if p_len is None:
|
||||
p_len = x.shape[0] // self.hop_length
|
||||
else:
|
||||
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = (
|
||||
parselmouth.Sound(x, self.sampling_rate)
|
||||
.to_pitch_ac(
|
||||
time_step=time_step / 1000,
|
||||
voicing_threshold=0.6,
|
||||
pitch_floor=self.f0_min,
|
||||
pitch_ceiling=self.f0_max,
|
||||
)
|
||||
.selected_array["frequency"]
|
||||
)
|
||||
pad_size = (p_len - len(f0) + 1) // 2
|
||||
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
||||
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
||||
f0, uv = self.interpolate_f0(f0)
|
||||
return f0
|
||||
|
||||
def compute_f0_uv(self, wav, p_len=None):
|
||||
x = wav
|
||||
if p_len is None:
|
||||
p_len = x.shape[0] // self.hop_length
|
||||
else:
|
||||
assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
|
||||
time_step = self.hop_length / self.sampling_rate * 1000
|
||||
f0 = (
|
||||
parselmouth.Sound(x, self.sampling_rate)
|
||||
.to_pitch_ac(
|
||||
time_step=time_step / 1000,
|
||||
voicing_threshold=0.6,
|
||||
pitch_floor=self.f0_min,
|
||||
pitch_ceiling=self.f0_max,
|
||||
)
|
||||
.selected_array["frequency"]
|
||||
)
|
||||
pad_size = (p_len - len(f0) + 1) // 2
|
||||
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
||||
f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
||||
f0, uv = self.interpolate_f0(f0)
|
||||
return f0, uv
|
||||
|
||||
|
||||
class Slicer: # from https://github.com/svc-develop-team/so-vits-svc/
|
||||
def __init__(self, sr: int, threshold: float = -40., min_length: int = 5000, min_interval: int = 300, hop_size: int = 20, max_sil_kept: int = 5000):
|
||||
if not min_length >= min_interval >= hop_size:
|
||||
raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
|
||||
if not max_sil_kept >= hop_size:
|
||||
raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
|
||||
min_interval = sr * min_interval / 1000
|
||||
self.threshold = 10 ** (threshold / 20.)
|
||||
self.hop_size = round(sr * hop_size / 1000)
|
||||
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||
self.min_interval = round(min_interval / self.hop_size)
|
||||
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||
def _apply_slice(self, waveform, begin, end):
|
||||
if len(waveform.shape) > 1: return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
|
||||
else: return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
|
||||
def slice(self, waveform):
|
||||
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
|
||||
if samples.shape[0] <= self.min_length: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
||||
rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
||||
sil_tags, silence_start, clip_start = [], None, 0
|
||||
for i, rms in enumerate(rms_list):
|
||||
if rms < self.threshold: # Keep looping while frame is silent.
|
||||
if silence_start is None: # Record start of silent frames.
|
||||
silence_start = i
|
||||
continue
|
||||
if silence_start is None: continue # Keep looping while frame is not silent and silence start has not been recorded.
|
||||
# Clear recorded silence start if interval is not enough or clip is too short
|
||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||
need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
|
||||
if not is_leading_silence and not need_slice_middle:
|
||||
silence_start = None
|
||||
continue
|
||||
if i - silence_start <= self.max_sil_kept: # Need slicing. Record the range of silent frames to be removed.
|
||||
pos = rms_list[silence_start: i + 1].argmin() + silence_start
|
||||
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
||||
clip_start = pos
|
||||
elif i - silence_start <= self.max_sil_kept * 2:
|
||||
pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
|
||||
pos += i - self.max_sil_kept
|
||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
clip_start = pos_r
|
||||
def __init__(
|
||||
self,
|
||||
sr: int,
|
||||
threshold: float = -40.0,
|
||||
min_length: int = 5000,
|
||||
min_interval: int = 300,
|
||||
hop_size: int = 20,
|
||||
max_sil_kept: int = 5000,
|
||||
):
|
||||
if not min_length >= min_interval >= hop_size:
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: min_length >= min_interval >= hop_size"
|
||||
)
|
||||
if not max_sil_kept >= hop_size:
|
||||
raise ValueError(
|
||||
"The following condition must be satisfied: max_sil_kept >= hop_size"
|
||||
)
|
||||
min_interval = sr * min_interval / 1000
|
||||
self.threshold = 10 ** (threshold / 20.0)
|
||||
self.hop_size = round(sr * hop_size / 1000)
|
||||
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
||||
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
||||
self.min_interval = round(min_interval / self.hop_size)
|
||||
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
||||
|
||||
def _apply_slice(self, waveform, begin, end):
|
||||
if len(waveform.shape) > 1:
|
||||
return waveform[
|
||||
:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)
|
||||
]
|
||||
else:
|
||||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||
clip_start = max(pos_r, pos)
|
||||
else:
|
||||
pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
|
||||
pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
|
||||
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
|
||||
clip_start = pos_r
|
||||
silence_start = None
|
||||
total_frames = rms_list.shape[0]
|
||||
if silence_start is not None and total_frames - silence_start >= self.min_interval: # Deal with trailing silence.
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
if len(sil_tags) == 0: return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} # Apply and return slices.
|
||||
chunks = []
|
||||
if sil_tags[0][0]:
|
||||
chunks.append({"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
|
||||
for i in range(0, len(sil_tags)):
|
||||
if i: chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
|
||||
chunks.append({"slice": True, "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
|
||||
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
||||
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
|
||||
chunk_dict = {}
|
||||
for i in range(len(chunks)): chunk_dict[str(i)] = chunks[i]
|
||||
return chunk_dict
|
||||
return waveform[
|
||||
begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)
|
||||
]
|
||||
|
||||
def slice(self, waveform):
|
||||
samples = librosa.to_mono(waveform) if len(waveform.shape) > 1 else waveform
|
||||
if samples.shape[0] <= self.min_length:
|
||||
return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
|
||||
rms_list = librosa.feature.rms(
|
||||
y=samples, frame_length=self.win_size, hop_length=self.hop_size
|
||||
).squeeze(0)
|
||||
sil_tags, silence_start, clip_start = [], None, 0
|
||||
for i, rms in enumerate(rms_list):
|
||||
if rms < self.threshold: # Keep looping while frame is silent.
|
||||
if silence_start is None: # Record start of silent frames.
|
||||
silence_start = i
|
||||
continue
|
||||
if silence_start is None:
|
||||
continue # Keep looping while frame is not silent and silence start has not been recorded.
|
||||
# Clear recorded silence start if interval is not enough or clip is too short
|
||||
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
||||
need_slice_middle = (
|
||||
i - silence_start >= self.min_interval
|
||||
and i - clip_start >= self.min_length
|
||||
)
|
||||
if not is_leading_silence and not need_slice_middle:
|
||||
silence_start = None
|
||||
continue
|
||||
if (
|
||||
i - silence_start <= self.max_sil_kept
|
||||
): # Need slicing. Record the range of silent frames to be removed.
|
||||
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
||||
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
||||
clip_start = pos
|
||||
elif i - silence_start <= self.max_sil_kept * 2:
|
||||
pos = rms_list[
|
||||
i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
pos += i - self.max_sil_kept
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
if silence_start == 0:
|
||||
sil_tags.append((0, pos_r))
|
||||
clip_start = pos_r
|
||||
else:
|
||||
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
||||
clip_start = max(pos_r, pos)
|
||||
else:
|
||||
pos_l = (
|
||||
rms_list[
|
||||
silence_start : silence_start + self.max_sil_kept + 1
|
||||
].argmin()
|
||||
+ silence_start
|
||||
)
|
||||
pos_r = (
|
||||
rms_list[i - self.max_sil_kept : i + 1].argmin()
|
||||
+ i
|
||||
- self.max_sil_kept
|
||||
)
|
||||
sil_tags.append((0, pos_r) if silence_start == 0 else (pos_l, pos_r))
|
||||
clip_start = pos_r
|
||||
silence_start = None
|
||||
total_frames = rms_list.shape[0]
|
||||
if (
|
||||
silence_start is not None
|
||||
and total_frames - silence_start >= self.min_interval
|
||||
): # Deal with trailing silence.
|
||||
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
||||
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
||||
sil_tags.append((pos, total_frames + 1))
|
||||
if len(sil_tags) == 0:
|
||||
return {
|
||||
"0": {"slice": False, "split_time": f"0,{len(waveform)}"}
|
||||
} # Apply and return slices.
|
||||
chunks = []
|
||||
if sil_tags[0][0]:
|
||||
chunks.append(
|
||||
{
|
||||
"slice": False,
|
||||
"split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}",
|
||||
}
|
||||
)
|
||||
for i in range(0, len(sil_tags)):
|
||||
if i:
|
||||
chunks.append(
|
||||
{
|
||||
"slice": False,
|
||||
"split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}",
|
||||
}
|
||||
)
|
||||
chunks.append(
|
||||
{
|
||||
"slice": True,
|
||||
"split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}",
|
||||
}
|
||||
)
|
||||
if sil_tags[-1][1] * self.hop_size < len(waveform):
|
||||
chunks.append(
|
||||
{
|
||||
"slice": False,
|
||||
"split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}",
|
||||
}
|
||||
)
|
||||
chunk_dict = {}
|
||||
for i in range(len(chunks)):
|
||||
chunk_dict[str(i)] = chunks[i]
|
||||
return chunk_dict
|
||||
|
||||
|
||||
# sinc_interp_hann audio resampling
|
||||
class Resample:
|
||||
def __init__(self, orig_freq:int=16000, new_freq:int=16000, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None, dtype:Optional[dtypes]=None):
|
||||
self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.beta = orig_freq, new_freq, lowpass_filter_width, rolloff, beta
|
||||
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
|
||||
self.kernel, self.width = self._get_sinc_resample_kernel(dtype) if self.orig_freq != self.new_freq else (None, None)
|
||||
def __call__(self, waveform:Tensor) -> Tensor:
|
||||
if self.orig_freq == self.new_freq: return waveform
|
||||
return self._apply_sinc_resample_kernel(waveform)
|
||||
def _apply_sinc_resample_kernel(self, waveform:Tensor):
|
||||
if not waveform.is_floating_point(): raise TypeError(f"Waveform tensor expected to be of type float, but received {waveform.dtype}.")
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
||||
shape = waveform.shape
|
||||
waveform = waveform.reshape(-1, shape[-1]) # pack batch
|
||||
num_wavs, length = waveform.shape
|
||||
target_length = int(math.ceil(new_freq * length / orig_freq))
|
||||
waveform = waveform.pad2d((self.width, self.width + orig_freq))
|
||||
resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq)
|
||||
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
|
||||
resampled = resampled[..., :target_length]
|
||||
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
|
||||
return resampled
|
||||
def _get_sinc_resample_kernel(self, dtype=None):
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (int(self.new_freq) // self.gcd)
|
||||
if self.lowpass_filter_width <= 0: raise ValueError("Low pass filter width should be positive.")
|
||||
base_freq = min(orig_freq, new_freq)
|
||||
base_freq *= self.rolloff
|
||||
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
|
||||
idx = Tensor.arange(-width, width + orig_freq, dtype=(dtype if dtype is not None else dtypes.float32))[None, None] / orig_freq
|
||||
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
|
||||
t *= base_freq
|
||||
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
|
||||
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
|
||||
t *= math.pi
|
||||
scale = base_freq / orig_freq
|
||||
kernels = Tensor.where(t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t)
|
||||
kernels *= window * scale
|
||||
if dtype is None: kernels = kernels.cast(dtype=dtypes.float32)
|
||||
return kernels, width
|
||||
def __init__(
|
||||
self,
|
||||
orig_freq: int = 16000,
|
||||
new_freq: int = 16000,
|
||||
lowpass_filter_width: int = 6,
|
||||
rolloff: float = 0.99,
|
||||
beta: Optional[float] = None,
|
||||
dtype: Optional[dtypes] = None,
|
||||
):
|
||||
(
|
||||
self.orig_freq,
|
||||
self.new_freq,
|
||||
self.lowpass_filter_width,
|
||||
self.rolloff,
|
||||
self.beta,
|
||||
) = (orig_freq, new_freq, lowpass_filter_width, rolloff, beta)
|
||||
self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
|
||||
self.kernel, self.width = (
|
||||
self._get_sinc_resample_kernel(dtype)
|
||||
if self.orig_freq != self.new_freq
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
def __call__(self, waveform: Tensor) -> Tensor:
|
||||
if self.orig_freq == self.new_freq:
|
||||
return waveform
|
||||
return self._apply_sinc_resample_kernel(waveform)
|
||||
|
||||
def _apply_sinc_resample_kernel(self, waveform: Tensor):
|
||||
if not waveform.is_floating_point():
|
||||
raise TypeError(
|
||||
f"Waveform tensor expected to be of type float, but received {waveform.dtype}."
|
||||
)
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
|
||||
int(self.new_freq) // self.gcd
|
||||
)
|
||||
shape = waveform.shape
|
||||
waveform = waveform.reshape(-1, shape[-1]) # pack batch
|
||||
num_wavs, length = waveform.shape
|
||||
target_length = int(math.ceil(new_freq * length / orig_freq))
|
||||
waveform = waveform.pad2d((self.width, self.width + orig_freq))
|
||||
resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq)
|
||||
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
|
||||
resampled = resampled[..., :target_length]
|
||||
resampled = resampled.reshape(shape[:-1] + resampled.shape[-1:]) # unpack batch
|
||||
return resampled
|
||||
|
||||
def _get_sinc_resample_kernel(self, dtype=None):
|
||||
orig_freq, new_freq = (int(self.orig_freq) // self.gcd), (
|
||||
int(self.new_freq) // self.gcd
|
||||
)
|
||||
if self.lowpass_filter_width <= 0:
|
||||
raise ValueError("Low pass filter width should be positive.")
|
||||
base_freq = min(orig_freq, new_freq)
|
||||
base_freq *= self.rolloff
|
||||
width = math.ceil(self.lowpass_filter_width * orig_freq / base_freq)
|
||||
idx = (
|
||||
Tensor.arange(
|
||||
-width,
|
||||
width + orig_freq,
|
||||
dtype=(dtype if dtype is not None else dtypes.float32),
|
||||
)[None, None]
|
||||
/ orig_freq
|
||||
)
|
||||
t = Tensor.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx
|
||||
t *= base_freq
|
||||
t = t.clip(-self.lowpass_filter_width, self.lowpass_filter_width)
|
||||
window = (t * math.pi / self.lowpass_filter_width / 2).cos() ** 2
|
||||
t *= math.pi
|
||||
scale = base_freq / orig_freq
|
||||
kernels = Tensor.where(
|
||||
t == 0, Tensor(1.0, dtype=t.dtype).to(t.device), t.sin() / t
|
||||
)
|
||||
kernels *= window * scale
|
||||
if dtype is None:
|
||||
kernels = kernels.cast(dtype=dtypes.float32)
|
||||
return kernels, width
|
||||
|
||||
|
||||
def sinc_interp_resample(
|
||||
x: Tensor,
|
||||
orig_freq: int = 16000,
|
||||
new_freq: int = 1600,
|
||||
lowpass_filter_width: int = 6,
|
||||
rolloff: float = 0.99,
|
||||
beta: Optional[float] = None,
|
||||
):
|
||||
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
|
||||
return resamp(x)
|
||||
|
||||
def sinc_interp_resample(x:Tensor, orig_freq:int=16000, new_freq:int=1600, lowpass_filter_width:int=6, rolloff:float=0.99, beta:Optional[float]=None):
|
||||
resamp = Resample(orig_freq, new_freq, lowpass_filter_width, rolloff, beta, x.dtype)
|
||||
return resamp(x)
|
||||
|
||||
def cut(audio_path, db_thresh=-30, min_len=5000):
|
||||
audio, sr = librosa.load(audio_path, sr=None)
|
||||
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
|
||||
chunks = slicer.slice(audio)
|
||||
return chunks
|
||||
audio, sr = librosa.load(audio_path, sr=None)
|
||||
slicer = Slicer(sr=sr, threshold=db_thresh, min_length=min_len)
|
||||
chunks = slicer.slice(audio)
|
||||
return chunks
|
||||
|
||||
|
||||
def chunks2audio(audio_path, chunks):
|
||||
chunks = dict(chunks)
|
||||
audio, sr = load_audiofile(audio_path)
|
||||
if len(audio.shape) == 2 and audio.shape[1] >= 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = audio.numpy()[0]
|
||||
result = []
|
||||
for k, v in chunks.items():
|
||||
tag = v["split_time"].split(",")
|
||||
if tag[0] != tag[1]:
|
||||
result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
|
||||
return result, sr
|
||||
chunks = dict(chunks)
|
||||
audio, sr = load_audiofile(audio_path)
|
||||
if len(audio.shape) == 2 and audio.shape[1] >= 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = audio.numpy()[0]
|
||||
result = []
|
||||
for k, v in chunks.items():
|
||||
tag = v["split_time"].split(",")
|
||||
if tag[0] != tag[1]:
|
||||
result.append((v["slice"], audio[int(tag[0]) : int(tag[1])]))
|
||||
return result, sr
|
||||
|
||||
def load_audiofile(filepath:str, frame_offset:int=0, num_frames:int=-1, channels_first:bool=True):
|
||||
with soundfile.SoundFile(filepath, "r") as file_:
|
||||
frames = file_._prepare_read(frame_offset, None, num_frames)
|
||||
waveform = file_.read(frames, "float32", always_2d=True)
|
||||
sample_rate = file_.samplerate
|
||||
waveform = Tensor(waveform)
|
||||
if channels_first: waveform = waveform.transpose(0, 1)
|
||||
return waveform, sample_rate
|
||||
|
||||
def get_unit_f0(wav:Tensor, tran, hop_length, target_sample, f0_filter=False) -> Tuple[Tensor,Tensor,Tensor]:
|
||||
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
|
||||
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
|
||||
if f0_filter and sum(f0) == 0: raise RuntimeError("No voice detected")
|
||||
f0 = Tensor(f0.astype(np.float32)).float()
|
||||
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
|
||||
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
|
||||
wav16k = sinc_interp_resample(wav[None,:], target_sample, 16000)[0]
|
||||
return wav16k.realize(), f0.realize(), uv.realize()
|
||||
def load_audiofile(
|
||||
filepath: str,
|
||||
frame_offset: int = 0,
|
||||
num_frames: int = -1,
|
||||
channels_first: bool = True,
|
||||
):
|
||||
with soundfile.SoundFile(filepath, "r") as file_:
|
||||
frames = file_._prepare_read(frame_offset, None, num_frames)
|
||||
waveform = file_.read(frames, "float32", always_2d=True)
|
||||
sample_rate = file_.samplerate
|
||||
waveform = Tensor(waveform)
|
||||
if channels_first:
|
||||
waveform = waveform.transpose(0, 1)
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def get_unit_f0(
|
||||
wav: Tensor, tran, hop_length, target_sample, f0_filter=False
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
f0_predictor = PMF0Predictor(hop_length, sampling_rate=target_sample)
|
||||
f0, uv = f0_predictor.compute_f0_uv(wav.numpy())
|
||||
if f0_filter and sum(f0) == 0:
|
||||
raise RuntimeError("No voice detected")
|
||||
f0 = Tensor(f0.astype(np.float32)).float()
|
||||
f0 = (f0 * 2 ** (tran / 12)).unsqueeze(0)
|
||||
uv = Tensor(uv.astype(np.float32)).float().unsqueeze(0)
|
||||
wav16k = sinc_interp_resample(wav[None, :], target_sample, 16000)[0]
|
||||
return wav16k.realize(), f0.realize(), uv.realize()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -10,96 +10,108 @@ from tinygrad.tensor import Tensor
|
|||
from extra.datasets import fetch_cifar
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
|
||||
class TinyConvNet:
|
||||
def __init__(self, classes=10):
|
||||
conv = 3
|
||||
inter_chan, out_chan = 8, 16 # for speed
|
||||
self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
|
||||
self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
|
||||
self.l1 = Tensor.uniform(out_chan*6*6, classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1)
|
||||
class TinyConvNet:
|
||||
def __init__(self, classes=10):
|
||||
conv = 3
|
||||
inter_chan, out_chan = 8, 16 # for speed
|
||||
self.c1 = Tensor.uniform(inter_chan, 3, conv, conv)
|
||||
self.c2 = Tensor.uniform(out_chan, inter_chan, conv, conv)
|
||||
self.l1 = Tensor.uniform(out_chan * 6 * 6, classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
IMAGENET = getenv("IMAGENET")
|
||||
classes = 1000 if IMAGENET else 10
|
||||
IMAGENET = getenv("IMAGENET")
|
||||
classes = 1000 if IMAGENET else 10
|
||||
|
||||
TINY = getenv("TINY")
|
||||
TRANSFER = getenv("TRANSFER")
|
||||
if TINY:
|
||||
model = TinyConvNet(classes)
|
||||
elif TRANSFER:
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
|
||||
model.load_from_pretrained()
|
||||
else:
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
|
||||
TINY = getenv("TINY")
|
||||
TRANSFER = getenv("TRANSFER")
|
||||
if TINY:
|
||||
model = TinyConvNet(classes)
|
||||
elif TRANSFER:
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
|
||||
model.load_from_pretrained()
|
||||
else:
|
||||
model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
|
||||
|
||||
parameters = get_parameters(model)
|
||||
print("parameter count", len(parameters))
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
parameters = get_parameters(model)
|
||||
print("parameter count", len(parameters))
|
||||
optimizer = optim.Adam(parameters, lr=0.001)
|
||||
|
||||
BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048)
|
||||
print(f"training with batch size {BS} for {steps} steps")
|
||||
BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048)
|
||||
print(f"training with batch size {BS} for {steps} steps")
|
||||
|
||||
if IMAGENET:
|
||||
from extra.datasets.imagenet import fetch_batch
|
||||
def loader(q):
|
||||
while 1:
|
||||
try:
|
||||
q.put(fetch_batch(BS))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
q = Queue(16)
|
||||
for i in range(2):
|
||||
p = Process(target=loader, args=(q,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
else:
|
||||
X_train, Y_train, _, _ = fetch_cifar()
|
||||
X_train = X_train.reshape((-1, 3, 32, 32))
|
||||
Y_train = Y_train.reshape((-1,))
|
||||
if IMAGENET:
|
||||
from extra.datasets.imagenet import fetch_batch
|
||||
|
||||
with Tensor.train():
|
||||
for i in (t := trange(steps)):
|
||||
if IMAGENET:
|
||||
X, Y = q.get(True)
|
||||
else:
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
|
||||
def loader(q):
|
||||
while 1:
|
||||
try:
|
||||
q.put(fetch_batch(BS))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||
fp_time = (time.time()-st)*1000.0
|
||||
q = Queue(16)
|
||||
for i in range(2):
|
||||
p = Process(target=loader, args=(q,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
else:
|
||||
X_train, Y_train, _, _ = fetch_cifar()
|
||||
X_train = X_train.reshape((-1, 3, 32, 32))
|
||||
Y_train = Y_train.reshape((-1,))
|
||||
|
||||
y = np.zeros((BS,classes), np.float32)
|
||||
y[range(y.shape[0]),Y] = -classes
|
||||
y = Tensor(y, requires_grad=False)
|
||||
loss = out.log_softmax().mul(y).mean()
|
||||
with Tensor.train():
|
||||
for i in (t := trange(steps)):
|
||||
if IMAGENET:
|
||||
X, Y = q.get(True)
|
||||
else:
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
|
||||
|
||||
optimizer.zero_grad()
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||
fp_time = (time.time() - st) * 1000.0
|
||||
|
||||
st = time.time()
|
||||
loss.backward()
|
||||
bp_time = (time.time()-st)*1000.0
|
||||
y = np.zeros((BS, classes), np.float32)
|
||||
y[range(y.shape[0]), Y] = -classes
|
||||
y = Tensor(y, requires_grad=False)
|
||||
loss = out.log_softmax().mul(y).mean()
|
||||
|
||||
st = time.time()
|
||||
optimizer.step()
|
||||
opt_time = (time.time()-st)*1000.0
|
||||
optimizer.zero_grad()
|
||||
|
||||
st = time.time()
|
||||
loss = loss.numpy()
|
||||
cat = out.argmax(axis=1).numpy()
|
||||
accuracy = (cat == Y).mean()
|
||||
finish_time = (time.time()-st)*1000.0
|
||||
st = time.time()
|
||||
loss.backward()
|
||||
bp_time = (time.time() - st) * 1000.0
|
||||
|
||||
# printing
|
||||
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
|
||||
(loss, accuracy,
|
||||
fp_time, bp_time, opt_time, finish_time,
|
||||
fp_time + bp_time + opt_time + finish_time))
|
||||
st = time.time()
|
||||
optimizer.step()
|
||||
opt_time = (time.time() - st) * 1000.0
|
||||
|
||||
del out, y, loss
|
||||
st = time.time()
|
||||
loss = loss.numpy()
|
||||
cat = out.argmax(axis=1).numpy()
|
||||
accuracy = (cat == Y).mean()
|
||||
finish_time = (time.time() - st) * 1000.0
|
||||
|
||||
# printing
|
||||
t.set_description(
|
||||
"loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f"
|
||||
% (
|
||||
loss,
|
||||
accuracy,
|
||||
fp_time,
|
||||
bp_time,
|
||||
opt_time,
|
||||
finish_time,
|
||||
fp_time + bp_time + opt_time + finish_time,
|
||||
)
|
||||
)
|
||||
|
||||
del out, y, loss
|
||||
|
|
|
@ -11,35 +11,38 @@ from extra.datasets import fetch_mnist
|
|||
|
||||
|
||||
class ComposeTransforms:
|
||||
def __init__(self, trans):
|
||||
self.trans = trans
|
||||
def __init__(self, trans):
|
||||
self.trans = trans
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.trans:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.trans:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
if __name__ == "__main__":
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
classes = 10
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
classes = 10
|
||||
|
||||
TRANSFER = getenv('TRANSFER')
|
||||
model = ResNet(getenv('NUM', 18), num_classes=classes)
|
||||
if TRANSFER:
|
||||
model.load_from_pretrained()
|
||||
TRANSFER = getenv("TRANSFER")
|
||||
model = ResNet(getenv("NUM", 18), num_classes=classes)
|
||||
if TRANSFER:
|
||||
model.load_from_pretrained()
|
||||
|
||||
lr = 5e-3
|
||||
transform = ComposeTransforms([
|
||||
lambda x: [Image.fromarray(xx, mode='L').resize((64, 64)) for xx in x],
|
||||
lambda x: np.stack([np.asarray(xx) for xx in x], 0),
|
||||
lambda x: x / 255.0,
|
||||
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
|
||||
])
|
||||
for _ in range(5):
|
||||
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
|
||||
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
|
||||
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
|
||||
lr /= 1.2
|
||||
print(f'reducing lr to {lr:.7f}')
|
||||
lr = 5e-3
|
||||
transform = ComposeTransforms(
|
||||
[
|
||||
lambda x: [Image.fromarray(xx, mode="L").resize((64, 64)) for xx in x],
|
||||
lambda x: np.stack([np.asarray(xx) for xx in x], 0),
|
||||
lambda x: x / 255.0,
|
||||
lambda x: np.tile(np.expand_dims(x, 1), (1, 3, 1, 1)).astype(np.float32),
|
||||
]
|
||||
)
|
||||
for _ in range(5):
|
||||
optimizer = optim.SGD(get_parameters(model), lr=lr, momentum=0.9)
|
||||
train(model, X_train, Y_train, optimizer, 100, BS=32, transform=transform)
|
||||
evaluate(model, X_test, Y_test, num_classes=classes, transform=transform)
|
||||
lr /= 1.2
|
||||
print(f"reducing lr to {lr:.7f}")
|
||||
|
|
|
@ -7,36 +7,49 @@ from tinygrad.nn.optim import Adam
|
|||
from extra.training import train, evaluate
|
||||
from extra.models.transformer import Transformer
|
||||
|
||||
|
||||
# dataset idea from https://github.com/karpathy/minGPT/blob/master/projects/adder/adder.py
|
||||
def make_dataset():
|
||||
ds = []
|
||||
for i in range(100):
|
||||
for j in range(100):
|
||||
s = i+j
|
||||
ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
|
||||
random.shuffle(ds)
|
||||
ds = np.array(ds).astype(np.float32)
|
||||
ds_X = ds[:, 0:6]
|
||||
ds_Y = np.copy(ds[:, 1:])
|
||||
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
|
||||
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
ds = []
|
||||
for i in range(100):
|
||||
for j in range(100):
|
||||
s = i + j
|
||||
ds.append(
|
||||
[i // 10, i % 10, j // 10, j % 10, s // 100, (s // 10) % 10, s % 10]
|
||||
)
|
||||
random.shuffle(ds)
|
||||
ds = np.array(ds).astype(np.float32)
|
||||
ds_X = ds[:, 0:6]
|
||||
ds_Y = np.copy(ds[:, 1:])
|
||||
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
|
||||
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4, 32)
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
lr = 0.003
|
||||
for i in range(10):
|
||||
optim = Adam(get_parameters(model), lr=lr)
|
||||
train(model, X_train, Y_train, optim, 50, BS=64)
|
||||
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
|
||||
lr /= 1.2
|
||||
print(f'reducing lr to {lr:.4f}')
|
||||
if acc > 0.998:
|
||||
wrong=0
|
||||
for k in range(len(Y_test_preds)):
|
||||
if (Y_test_preds[k] != Y_test[k]).any():
|
||||
wrong+=1
|
||||
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
|
||||
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
|
||||
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
|
||||
model = Transformer(10, 6, 2, 128, 4, 32)
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
lr = 0.003
|
||||
for i in range(10):
|
||||
optim = Adam(get_parameters(model), lr=lr)
|
||||
train(model, X_train, Y_train, optim, 50, BS=64)
|
||||
acc, Y_test_preds = evaluate(
|
||||
model, X_test, Y_test, num_classes=10, return_predict=True
|
||||
)
|
||||
lr /= 1.2
|
||||
print(f"reducing lr to {lr:.4f}")
|
||||
if acc > 0.998:
|
||||
wrong = 0
|
||||
for k in range(len(Y_test_preds)):
|
||||
if (Y_test_preds[k] != Y_test[k]).any():
|
||||
wrong += 1
|
||||
a, b, c, x = (
|
||||
X_test[k, :2],
|
||||
X_test[k, 2:4],
|
||||
Y_test[k, -3:],
|
||||
Y_test_preds[k, -3:],
|
||||
)
|
||||
print(
|
||||
f"{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})"
|
||||
)
|
||||
print(f"Wrong predictions: {wrong}, acc = {acc:.4f}")
|
||||
|
|
463
examples/vgg7.py
463
examples/vgg7.py
|
@ -12,251 +12,276 @@ from examples.vgg7_helpers.waifu2x import image_load, image_save, Vgg7
|
|||
# amount of context erased by model
|
||||
CONTEXT = 7
|
||||
|
||||
|
||||
def get_sample_count(samples_dir):
|
||||
try:
|
||||
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
|
||||
v = samples_dir_count_file.readline()
|
||||
samples_dir_count_file.close()
|
||||
return int(v)
|
||||
except:
|
||||
return 0
|
||||
try:
|
||||
samples_dir_count_file = open(samples_dir + "/sample_count.txt", "r")
|
||||
v = samples_dir_count_file.readline()
|
||||
samples_dir_count_file.close()
|
||||
return int(v)
|
||||
except:
|
||||
return 0
|
||||
|
||||
|
||||
def set_sample_count(samples_dir, sc):
|
||||
with open(samples_dir + "/sample_count.txt", "w") as file:
|
||||
file.write(str(sc) + "\n")
|
||||
with open(samples_dir + "/sample_count.txt", "w") as file:
|
||||
file.write(str(sc) + "\n")
|
||||
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("python3 -m examples.vgg7 import MODELJSON MODEL")
|
||||
print(" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json")
|
||||
print(" into a safetensors file")
|
||||
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
|
||||
print(" *this format is used by most other commands in this program*")
|
||||
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
|
||||
print(" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors")
|
||||
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
|
||||
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
|
||||
print(" output image has 7 pixels removed on all edges")
|
||||
print(" do not run on large images, will have *hilarious* RAM use")
|
||||
print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT")
|
||||
print(" does the 'whole thing' (padding, tiling)")
|
||||
print(" safe for large images, etc.")
|
||||
print("python3 -m examples.vgg7 new MODEL")
|
||||
print(" creates a new model (experimental)")
|
||||
print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE")
|
||||
print(" trains a model (experimental)")
|
||||
print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)")
|
||||
print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.")
|
||||
print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png")
|
||||
print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png")
|
||||
print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,")
|
||||
print(" my_samples/0b.png is the first original image)")
|
||||
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
|
||||
print(" won't pad or tile, so keep image sizes sane")
|
||||
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
|
||||
print(" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training")
|
||||
print(" maintains/creates samples_count.txt automatically")
|
||||
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
|
||||
sys.exit(1)
|
||||
print("python3 -m examples.vgg7 import MODELJSON MODEL")
|
||||
print(
|
||||
" imports a waifu2x JSON vgg_7 model, i.e. waifu2x/models/vgg_7/art/scale2.0x_model.json"
|
||||
)
|
||||
print(" into a safetensors file")
|
||||
print(" weight tensors are ordered in tinygrad/ncnn form, as so: (outC,inC,H,W)")
|
||||
print(" *this format is used by most other commands in this program*")
|
||||
print("python3 -m examples.vgg7 import_kinne MODEL_KINNE MODEL_SAFETENSORS")
|
||||
print(
|
||||
" imports a model in 'KINNE' format (raw floats: used by older versions of this example) into safetensors"
|
||||
)
|
||||
print("python3 -m examples.vgg7 execute MODEL IMG_IN IMG_OUT")
|
||||
print(" given an already-nearest-neighbour-scaled image, runs vgg7 on it")
|
||||
print(" output image has 7 pixels removed on all edges")
|
||||
print(" do not run on large images, will have *hilarious* RAM use")
|
||||
print("python3 -m examples.vgg7 execute_full MODEL IMG_IN IMG_OUT")
|
||||
print(" does the 'whole thing' (padding, tiling)")
|
||||
print(" safe for large images, etc.")
|
||||
print("python3 -m examples.vgg7 new MODEL")
|
||||
print(" creates a new model (experimental)")
|
||||
print("python3 -m examples.vgg7 train MODEL SAMPLES_DIR ROUNDS ROUNDS_SAVE")
|
||||
print(" trains a model (experimental)")
|
||||
print(" (how experimental? well, every time I tried it, it flooded w/ NaNs)")
|
||||
print(" note: ROUNDS < 0 means 'forever'. ROUNDS_SAVE <= 0 is not a good idea.")
|
||||
print(" expects roughly execute's input as SAMPLES_DIR/IDXa.png")
|
||||
print(" expects roughly execute's output as SAMPLES_DIR/IDXb.png")
|
||||
print(" (i.e. my_samples/0a.png is the first pre-nearest-scaled image,")
|
||||
print(" my_samples/0b.png is the first original image)")
|
||||
print(" in addition, SAMPLES_DIR/samples_count.txt indicates sample count")
|
||||
print(" won't pad or tile, so keep image sizes sane")
|
||||
print("python3 -m examples.vgg7 samplify IMG_A IMG_B SAMPLES_DIR SIZE")
|
||||
print(
|
||||
" creates overlapping micropatches (SIZExSIZE w/ 7-pixel border) for training"
|
||||
)
|
||||
print(" maintains/creates samples_count.txt automatically")
|
||||
print(" unlike training, IMG_A must be exactly half the size of IMG_B")
|
||||
sys.exit(1)
|
||||
|
||||
cmd = sys.argv[1]
|
||||
vgg7 = Vgg7()
|
||||
|
||||
|
||||
def nansbane(p):
|
||||
if numpy.isnan(numpy.min(p.numpy())):
|
||||
raise Exception("A NaN in the model has been detected. This model will not be interacted with to prevent further damage.")
|
||||
if numpy.isnan(numpy.min(p.numpy())):
|
||||
raise Exception(
|
||||
"A NaN in the model has been detected. This model will not be interacted with to prevent further damage."
|
||||
)
|
||||
|
||||
|
||||
def load_and_save(path, save):
|
||||
if save:
|
||||
for v in vgg7.get_parameters():
|
||||
nansbane(v)
|
||||
st = get_state_dict(vgg7)
|
||||
safe_save(st, path)
|
||||
else:
|
||||
st = safe_load(path)
|
||||
load_state_dict(vgg7, st)
|
||||
for v in vgg7.get_parameters():
|
||||
nansbane(v)
|
||||
if save:
|
||||
for v in vgg7.get_parameters():
|
||||
nansbane(v)
|
||||
st = get_state_dict(vgg7)
|
||||
safe_save(st, path)
|
||||
else:
|
||||
st = safe_load(path)
|
||||
load_state_dict(vgg7, st)
|
||||
for v in vgg7.get_parameters():
|
||||
nansbane(v)
|
||||
|
||||
|
||||
if cmd == "import":
|
||||
src = sys.argv[2]
|
||||
model = sys.argv[3]
|
||||
src = sys.argv[2]
|
||||
model = sys.argv[3]
|
||||
|
||||
vgg7.load_waifu2x_json(json.load(open(src, "rb")))
|
||||
vgg7.load_waifu2x_json(json.load(open(src, "rb")))
|
||||
|
||||
load_and_save(model, True)
|
||||
elif cmd == "import_kinne":
|
||||
# tinygrad wasn't doing safetensors when this example was written
|
||||
# it's possible someone might have a model around using the resulting interim format
|
||||
src = sys.argv[2]
|
||||
model = sys.argv[3]
|
||||
|
||||
index = 0
|
||||
for t in vgg7.get_parameters():
|
||||
fn = src + "/snoop_bin_" + str(index) + ".bin"
|
||||
t.assign(Tensor(numpy.fromfile(fn, "<f4")).reshape(shape=t.shape))
|
||||
index += 1
|
||||
|
||||
load_and_save(model, True)
|
||||
elif cmd == "execute":
|
||||
model = sys.argv[2]
|
||||
in_file = sys.argv[3]
|
||||
out_file = sys.argv[4]
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).numpy())
|
||||
elif cmd == "execute_full":
|
||||
model = sys.argv[2]
|
||||
in_file = sys.argv[3]
|
||||
out_file = sys.argv[4]
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
image_save(out_file, vgg7.forward_tiled(image_load(in_file), 156))
|
||||
elif cmd == "new":
|
||||
model = sys.argv[2]
|
||||
|
||||
load_and_save(model, True)
|
||||
elif cmd == "train":
|
||||
model = sys.argv[2]
|
||||
samples_base = sys.argv[3]
|
||||
samples_count = get_sample_count(samples_base)
|
||||
rounds = int(sys.argv[4])
|
||||
rounds_per_save = int(sys.argv[5])
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
# Initialize sample probabilities.
|
||||
# This is used to try and get the network to focus on "interesting" samples,
|
||||
# which works nicely with the microsample system.
|
||||
sample_probs = None
|
||||
sample_probs_path = model + "_sample_probs.bin"
|
||||
try:
|
||||
# try to read...
|
||||
sample_probs = numpy.fromfile(sample_probs_path, "<f8")
|
||||
if sample_probs.shape[0] != samples_count:
|
||||
print("sample probs size != sample count - initializing")
|
||||
sample_probs = None
|
||||
except:
|
||||
# it's fine
|
||||
print("sample probs could not be loaded - initializing")
|
||||
|
||||
if sample_probs is None:
|
||||
# This stupidly high amount is used to force an initial pass over all samples
|
||||
sample_probs = numpy.ones(samples_count) * 1000
|
||||
|
||||
print("Training...")
|
||||
# Adam has a tendency to destroy the state of the network when restarted
|
||||
# Plus it's slower
|
||||
optim = SGD(vgg7.get_parameters())
|
||||
|
||||
rnum = 0
|
||||
while True:
|
||||
# The way the -1 option works is that rnum is never -1.
|
||||
if rnum == rounds:
|
||||
break
|
||||
|
||||
sample_idx = 0
|
||||
try:
|
||||
sample_idx = numpy.random.choice(samples_count, p = sample_probs / sample_probs.sum())
|
||||
except:
|
||||
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
|
||||
sample_idx = random.randint(0, samples_count - 1)
|
||||
|
||||
x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
|
||||
y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
|
||||
|
||||
sample_x = Tensor(x_img, requires_grad = False)
|
||||
sample_y = Tensor(y_img, requires_grad = False)
|
||||
|
||||
# magic code roughly from readme example
|
||||
# An explanation, in case anyone else has to go down this path:
|
||||
# This runs the actual network normally
|
||||
out = vgg7.forward(sample_x)
|
||||
# Subtraction determines error here (as this is an image, not classification).
|
||||
# *Abs is the important bit* - at least for me, anyway.
|
||||
# The training process seeks to minimize this 'loss' value.
|
||||
# Minimization of loss *tends towards negative infinity*, so without the abs,
|
||||
# or without an implicit abs (the mul in the README),
|
||||
# loss will always go haywire in one direction or another.
|
||||
# Mean determines how errors are treated.
|
||||
# Do not use Sum. I tried that. It worked while I was using 1x1 patches...
|
||||
# Then it went exponential.
|
||||
# Also, Mean goes *after* abs. I realize this should have been obvious to me.
|
||||
loss = sample_y.sub(out).abs().mean()
|
||||
# This is the bit where tinygrad works backward from the loss
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
# And this updates the parameters
|
||||
optim.step()
|
||||
|
||||
# warning: used by sample probability adjuster
|
||||
loss_indicator = loss.max().numpy()
|
||||
print("Round " + str(rnum) + " : " + str(loss_indicator))
|
||||
|
||||
if (rnum % rounds_per_save) == 0:
|
||||
print("Saving")
|
||||
load_and_save(model, True)
|
||||
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
||||
|
||||
# Update round state
|
||||
# Number
|
||||
rnum = rnum + 1
|
||||
# Probability management
|
||||
# there must always be a probability, no matter how slim, even if loss goes to 0
|
||||
sample_probs[sample_idx] = max(loss_indicator, 1.e-10)
|
||||
|
||||
# if we were told to save every round, we already saved
|
||||
if rounds_per_save != 1:
|
||||
print("Done with all rounds, saving")
|
||||
load_and_save(model, True)
|
||||
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
||||
elif cmd == "import_kinne":
|
||||
# tinygrad wasn't doing safetensors when this example was written
|
||||
# it's possible someone might have a model around using the resulting interim format
|
||||
src = sys.argv[2]
|
||||
model = sys.argv[3]
|
||||
|
||||
index = 0
|
||||
for t in vgg7.get_parameters():
|
||||
fn = src + "/snoop_bin_" + str(index) + ".bin"
|
||||
t.assign(Tensor(numpy.fromfile(fn, "<f4")).reshape(shape=t.shape))
|
||||
index += 1
|
||||
|
||||
load_and_save(model, True)
|
||||
elif cmd == "execute":
|
||||
model = sys.argv[2]
|
||||
in_file = sys.argv[3]
|
||||
out_file = sys.argv[4]
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
image_save(out_file, vgg7.forward(Tensor(image_load(in_file))).numpy())
|
||||
elif cmd == "execute_full":
|
||||
model = sys.argv[2]
|
||||
in_file = sys.argv[3]
|
||||
out_file = sys.argv[4]
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
image_save(out_file, vgg7.forward_tiled(image_load(in_file), 156))
|
||||
elif cmd == "new":
|
||||
model = sys.argv[2]
|
||||
|
||||
load_and_save(model, True)
|
||||
elif cmd == "train":
|
||||
model = sys.argv[2]
|
||||
samples_base = sys.argv[3]
|
||||
samples_count = get_sample_count(samples_base)
|
||||
rounds = int(sys.argv[4])
|
||||
rounds_per_save = int(sys.argv[5])
|
||||
|
||||
load_and_save(model, False)
|
||||
|
||||
# Initialize sample probabilities.
|
||||
# This is used to try and get the network to focus on "interesting" samples,
|
||||
# which works nicely with the microsample system.
|
||||
sample_probs = None
|
||||
sample_probs_path = model + "_sample_probs.bin"
|
||||
try:
|
||||
# try to read...
|
||||
sample_probs = numpy.fromfile(sample_probs_path, "<f8")
|
||||
if sample_probs.shape[0] != samples_count:
|
||||
print("sample probs size != sample count - initializing")
|
||||
sample_probs = None
|
||||
except:
|
||||
# it's fine
|
||||
print("sample probs could not be loaded - initializing")
|
||||
|
||||
if sample_probs is None:
|
||||
# This stupidly high amount is used to force an initial pass over all samples
|
||||
sample_probs = numpy.ones(samples_count) * 1000
|
||||
|
||||
print("Training...")
|
||||
# Adam has a tendency to destroy the state of the network when restarted
|
||||
# Plus it's slower
|
||||
optim = SGD(vgg7.get_parameters())
|
||||
|
||||
rnum = 0
|
||||
while True:
|
||||
# The way the -1 option works is that rnum is never -1.
|
||||
if rnum == rounds:
|
||||
break
|
||||
|
||||
sample_idx = 0
|
||||
try:
|
||||
sample_idx = numpy.random.choice(
|
||||
samples_count, p=sample_probs / sample_probs.sum()
|
||||
)
|
||||
except:
|
||||
print("exception occurred (PROBABLY value-probabilities-dont-sum-to-1)")
|
||||
sample_idx = random.randint(0, samples_count - 1)
|
||||
|
||||
x_img = image_load(samples_base + "/" + str(sample_idx) + "a.png")
|
||||
y_img = image_load(samples_base + "/" + str(sample_idx) + "b.png")
|
||||
|
||||
sample_x = Tensor(x_img, requires_grad=False)
|
||||
sample_y = Tensor(y_img, requires_grad=False)
|
||||
|
||||
# magic code roughly from readme example
|
||||
# An explanation, in case anyone else has to go down this path:
|
||||
# This runs the actual network normally
|
||||
out = vgg7.forward(sample_x)
|
||||
# Subtraction determines error here (as this is an image, not classification).
|
||||
# *Abs is the important bit* - at least for me, anyway.
|
||||
# The training process seeks to minimize this 'loss' value.
|
||||
# Minimization of loss *tends towards negative infinity*, so without the abs,
|
||||
# or without an implicit abs (the mul in the README),
|
||||
# loss will always go haywire in one direction or another.
|
||||
# Mean determines how errors are treated.
|
||||
# Do not use Sum. I tried that. It worked while I was using 1x1 patches...
|
||||
# Then it went exponential.
|
||||
# Also, Mean goes *after* abs. I realize this should have been obvious to me.
|
||||
loss = sample_y.sub(out).abs().mean()
|
||||
# This is the bit where tinygrad works backward from the loss
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
# And this updates the parameters
|
||||
optim.step()
|
||||
|
||||
# warning: used by sample probability adjuster
|
||||
loss_indicator = loss.max().numpy()
|
||||
print("Round " + str(rnum) + " : " + str(loss_indicator))
|
||||
|
||||
if (rnum % rounds_per_save) == 0:
|
||||
print("Saving")
|
||||
load_and_save(model, True)
|
||||
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
||||
|
||||
# Update round state
|
||||
# Number
|
||||
rnum = rnum + 1
|
||||
# Probability management
|
||||
# there must always be a probability, no matter how slim, even if loss goes to 0
|
||||
sample_probs[sample_idx] = max(loss_indicator, 1.0e-10)
|
||||
|
||||
# if we were told to save every round, we already saved
|
||||
if rounds_per_save != 1:
|
||||
print("Done with all rounds, saving")
|
||||
load_and_save(model, True)
|
||||
sample_probs.astype("<f8", "C").tofile(sample_probs_path)
|
||||
|
||||
elif cmd == "samplify":
|
||||
a_img = sys.argv[2]
|
||||
b_img = sys.argv[3]
|
||||
samples_base = sys.argv[4]
|
||||
sample_size = int(sys.argv[5])
|
||||
samples_count = get_sample_count(samples_base)
|
||||
a_img = sys.argv[2]
|
||||
b_img = sys.argv[3]
|
||||
samples_base = sys.argv[4]
|
||||
sample_size = int(sys.argv[5])
|
||||
samples_count = get_sample_count(samples_base)
|
||||
|
||||
# This bit is interesting because it actually does some work.
|
||||
# Not much, but some work.
|
||||
a_img = image_load(a_img)
|
||||
b_img = image_load(b_img)
|
||||
# This bit is interesting because it actually does some work.
|
||||
# Not much, but some work.
|
||||
a_img = image_load(a_img)
|
||||
b_img = image_load(b_img)
|
||||
|
||||
# as with the main library body,
|
||||
# Y X order is used here
|
||||
# as with the main library body,
|
||||
# Y X order is used here
|
||||
|
||||
# assertion before pre-upscaling is performed
|
||||
assert a_img.shape[2] == (b_img.shape[2] // 2)
|
||||
assert a_img.shape[3] == (b_img.shape[3] // 2)
|
||||
# assertion before pre-upscaling is performed
|
||||
assert a_img.shape[2] == (b_img.shape[2] // 2)
|
||||
assert a_img.shape[3] == (b_img.shape[3] // 2)
|
||||
|
||||
# pre-upscaling - this matches the sizes (and coordinates)
|
||||
a_img = a_img.repeat(2, 2).repeat(2, 3)
|
||||
# pre-upscaling - this matches the sizes (and coordinates)
|
||||
a_img = a_img.repeat(2, 2).repeat(2, 3)
|
||||
|
||||
samples_added = 0
|
||||
samples_added = 0
|
||||
|
||||
# actual patch extraction
|
||||
for posy in range(CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size):
|
||||
for posx in range(CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size):
|
||||
# this is a viable patch location, add it
|
||||
# note the ranges here:
|
||||
# + there are always CONTEXT pixels *before* the point
|
||||
# + with no subtraction at the end, there'd already be a pixel *at* the point,
|
||||
# as ranges are exclusive
|
||||
# + additionally, there are sample_size - 1 additional sample pixels
|
||||
# + additionally, there are CONTEXT additional pixels
|
||||
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
|
||||
patch_x = a_img[:, :, posy - CONTEXT : posy + CONTEXT + sample_size, posx - CONTEXT : posx + CONTEXT + sample_size]
|
||||
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
|
||||
# actual patch extraction
|
||||
for posy in range(
|
||||
CONTEXT, b_img.shape[2] - (CONTEXT + sample_size - 1), sample_size
|
||||
):
|
||||
for posx in range(
|
||||
CONTEXT, b_img.shape[3] - (CONTEXT + sample_size - 1), sample_size
|
||||
):
|
||||
# this is a viable patch location, add it
|
||||
# note the ranges here:
|
||||
# + there are always CONTEXT pixels *before* the point
|
||||
# + with no subtraction at the end, there'd already be a pixel *at* the point,
|
||||
# as ranges are exclusive
|
||||
# + additionally, there are sample_size - 1 additional sample pixels
|
||||
# + additionally, there are CONTEXT additional pixels
|
||||
# + therefore there are CONTEXT + sample_size pixels *at & after* the point
|
||||
patch_x = a_img[
|
||||
:,
|
||||
:,
|
||||
posy - CONTEXT : posy + CONTEXT + sample_size,
|
||||
posx - CONTEXT : posx + CONTEXT + sample_size,
|
||||
]
|
||||
patch_y = b_img[:, :, posy : posy + sample_size, posx : posx + sample_size]
|
||||
|
||||
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
|
||||
image_save(f"{samples_base}/{str(samples_count)}b.png", patch_y)
|
||||
samples_count += 1
|
||||
samples_added += 1
|
||||
image_save(f"{samples_base}/{str(samples_count)}a.png", patch_x)
|
||||
image_save(f"{samples_base}/{str(samples_count)}b.png", patch_y)
|
||||
samples_count += 1
|
||||
samples_added += 1
|
||||
|
||||
print(f"Added {str(samples_added)} samples")
|
||||
set_sample_count(samples_base, samples_count)
|
||||
print(f"Added {str(samples_added)} samples")
|
||||
set_sample_count(samples_base, samples_count)
|
||||
|
||||
else:
|
||||
print("unknown command")
|
||||
print("unknown command")
|
||||
|
|
|
@ -11,183 +11,211 @@ from tinygrad.helpers import fetch
|
|||
# tinygrad convolution tensor input layout is (1,c,y,x) - and therefore the form for all images used in the project
|
||||
# tinygrad convolution tensor weight layout is (outC,inC,H,W) - this matches NCNN (and therefore KINNE), but not waifu2x json
|
||||
|
||||
|
||||
def image_load(path) -> numpy.ndarray:
|
||||
"""
|
||||
Loads an image in the shape expected by other functions in this module.
|
||||
Doesn't Tensor it, in case you need to do further work with it.
|
||||
"""
|
||||
# file
|
||||
na = numpy.array(Image.open(path))
|
||||
if na.shape[2] == 4:
|
||||
# RGBA -> RGB (covers opaque images with alpha channels)
|
||||
na = na[:,:,0:3]
|
||||
# fix shape
|
||||
na = numpy.moveaxis(na, [2,0,1], [0,1,2])
|
||||
# shape is now (3,h,w), add 1
|
||||
na = na.reshape(1,3,na.shape[1],na.shape[2])
|
||||
# change type
|
||||
na = na.astype("float32") / 255.0
|
||||
return na
|
||||
"""
|
||||
Loads an image in the shape expected by other functions in this module.
|
||||
Doesn't Tensor it, in case you need to do further work with it.
|
||||
"""
|
||||
# file
|
||||
na = numpy.array(Image.open(path))
|
||||
if na.shape[2] == 4:
|
||||
# RGBA -> RGB (covers opaque images with alpha channels)
|
||||
na = na[:, :, 0:3]
|
||||
# fix shape
|
||||
na = numpy.moveaxis(na, [2, 0, 1], [0, 1, 2])
|
||||
# shape is now (3,h,w), add 1
|
||||
na = na.reshape(1, 3, na.shape[1], na.shape[2])
|
||||
# change type
|
||||
na = na.astype("float32") / 255.0
|
||||
return na
|
||||
|
||||
|
||||
def image_save(path, na: numpy.ndarray):
|
||||
"""
|
||||
Saves an image of the shape expected by other functions in this module.
|
||||
However, note this expects a numpy array.
|
||||
"""
|
||||
# change type
|
||||
na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8")
|
||||
# shape is now (1,3,h,w), remove 1
|
||||
na = na.reshape(3,na.shape[2],na.shape[3])
|
||||
# fix shape
|
||||
na = numpy.moveaxis(na, [0,1,2], [2,0,1])
|
||||
# shape is now (h,w,3)
|
||||
# file
|
||||
Image.fromarray(na).save(path)
|
||||
"""
|
||||
Saves an image of the shape expected by other functions in this module.
|
||||
However, note this expects a numpy array.
|
||||
"""
|
||||
# change type
|
||||
na = numpy.fmax(numpy.fmin(na * 255.0, 255), 0).astype("uint8")
|
||||
# shape is now (1,3,h,w), remove 1
|
||||
na = na.reshape(3, na.shape[2], na.shape[3])
|
||||
# fix shape
|
||||
na = numpy.moveaxis(na, [0, 1, 2], [2, 0, 1])
|
||||
# shape is now (h,w,3)
|
||||
# file
|
||||
Image.fromarray(na).save(path)
|
||||
|
||||
|
||||
# The Model
|
||||
|
||||
|
||||
class Conv3x3Biased:
|
||||
"""
|
||||
A 3x3 convolution layer with some utility functions.
|
||||
"""
|
||||
def __init__(self, inC, outC, last = False):
|
||||
# The properties must be named as "W" and "b".
|
||||
# This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x
|
||||
# though this cannot necessarily account for transposition and other such things.
|
||||
"""
|
||||
A 3x3 convolution layer with some utility functions.
|
||||
"""
|
||||
|
||||
# Massively overstate the weights to get them to be focused on,
|
||||
# since otherwise the biases overrule everything
|
||||
self.W = Tensor.uniform(outC, inC, 3, 3) * 16.0
|
||||
# Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is?
|
||||
# Values-wise, entirely different blatant cheat.
|
||||
# In most cases, use uniform bias, but tiny.
|
||||
# For the last layer, use just 0.5, constant.
|
||||
if last:
|
||||
self.b = Tensor.zeros(1, outC, 1, 1) + 0.5
|
||||
else:
|
||||
self.b = Tensor.uniform(1, outC, 1, 1)
|
||||
def __init__(self, inC, outC, last=False):
|
||||
# The properties must be named as "W" and "b".
|
||||
# This is in an attempt to try and be roughly compatible with https://github.com/FHPythonUtils/Waifu2x
|
||||
# though this cannot necessarily account for transposition and other such things.
|
||||
|
||||
def forward(self, x):
|
||||
# You might be thinking, "but what about padding?"
|
||||
# Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it.
|
||||
return x.conv2d(self.W).add(self.b)
|
||||
# Massively overstate the weights to get them to be focused on,
|
||||
# since otherwise the biases overrule everything
|
||||
self.W = Tensor.uniform(outC, inC, 3, 3) * 16.0
|
||||
# Layout-wise, blatant cheat, but serious_mnist does it. I'd guess channels either have to have a size of 1 or whatever the target is?
|
||||
# Values-wise, entirely different blatant cheat.
|
||||
# In most cases, use uniform bias, but tiny.
|
||||
# For the last layer, use just 0.5, constant.
|
||||
if last:
|
||||
self.b = Tensor.zeros(1, outC, 1, 1) + 0.5
|
||||
else:
|
||||
self.b = Tensor.uniform(1, outC, 1, 1)
|
||||
|
||||
def get_parameters(self) -> list:
|
||||
return [self.W, self.b]
|
||||
def forward(self, x):
|
||||
# You might be thinking, "but what about padding?"
|
||||
# Answer: Tiling is used to stitch everything back together, though you could pad the image before providing it.
|
||||
return x.conv2d(self.W).add(self.b)
|
||||
|
||||
def get_parameters(self) -> list:
|
||||
return [self.W, self.b]
|
||||
|
||||
def load_waifu2x_json(self, layer: dict):
|
||||
# Weights in this file are outChannel,inChannel,X,Y.
|
||||
# Not outChannel,inChannel,Y,X.
|
||||
# Therefore, transpose it before assignment.
|
||||
# I have long since forgotten how I worked this out.
|
||||
self.W.assign(
|
||||
Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3)
|
||||
)
|
||||
self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape))
|
||||
|
||||
def load_waifu2x_json(self, layer: dict):
|
||||
# Weights in this file are outChannel,inChannel,X,Y.
|
||||
# Not outChannel,inChannel,Y,X.
|
||||
# Therefore, transpose it before assignment.
|
||||
# I have long since forgotten how I worked this out.
|
||||
self.W.assign(Tensor(layer["weight"]).reshape(shape=self.W.shape).transpose(2, 3))
|
||||
self.b.assign(Tensor(layer["bias"]).reshape(shape=self.b.shape))
|
||||
|
||||
class Vgg7:
|
||||
"""
|
||||
The 'vgg7' waifu2x network.
|
||||
Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.conv1 = Conv3x3Biased(3, 32)
|
||||
self.conv2 = Conv3x3Biased(32, 32)
|
||||
self.conv3 = Conv3x3Biased(32, 64)
|
||||
self.conv4 = Conv3x3Biased(64, 64)
|
||||
self.conv5 = Conv3x3Biased(64, 128)
|
||||
self.conv6 = Conv3x3Biased(128, 128)
|
||||
self.conv7 = Conv3x3Biased(128, 3, True)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass: Actually runs the network.
|
||||
Input format: (1, 3, Y, X)
|
||||
Output format: (1, 3, Y - 14, X - 14)
|
||||
(the - 14 represents the 7-pixel context border that is lost)
|
||||
The 'vgg7' waifu2x network.
|
||||
Lower quality and slower than even upconv7 (nevermind cunet), but is very easy to implement and test.
|
||||
"""
|
||||
x = self.conv1.forward(x).leakyrelu(0.1)
|
||||
x = self.conv2.forward(x).leakyrelu(0.1)
|
||||
x = self.conv3.forward(x).leakyrelu(0.1)
|
||||
x = self.conv4.forward(x).leakyrelu(0.1)
|
||||
x = self.conv5.forward(x).leakyrelu(0.1)
|
||||
x = self.conv6.forward(x).leakyrelu(0.1)
|
||||
x = self.conv7.forward(x)
|
||||
return x
|
||||
|
||||
def get_parameters(self) -> list:
|
||||
return self.conv1.get_parameters() + self.conv2.get_parameters() + self.conv3.get_parameters() + self.conv4.get_parameters() + self.conv5.get_parameters() + self.conv6.get_parameters() + self.conv7.get_parameters()
|
||||
def __init__(self):
|
||||
self.conv1 = Conv3x3Biased(3, 32)
|
||||
self.conv2 = Conv3x3Biased(32, 32)
|
||||
self.conv3 = Conv3x3Biased(32, 64)
|
||||
self.conv4 = Conv3x3Biased(64, 64)
|
||||
self.conv5 = Conv3x3Biased(64, 128)
|
||||
self.conv6 = Conv3x3Biased(128, 128)
|
||||
self.conv7 = Conv3x3Biased(128, 3, True)
|
||||
|
||||
def load_from_pretrained(self, intent = "art", subtype = "scale2.0x"):
|
||||
"""
|
||||
Downloads a nagadomi/waifu2x JSON weight file and loads it.
|
||||
"""
|
||||
import json
|
||||
data = json.loads(fetch("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json").read_bytes())
|
||||
self.load_waifu2x_json(data)
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass: Actually runs the network.
|
||||
Input format: (1, 3, Y, X)
|
||||
Output format: (1, 3, Y - 14, X - 14)
|
||||
(the - 14 represents the 7-pixel context border that is lost)
|
||||
"""
|
||||
x = self.conv1.forward(x).leakyrelu(0.1)
|
||||
x = self.conv2.forward(x).leakyrelu(0.1)
|
||||
x = self.conv3.forward(x).leakyrelu(0.1)
|
||||
x = self.conv4.forward(x).leakyrelu(0.1)
|
||||
x = self.conv5.forward(x).leakyrelu(0.1)
|
||||
x = self.conv6.forward(x).leakyrelu(0.1)
|
||||
x = self.conv7.forward(x)
|
||||
return x
|
||||
|
||||
def load_waifu2x_json(self, data: list):
|
||||
"""
|
||||
Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json
|
||||
data (passed in) is assumed to be the output of json.load or some similar on such a file
|
||||
"""
|
||||
self.conv1.load_waifu2x_json(data[0])
|
||||
self.conv2.load_waifu2x_json(data[1])
|
||||
self.conv3.load_waifu2x_json(data[2])
|
||||
self.conv4.load_waifu2x_json(data[3])
|
||||
self.conv5.load_waifu2x_json(data[4])
|
||||
self.conv6.load_waifu2x_json(data[5])
|
||||
self.conv7.load_waifu2x_json(data[6])
|
||||
def get_parameters(self) -> list:
|
||||
return (
|
||||
self.conv1.get_parameters()
|
||||
+ self.conv2.get_parameters()
|
||||
+ self.conv3.get_parameters()
|
||||
+ self.conv4.get_parameters()
|
||||
+ self.conv5.get_parameters()
|
||||
+ self.conv6.get_parameters()
|
||||
+ self.conv7.get_parameters()
|
||||
)
|
||||
|
||||
def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray:
|
||||
"""
|
||||
Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it.
|
||||
Note that you really shouldn't try to run anything not (1, 3, *, *) through this.
|
||||
"""
|
||||
# Constant that only really gets repeated a ton here.
|
||||
context = 7
|
||||
context2 = context + context
|
||||
def load_from_pretrained(self, intent="art", subtype="scale2.0x"):
|
||||
"""
|
||||
Downloads a nagadomi/waifu2x JSON weight file and loads it.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Notably, numpy is used here because it makes this fine manipulation a lot simpler.
|
||||
# Scaling first - repeat on axis 2 and axis 3 (Y & X)
|
||||
image = image.repeat(2, 2).repeat(2, 3)
|
||||
data = json.loads(
|
||||
fetch(
|
||||
"https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/"
|
||||
+ intent
|
||||
+ "/"
|
||||
+ subtype
|
||||
+ "_model.json"
|
||||
).read_bytes()
|
||||
)
|
||||
self.load_waifu2x_json(data)
|
||||
|
||||
# Resulting image buffer. This is made before the input is padded,
|
||||
# since the input has the padded shape right now.
|
||||
image_out = numpy.zeros(image.shape)
|
||||
def load_waifu2x_json(self, data: list):
|
||||
"""
|
||||
Loads weights from one of the waifu2x JSON files, i.e. waifu2x/models/vgg_7/art/noise0_model.json
|
||||
data (passed in) is assumed to be the output of json.load or some similar on such a file
|
||||
"""
|
||||
self.conv1.load_waifu2x_json(data[0])
|
||||
self.conv2.load_waifu2x_json(data[1])
|
||||
self.conv3.load_waifu2x_json(data[2])
|
||||
self.conv4.load_waifu2x_json(data[3])
|
||||
self.conv5.load_waifu2x_json(data[4])
|
||||
self.conv6.load_waifu2x_json(data[5])
|
||||
self.conv7.load_waifu2x_json(data[6])
|
||||
|
||||
# Padding next. Note that this padding is done on the whole image.
|
||||
# Padding the tiles would lose critical context, cause seams, etc.
|
||||
image = numpy.pad(image, [[0, 0], [0, 0], [context, context], [context, context]], mode = "edge")
|
||||
def forward_tiled(self, image: numpy.ndarray, tile_size: int) -> numpy.ndarray:
|
||||
"""
|
||||
Given an ndarray image as loaded by image_load (NOT a tensor), scales it, pads it, splits it up, forwards the pieces, and reconstitutes it.
|
||||
Note that you really shouldn't try to run anything not (1, 3, *, *) through this.
|
||||
"""
|
||||
# Constant that only really gets repeated a ton here.
|
||||
context = 7
|
||||
context2 = context + context
|
||||
|
||||
# Now for tiling.
|
||||
# The output tile size is the usable output from an input tile (tile_size).
|
||||
# As such, the tiles overlap.
|
||||
out_tile_size = tile_size - context2
|
||||
for out_y in range(0, image_out.shape[2], out_tile_size):
|
||||
for out_x in range(0, image_out.shape[3], out_tile_size):
|
||||
# Input is sourced from the same coordinates, but some stuff ought to be
|
||||
# noted here for future reference:
|
||||
# + out_x/y's equivalent position w/ the padding is out_x + context.
|
||||
# + The output, however, is without context. Input needs context.
|
||||
# + Therefore, the input rectangle is expanded on all sides by context.
|
||||
# + Therefore, the input position has the context subtracted again.
|
||||
# + Therefore:
|
||||
in_y = out_y
|
||||
in_x = out_x
|
||||
# not shown: in_w/in_h = tile_size (as opposed to out_tile_size)
|
||||
# Extract tile.
|
||||
# Note that numpy will auto-crop this at the bottom-right.
|
||||
# This will never be a problem, as tiles are specifically chosen within the padded section.
|
||||
tile = image[:, :, in_y:in_y + tile_size, in_x:in_x + tile_size]
|
||||
# Extracted tile dimensions -> output dimensions
|
||||
# This is important because of said cropping, otherwise it'd be interior tile size.
|
||||
out_h = tile.shape[2] - context2
|
||||
out_w = tile.shape[3] - context2
|
||||
# Process tile.
|
||||
tile_t = Tensor(tile)
|
||||
tile_fwd_t = self.forward(tile_t)
|
||||
# Replace tile.
|
||||
image_out[:, :, out_y:out_y + out_h, out_x:out_x + out_w] = tile_fwd_t.numpy()
|
||||
# Notably, numpy is used here because it makes this fine manipulation a lot simpler.
|
||||
# Scaling first - repeat on axis 2 and axis 3 (Y & X)
|
||||
image = image.repeat(2, 2).repeat(2, 3)
|
||||
|
||||
return image_out
|
||||
# Resulting image buffer. This is made before the input is padded,
|
||||
# since the input has the padded shape right now.
|
||||
image_out = numpy.zeros(image.shape)
|
||||
|
||||
# Padding next. Note that this padding is done on the whole image.
|
||||
# Padding the tiles would lose critical context, cause seams, etc.
|
||||
image = numpy.pad(
|
||||
image, [[0, 0], [0, 0], [context, context], [context, context]], mode="edge"
|
||||
)
|
||||
|
||||
# Now for tiling.
|
||||
# The output tile size is the usable output from an input tile (tile_size).
|
||||
# As such, the tiles overlap.
|
||||
out_tile_size = tile_size - context2
|
||||
for out_y in range(0, image_out.shape[2], out_tile_size):
|
||||
for out_x in range(0, image_out.shape[3], out_tile_size):
|
||||
# Input is sourced from the same coordinates, but some stuff ought to be
|
||||
# noted here for future reference:
|
||||
# + out_x/y's equivalent position w/ the padding is out_x + context.
|
||||
# + The output, however, is without context. Input needs context.
|
||||
# + Therefore, the input rectangle is expanded on all sides by context.
|
||||
# + Therefore, the input position has the context subtracted again.
|
||||
# + Therefore:
|
||||
in_y = out_y
|
||||
in_x = out_x
|
||||
# not shown: in_w/in_h = tile_size (as opposed to out_tile_size)
|
||||
# Extract tile.
|
||||
# Note that numpy will auto-crop this at the bottom-right.
|
||||
# This will never be a problem, as tiles are specifically chosen within the padded section.
|
||||
tile = image[:, :, in_y : in_y + tile_size, in_x : in_x + tile_size]
|
||||
# Extracted tile dimensions -> output dimensions
|
||||
# This is important because of said cropping, otherwise it'd be interior tile size.
|
||||
out_h = tile.shape[2] - context2
|
||||
out_w = tile.shape[3] - context2
|
||||
# Process tile.
|
||||
tile_t = Tensor(tile)
|
||||
tile_fwd_t = self.forward(tile_t)
|
||||
# Replace tile.
|
||||
image_out[
|
||||
:, :, out_y : out_y + out_h, out_x : out_x + out_w
|
||||
] = tile_fwd_t.numpy()
|
||||
|
||||
return image_out
|
||||
|
|
|
@ -4,6 +4,7 @@ from PIL import Image
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
from extra.models.vit import ViT
|
||||
|
||||
"""
|
||||
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
||||
import tensorflow as tf
|
||||
|
@ -15,27 +16,33 @@ with tf.io.gfile.GFile(fn, "rb") as f:
|
|||
|
||||
Tensor.training = False
|
||||
if getenv("LARGE", 0) == 1:
|
||||
m = ViT(embed_dim=768, num_heads=12)
|
||||
m = ViT(embed_dim=768, num_heads=12)
|
||||
else:
|
||||
# tiny
|
||||
m = ViT(embed_dim=192, num_heads=3)
|
||||
# tiny
|
||||
m = ViT(embed_dim=192, num_heads=3)
|
||||
m.load_from_pretrained()
|
||||
|
||||
# category labels
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
lbls = ast.literal_eval(
|
||||
fetch(
|
||||
"https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
|
||||
).read_text()
|
||||
)
|
||||
|
||||
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
||||
# url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
||||
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
||||
|
||||
# junk
|
||||
img = Image.open(fetch(url))
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
img = img.resize(
|
||||
(int(224 * max(aspect_ratio, 1.0)), int(224 * max(1.0 / aspect_ratio, 1.0)))
|
||||
)
|
||||
img = np.array(img)
|
||||
y0,x0=(np.asarray(img.shape)[:2]-224)//2
|
||||
img = img[y0:y0+224, x0:x0+224]
|
||||
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
|
||||
y0, x0 = (np.asarray(img.shape)[:2] - 224) // 2
|
||||
img = img[y0 : y0 + 224, x0 : x0 + 224]
|
||||
img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
|
||||
img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
|
||||
img /= 255.0
|
||||
img -= 0.5
|
||||
img /= 0.5
|
||||
|
|
2581
examples/vits.py
2581
examples/vits.py
File diff suppressed because it is too large
Load Diff
|
@ -1,7 +1,13 @@
|
|||
import os
|
||||
from extra.export_model import compile_net, jit_model
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
||||
from tinygrad.nn.state import (
|
||||
get_state_dict,
|
||||
safe_save,
|
||||
safe_load_metadata,
|
||||
torch_load,
|
||||
load_state_dict,
|
||||
)
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device
|
||||
from tinygrad.helpers import fetch
|
||||
|
@ -10,102 +16,174 @@ from pathlib import Path
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_f32_to_f16(input_file, output_file):
|
||||
with open(input_file, 'rb') as f:
|
||||
metadata_length_bytes = f.read(8)
|
||||
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
|
||||
metadata_json_bytes = f.read(metadata_length)
|
||||
float32_values = np.fromfile(f, dtype=np.float32)
|
||||
with open(input_file, "rb") as f:
|
||||
metadata_length_bytes = f.read(8)
|
||||
metadata_length = int.from_bytes(
|
||||
metadata_length_bytes, byteorder="little", signed=False
|
||||
)
|
||||
metadata_json_bytes = f.read(metadata_length)
|
||||
float32_values = np.fromfile(f, dtype=np.float32)
|
||||
|
||||
first_text_model_offset = 3772703308
|
||||
num_elements = int((first_text_model_offset)/4)
|
||||
front_float16_values = float32_values[:num_elements].astype(np.float16)
|
||||
rest_float32_values = float32_values[num_elements:]
|
||||
first_text_model_offset = 3772703308
|
||||
num_elements = int((first_text_model_offset) / 4)
|
||||
front_float16_values = float32_values[:num_elements].astype(np.float16)
|
||||
rest_float32_values = float32_values[num_elements:]
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(metadata_length_bytes)
|
||||
f.write(metadata_json_bytes)
|
||||
front_float16_values.tofile(f)
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
f.write(metadata_length_bytes)
|
||||
f.write(metadata_json_bytes)
|
||||
front_float16_values.tofile(f)
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
chunk_size = 536870912
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
chunk_size = 536870912
|
||||
|
||||
for k in metadata:
|
||||
# safetensor is in fp16, except for text moel
|
||||
if (metadata[k]["data_offsets"][0] < text_model_offset):
|
||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
|
||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
|
||||
for k in metadata:
|
||||
# safetensor is in fp16, except for text moel
|
||||
if metadata[k]["data_offsets"][0] < text_model_offset:
|
||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0] / 2)
|
||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1] / 2)
|
||||
|
||||
last_offset = 0
|
||||
part_end_offsets = []
|
||||
last_offset = 0
|
||||
part_end_offsets = []
|
||||
|
||||
for k in metadata:
|
||||
offset = metadata[k]['data_offsets'][0]
|
||||
for k in metadata:
|
||||
offset = metadata[k]["data_offsets"][0]
|
||||
|
||||
if offset == text_model_offset:
|
||||
break
|
||||
if offset == text_model_offset:
|
||||
break
|
||||
|
||||
part_offset = offset - last_offset
|
||||
part_offset = offset - last_offset
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(8+json_len+offset)
|
||||
last_offset = offset
|
||||
if part_offset >= chunk_size:
|
||||
part_end_offsets.append(8 + json_len + offset)
|
||||
last_offset = offset
|
||||
|
||||
text_model_start = int(text_model_offset/2)
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+8+json_len)
|
||||
cur_pos = 0
|
||||
text_model_start = int(text_model_offset / 2)
|
||||
net_bytes = bytes(open(fn, "rb").read())
|
||||
part_end_offsets.append(text_model_start + 8 + json_len)
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(f'./net_part{i}.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
cur_pos = end_pos
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(f"./net_part{i}.safetensors", "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(f'./net_textmodel.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
with open(f"./net_textmodel.safetensors", "wb+") as f:
|
||||
f.write(net_bytes[text_model_start + 8 + json_len :])
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
|
||||
args = parser.parse_args()
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run Stable Diffusion",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remoteweights",
|
||||
action="store_true",
|
||||
help="Use safetensors from Huggingface, or from local",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
Device.DEFAULT = "WEBGPU"
|
||||
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
Tensor.no_grad = True
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
# load in weights
|
||||
load_state_dict(
|
||||
model,
|
||||
torch_load(
|
||||
fetch(
|
||||
"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt",
|
||||
"sd-v1-4.ckpt",
|
||||
)
|
||||
)["state_dict"],
|
||||
strict=False,
|
||||
)
|
||||
|
||||
class Step(NamedTuple):
|
||||
name: str = ""
|
||||
input: List[Tensor] = []
|
||||
forward: Any = None
|
||||
class Step(NamedTuple):
|
||||
name: str = ""
|
||||
input: List[Tensor] = []
|
||||
forward: Any = None
|
||||
|
||||
sub_steps = [
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
|
||||
]
|
||||
sub_steps = [
|
||||
Step(
|
||||
name="textModel",
|
||||
input=[Tensor.randn(1, 77)],
|
||||
forward=model.cond_stage_model.transformer.text_model,
|
||||
),
|
||||
Step(
|
||||
name="diffusor",
|
||||
input=[
|
||||
Tensor.randn(1, 77, 768),
|
||||
Tensor.randn(1, 77, 768),
|
||||
Tensor.randn(1, 4, 64, 64),
|
||||
Tensor.rand(1),
|
||||
Tensor.randn(1),
|
||||
Tensor.randn(1),
|
||||
Tensor.randn(1),
|
||||
],
|
||||
forward=model,
|
||||
),
|
||||
Step(name="decoder", input=[Tensor.randn(1, 4, 64, 64)], forward=model.decode),
|
||||
]
|
||||
|
||||
prg = ""
|
||||
prg = ""
|
||||
|
||||
def compile_step(model, step: Step):
|
||||
run, special_names = jit_model(step, *step.input)
|
||||
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
|
||||
return f"""\n var {step.name} = function() {{
|
||||
def compile_step(model, step: Step):
|
||||
run, special_names = jit_model(step, *step.input)
|
||||
functions, statements, bufs, _ = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
kernel_code = "\n\n".join(
|
||||
[
|
||||
f"const {key} = `{code.replace(key, 'main')}`;"
|
||||
for key, code in functions.items()
|
||||
]
|
||||
)
|
||||
kernel_names = ", ".join([name for (name, _, _, _) in statements])
|
||||
kernel_calls = "\n ".join(
|
||||
[
|
||||
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
|
||||
for i, (_name, args, global_size, _local_size) in enumerate(statements)
|
||||
]
|
||||
)
|
||||
bufs = "\n ".join(
|
||||
[
|
||||
f"const {name} = "
|
||||
+ (
|
||||
f"createEmptyBuf(device, {size});"
|
||||
if _key not in weights
|
||||
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))"
|
||||
)
|
||||
+ ";"
|
||||
for name, (size, dtype, _key) in bufs.items()
|
||||
]
|
||||
)
|
||||
gpu_write_bufs = "\n ".join(
|
||||
[
|
||||
f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
|
||||
for i, (_, value) in enumerate(special_names.items())
|
||||
if "output" not in value
|
||||
]
|
||||
)
|
||||
input_writer = "\n ".join(
|
||||
[
|
||||
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
|
||||
+ f"data{i});"
|
||||
+ f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);"
|
||||
for i, (_, value) in enumerate(special_names.items())
|
||||
if value != "output0"
|
||||
]
|
||||
)
|
||||
return f"""\n var {step.name} = function() {{
|
||||
|
||||
{kernel_code}
|
||||
|
||||
|
@ -142,23 +220,25 @@ if __name__ == "__main__":
|
|||
}}
|
||||
"""
|
||||
|
||||
for step in sub_steps:
|
||||
print(f'Executing step={step.name}')
|
||||
prg += compile_step(model, step)
|
||||
for step in sub_steps:
|
||||
print(f"Executing step={step.name}")
|
||||
prg += compile_step(model, step)
|
||||
|
||||
if step.name == "diffusor":
|
||||
if args.remoteweights:
|
||||
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
|
||||
else:
|
||||
state = get_state_dict(model)
|
||||
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
|
||||
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
|
||||
split_safetensor("./net_conv.safetensors")
|
||||
os.remove("net.safetensors")
|
||||
os.remove("net_conv.safetensors")
|
||||
base_url = "."
|
||||
if step.name == "diffusor":
|
||||
if args.remoteweights:
|
||||
base_url = "https://huggingface.co/wpmed/tinygrad-sd-f16/resolve/main"
|
||||
else:
|
||||
state = get_state_dict(model)
|
||||
safe_save(
|
||||
state, os.path.join(os.path.dirname(__file__), "net.safetensors")
|
||||
)
|
||||
convert_f32_to_f16("./net.safetensors", "./net_conv.safetensors")
|
||||
split_safetensor("./net_conv.safetensors")
|
||||
os.remove("net.safetensors")
|
||||
os.remove("net_conv.safetensors")
|
||||
base_url = "."
|
||||
|
||||
prekernel = f"""
|
||||
prekernel = f"""
|
||||
window.MODEL_BASE_URL= "{base_url}";
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
|
@ -227,5 +307,5 @@ if __name__ == "__main__":
|
|||
passEncoder.end();
|
||||
}};"""
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
|
||||
text_file.write(prekernel + prg)
|
||||
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
|
||||
text_file.write(prekernel + prg)
|
||||
|
|
|
@ -15,338 +15,562 @@ from tinygrad.tensor import Tensor
|
|||
import itertools
|
||||
import librosa
|
||||
|
||||
|
||||
class MultiHeadAttention:
|
||||
def __init__(self, n_state, n_head, kv_caching: Literal['cross', 'self']=None, max_self_attn_cache_len=None):
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
def __init__(
|
||||
self,
|
||||
n_state,
|
||||
n_head,
|
||||
kv_caching: Literal["cross", "self"] = None,
|
||||
max_self_attn_cache_len=None,
|
||||
):
|
||||
self.n_head = n_head
|
||||
self.query = nn.Linear(n_state, n_state)
|
||||
self.key = nn.Linear(n_state, n_state, bias=False)
|
||||
self.value = nn.Linear(n_state, n_state)
|
||||
self.out = nn.Linear(n_state, n_state)
|
||||
|
||||
self.kv_caching = kv_caching
|
||||
self.max_self_attn_cache_len = max_self_attn_cache_len
|
||||
self.kv_caching = kv_caching
|
||||
self.max_self_attn_cache_len = max_self_attn_cache_len
|
||||
|
||||
def __call__(self, x:Tensor, xa:Optional[Tensor]=None, mask:Optional[Tensor]=None, len: Union[Variable,int]=None):
|
||||
if self.kv_caching == 'cross':
|
||||
if xa is not None:
|
||||
k, v = self.key(xa), self.value(xa)
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k, self.cache_v = k, v
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
len: Union[Variable, int] = None,
|
||||
):
|
||||
if self.kv_caching == "cross":
|
||||
if xa is not None:
|
||||
k, v = self.key(xa), self.value(xa)
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = k, v
|
||||
else:
|
||||
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
|
||||
self.cache_k.assign(k + 1 - 1).realize()
|
||||
self.cache_v.assign(v + 1 - 1).realize()
|
||||
else:
|
||||
k, v = self.cache_k, self.cache_v
|
||||
else:
|
||||
# see test_jitted_read_assign in test_jit.py. more context https://github.com/tinygrad/tinygrad/pull/2360#issuecomment-1817989994
|
||||
self.cache_k.assign(k+1-1).realize()
|
||||
self.cache_v.assign(v+1-1).realize()
|
||||
else:
|
||||
k, v = self.cache_k, self.cache_v
|
||||
else:
|
||||
k, v = self.key(x), self.value(x)
|
||||
if self.kv_caching == 'self':
|
||||
if not hasattr(self, 'cache_k'):
|
||||
self.cache_k = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
self.cache_v = Tensor.zeros(x.shape[0], self.max_self_attn_cache_len, x.shape[2])
|
||||
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
||||
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
||||
padding = self.max_self_attn_cache_len-len-x.shape[1]
|
||||
self.cache_k.assign(k.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
self.cache_v.assign(v.pad((None, (0, padding), None)).contiguous()).realize()
|
||||
k, v = self.key(x), self.value(x)
|
||||
if self.kv_caching == "self":
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k = Tensor.zeros(
|
||||
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
|
||||
)
|
||||
self.cache_v = Tensor.zeros(
|
||||
x.shape[0], self.max_self_attn_cache_len, x.shape[2]
|
||||
)
|
||||
k = self.cache_k.shrink((None, (0, len), None)).cat(k, dim=1)
|
||||
v = self.cache_v.shrink((None, (0, len), None)).cat(v, dim=1)
|
||||
padding = self.max_self_attn_cache_len - len - x.shape[1]
|
||||
self.cache_k.assign(
|
||||
k.pad((None, (0, padding), None)).contiguous()
|
||||
).realize()
|
||||
self.cache_v.assign(
|
||||
v.pad((None, (0, padding), None)).contiguous()
|
||||
).realize()
|
||||
|
||||
q = self.query(x)
|
||||
n_ctx = q.shape[1]
|
||||
assert(q.shape[-1] == k.shape[-1] == v.shape[-1])
|
||||
head_dim = q.shape[-1] // self.n_head
|
||||
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
attn = Tensor.scaled_dot_product_attention(q, k, v, mask[:n_ctx,:n_ctx] if mask is not None else None)
|
||||
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
return self.out(wv)
|
||||
q = self.query(x)
|
||||
n_ctx = q.shape[1]
|
||||
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
||||
head_dim = q.shape[-1] // self.n_head
|
||||
q = q.reshape(*q.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
k = k.reshape(*k.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
v = v.reshape(*v.shape[:2], self.n_head, head_dim).permute(0, 2, 1, 3)
|
||||
attn = Tensor.scaled_dot_product_attention(
|
||||
q, k, v, mask[:n_ctx, :n_ctx] if mask is not None else None
|
||||
)
|
||||
wv = attn.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
return self.out(wv)
|
||||
|
||||
|
||||
class ResidualAttentionBlock:
|
||||
def __init__(self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None):
|
||||
self.attn = MultiHeadAttention(n_state, n_head, kv_caching='self' if is_decoder_block else None, max_self_attn_cache_len=max_self_attn_cache_len)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
def __init__(
|
||||
self, n_state, n_head, is_decoder_block=False, max_self_attn_cache_len=None
|
||||
):
|
||||
self.attn = MultiHeadAttention(
|
||||
n_state,
|
||||
n_head,
|
||||
kv_caching="self" if is_decoder_block else None,
|
||||
max_self_attn_cache_len=max_self_attn_cache_len,
|
||||
)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head, kv_caching='cross') if is_decoder_block else None
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head, kv_caching="cross")
|
||||
if is_decoder_block
|
||||
else None
|
||||
)
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if is_decoder_block else None
|
||||
|
||||
self.mlp = [nn.Linear(n_state, n_state*4), Tensor.gelu, nn.Linear(n_state*4, n_state)]
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
self.mlp = [
|
||||
nn.Linear(n_state, n_state * 4),
|
||||
Tensor.gelu,
|
||||
nn.Linear(n_state * 4, n_state),
|
||||
]
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int] = None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||
x = x + self.mlp_ln(x).sequential(self.mlp)
|
||||
return x.realize()
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, len: Union[Variable, int]=None):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, len=len)
|
||||
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa)
|
||||
x = x + self.mlp_ln(x).sequential(self.mlp)
|
||||
return x.realize()
|
||||
|
||||
class AudioEncoder:
|
||||
def __init__(self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_):
|
||||
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1)
|
||||
self.blocks = [ResidualAttentionBlock(n_audio_state, n_audio_head) for _ in range(n_audio_layer)]
|
||||
self.ln_post = nn.LayerNorm(n_audio_state)
|
||||
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
||||
self.encode = TinyJit(self.__call__)
|
||||
def __init__(
|
||||
self, n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer, **_
|
||||
):
|
||||
self.conv1 = nn.Conv1d(n_mels, n_audio_state, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(
|
||||
n_audio_state, n_audio_state, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.blocks = [
|
||||
ResidualAttentionBlock(n_audio_state, n_audio_head)
|
||||
for _ in range(n_audio_layer)
|
||||
]
|
||||
self.ln_post = nn.LayerNorm(n_audio_state)
|
||||
self.positional_embedding = Tensor.empty(n_audio_ctx, n_audio_state)
|
||||
self.encode = TinyJit(self.__call__)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x).gelu()
|
||||
x = self.conv2(x).gelu()
|
||||
x = x.permute(0, 2, 1)
|
||||
x = x + self.positional_embedding[: x.shape[1]]
|
||||
x = x.sequential(self.blocks)
|
||||
x = self.ln_post(x)
|
||||
return x.realize()
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x).gelu()
|
||||
x = self.conv2(x).gelu()
|
||||
x = x.permute(0, 2, 1)
|
||||
x = x + self.positional_embedding[:x.shape[1]]
|
||||
x = x.sequential(self.blocks)
|
||||
x = self.ln_post(x)
|
||||
return x.realize()
|
||||
|
||||
class TextDecoder:
|
||||
def __init__(self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = self.max_tokens_to_sample * 2 + 5 # roughly prompt + start toks + max_tokens_to_sample
|
||||
def __init__(
|
||||
self, n_vocab, n_text_ctx, n_text_state, n_text_head, n_text_layer, **_
|
||||
):
|
||||
self.max_tokens_to_sample = n_text_ctx // 2
|
||||
self.max_self_attn_cache_len = (
|
||||
self.max_tokens_to_sample * 2 + 5
|
||||
) # roughly prompt + start toks + max_tokens_to_sample
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
self.blocks = [ResidualAttentionBlock(n_text_state, n_text_head, is_decoder_block=True, max_self_attn_cache_len=self.max_self_attn_cache_len) for _ in range(n_text_layer)]
|
||||
self.ln = nn.LayerNorm(n_text_state)
|
||||
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
||||
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||
self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||
self.start_output_tok = TinyJit(self.output_tok)
|
||||
self.after_start_output_tok = TinyJit(self.output_tok)
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_text_state)
|
||||
self.positional_embedding = Tensor.empty(n_text_ctx, n_text_state)
|
||||
self.blocks = [
|
||||
ResidualAttentionBlock(
|
||||
n_text_state,
|
||||
n_text_head,
|
||||
is_decoder_block=True,
|
||||
max_self_attn_cache_len=self.max_self_attn_cache_len,
|
||||
)
|
||||
for _ in range(n_text_layer)
|
||||
]
|
||||
self.ln = nn.LayerNorm(n_text_state)
|
||||
self.mask = Tensor.full((n_text_ctx, n_text_ctx), -np.inf).triu(1).realize()
|
||||
self.blocks_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||
self.blocks_after_start_tok = [TinyJit(block.__call__) for block in self.blocks]
|
||||
self.start_output_tok = TinyJit(self.output_tok)
|
||||
self.after_start_output_tok = TinyJit(self.output_tok)
|
||||
|
||||
# if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only')
|
||||
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
|
||||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding[pos:pos+seqlen]
|
||||
if pos == 0:
|
||||
for block in (self.blocks if streaming else self.blocks_start_tok):
|
||||
x = block(x, xa=encoded_audio, mask=self.mask, len=0) # pass xa for cross attn kv caching
|
||||
return self.output_tok(x) if streaming else self.start_output_tok(x)
|
||||
else:
|
||||
for block in self.blocks_after_start_tok:
|
||||
len_v = Variable("self_attn_cache_len", 1, self.max_self_attn_cache_len).bind(pos)
|
||||
x = block(x, mask=self.mask, len=len_v)
|
||||
return self.after_start_output_tok(x)
|
||||
# if layernorm supported symbolic shapes, we wouldn't need this hacky 'streaming' param (which should be called something more descriptive like 'x_is_start_toks_only')
|
||||
def __call__(self, x: Tensor, pos: int, encoded_audio: Tensor, streaming=False):
|
||||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding[pos : pos + seqlen]
|
||||
if pos == 0:
|
||||
for block in self.blocks if streaming else self.blocks_start_tok:
|
||||
x = block(
|
||||
x, xa=encoded_audio, mask=self.mask, len=0
|
||||
) # pass xa for cross attn kv caching
|
||||
return self.output_tok(x) if streaming else self.start_output_tok(x)
|
||||
else:
|
||||
for block in self.blocks_after_start_tok:
|
||||
len_v = Variable(
|
||||
"self_attn_cache_len", 1, self.max_self_attn_cache_len
|
||||
).bind(pos)
|
||||
x = block(x, mask=self.mask, len=len_v)
|
||||
return self.after_start_output_tok(x)
|
||||
|
||||
def output_tok(self, x):
|
||||
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
||||
|
||||
def output_tok(self, x):
|
||||
return (self.ln(x) @ self.token_embedding.weight.T).realize()
|
||||
|
||||
class Whisper:
|
||||
def __init__(self, dims, batch_size=1):
|
||||
self.encoder = AudioEncoder(**dims)
|
||||
self.decoder = TextDecoder(**dims)
|
||||
self.is_multilingual = dims["n_vocab"] == 51865
|
||||
self.batch_size = batch_size
|
||||
def __init__(self, dims, batch_size=1):
|
||||
self.encoder = AudioEncoder(**dims)
|
||||
self.decoder = TextDecoder(**dims)
|
||||
self.is_multilingual = dims["n_vocab"] == 51865
|
||||
self.batch_size = batch_size
|
||||
|
||||
|
||||
RATE = 16000
|
||||
SEGMENT_SECONDS=30
|
||||
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
|
||||
SEGMENT_SECONDS = 30
|
||||
SAMPLES_PER_SEGMENT = RATE * SEGMENT_SECONDS # 480000
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
N_MELS = 80
|
||||
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
|
||||
FRAMES_PER_SEGMENT = SAMPLES_PER_SEGMENT // HOP_LENGTH # 3000
|
||||
|
||||
def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> np.ndarray:
|
||||
"""
|
||||
:param waveforms: A list of possibly variable length 16000Hz audio samples
|
||||
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
|
||||
Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
|
||||
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
|
||||
:return: mel spectrogram of the given waveforms
|
||||
"""
|
||||
def pad_or_trim(arr, target_len):
|
||||
curr_len = len(arr)
|
||||
if curr_len == target_len:
|
||||
return arr
|
||||
elif curr_len < target_len:
|
||||
return np.pad(arr, (0, target_len - curr_len), 'constant')
|
||||
else:
|
||||
return arr[:target_len]
|
||||
|
||||
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
|
||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0: max_len += SAMPLES_PER_SEGMENT - r
|
||||
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
||||
assert waveforms.shape[0] <= batch_size
|
||||
if waveforms.shape[0] < batch_size:
|
||||
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
|
||||
waveforms = np.pad(waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0)))
|
||||
def prep_audio(
|
||||
waveforms: List[np.ndarray], batch_size: int, truncate=False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
:param waveforms: A list of possibly variable length 16000Hz audio samples
|
||||
:param batch_size: The batch_size associated with the Whisper model being used to transcribe the audio.
|
||||
Used to prevent JIT mismatch errors since the encoder does not accept symbolic shapes
|
||||
:param truncate: If true, truncates (or pads) audio to exactly 30s for a single encoder pass
|
||||
:return: mel spectrogram of the given waveforms
|
||||
"""
|
||||
|
||||
stft = librosa.stft(waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann', dtype=np.csingle)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
def pad_or_trim(arr, target_len):
|
||||
curr_len = len(arr)
|
||||
if curr_len == target_len:
|
||||
return arr
|
||||
elif curr_len < target_len:
|
||||
return np.pad(arr, (0, target_len - curr_len), "constant")
|
||||
else:
|
||||
return arr[:target_len]
|
||||
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
max_len = SAMPLES_PER_SEGMENT if truncate else max(len(wav) for wav in waveforms)
|
||||
if (r := max_len % SAMPLES_PER_SEGMENT) > 0:
|
||||
max_len += SAMPLES_PER_SEGMENT - r
|
||||
waveforms = np.array(list(map(lambda w: pad_or_trim(w, max_len), waveforms)))
|
||||
assert waveforms.shape[0] <= batch_size
|
||||
if waveforms.shape[0] < batch_size:
|
||||
# we could have a symbolic batch_size dim instead of manually padding here if conv/layernorm supported symbolic shapes
|
||||
waveforms = np.pad(
|
||||
waveforms, pad_width=((0, batch_size - waveforms.shape[0]), (0, 0))
|
||||
)
|
||||
|
||||
stft = librosa.stft(
|
||||
waveforms, n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann", dtype=np.csingle
|
||||
)
|
||||
magnitudes = np.absolute(stft[..., :-1]) ** 2
|
||||
mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes
|
||||
|
||||
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
|
||||
return log_spec
|
||||
|
||||
return log_spec
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish",
|
||||
"pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese",
|
||||
"he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian",
|
||||
"th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu",
|
||||
"fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian",
|
||||
"br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili",
|
||||
"gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian",
|
||||
"be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole",
|
||||
"ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy",
|
||||
"as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese",
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
|
||||
def get_encoding(encoding_name):
|
||||
with fetch(f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken").open() as f:
|
||||
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in f if line)}
|
||||
n_vocab = len(ranks)
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
||||
n_vocab += len(specials)
|
||||
import tiktoken
|
||||
return tiktoken.Encoding(
|
||||
name=encoding_name,
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens)
|
||||
with fetch(
|
||||
f"https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/{encoding_name}.tiktoken"
|
||||
).open() as f:
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in f if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
special_tokens = dict(zip(specials, itertools.count(n_vocab)))
|
||||
n_vocab += len(specials)
|
||||
import tiktoken
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=encoding_name,
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
MODEL_URLS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
}
|
||||
def init_whisper(model_name="tiny.en", batch_size=1):
|
||||
assert MODEL_URLS[model_name] is not None
|
||||
|
||||
filename = fetch(MODEL_URLS[model_name])
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state['dims'], batch_size)
|
||||
load_state_dict(model, state['model_state_dict'], strict=False)
|
||||
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
||||
return model, enc
|
||||
|
||||
def init_whisper(model_name="tiny.en", batch_size=1):
|
||||
assert MODEL_URLS[model_name] is not None
|
||||
|
||||
filename = fetch(MODEL_URLS[model_name])
|
||||
state = torch_load(filename)
|
||||
model = Whisper(state["dims"], batch_size)
|
||||
load_state_dict(model, state["model_state_dict"], strict=False)
|
||||
enc = get_encoding("multilingual" if model.is_multilingual else "gpt2")
|
||||
return model, enc
|
||||
|
||||
|
||||
def load_file_waveform(filename):
|
||||
waveform, _ = librosa.load(filename, sr=RATE)
|
||||
return waveform
|
||||
waveform, _ = librosa.load(filename, sr=RATE)
|
||||
return waveform
|
||||
|
||||
|
||||
def transcribe_file(model, enc, filename):
|
||||
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
|
||||
return transcribe_waveform(model, enc, [load_file_waveform(filename)])
|
||||
|
||||
|
||||
def transcribe_waveform(model, enc, waveforms, truncate=False):
|
||||
"""
|
||||
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
|
||||
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
|
||||
"""
|
||||
N_audio = len(waveforms)
|
||||
log_spec = prep_audio(waveforms, model.batch_size, truncate)
|
||||
"""
|
||||
Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples
|
||||
Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided
|
||||
"""
|
||||
N_audio = len(waveforms)
|
||||
log_spec = prep_audio(waveforms, model.batch_size, truncate)
|
||||
|
||||
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
|
||||
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
|
||||
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
|
||||
raise Exception("Multi-segment transcription not supported with batch audio input")
|
||||
if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1:
|
||||
# we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch
|
||||
# if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent
|
||||
raise Exception(
|
||||
"Multi-segment transcription not supported with batch audio input"
|
||||
)
|
||||
|
||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
if model.is_multilingual:
|
||||
# TODO detect language
|
||||
language_token = enc._special_tokens["<|startoftranscript|>"] + 1 + tuple(LANGUAGES.keys()).index("en")
|
||||
start_tokens.append(language_token)
|
||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||
transcription_start_index = len(start_tokens)
|
||||
eot = enc._special_tokens["<|endoftext|>"]
|
||||
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
|
||||
start_tokens = [enc._special_tokens["<|startoftranscript|>"]]
|
||||
if model.is_multilingual:
|
||||
# TODO detect language
|
||||
language_token = (
|
||||
enc._special_tokens["<|startoftranscript|>"]
|
||||
+ 1
|
||||
+ tuple(LANGUAGES.keys()).index("en")
|
||||
)
|
||||
start_tokens.append(language_token)
|
||||
start_tokens.append(enc._special_tokens["<|transcribe|>"])
|
||||
start_tokens.append(enc._special_tokens["<|notimestamps|>"])
|
||||
transcription_start_index = len(start_tokens)
|
||||
eot = enc._special_tokens["<|endoftext|>"]
|
||||
transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0]
|
||||
|
||||
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT]))
|
||||
pos = 0
|
||||
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
||||
if curr_frame > 0:
|
||||
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt = np.concatenate((
|
||||
[enc._special_tokens["<|startofprev|>"]],
|
||||
transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:],
|
||||
start_tokens))
|
||||
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
|
||||
transcription_start_index = len(curr_segment_tokens[0])
|
||||
for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT):
|
||||
encoded_audio = model.encoder.encode(
|
||||
Tensor(log_spec[:, :, curr_frame : curr_frame + FRAMES_PER_SEGMENT])
|
||||
)
|
||||
pos = 0
|
||||
curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1))
|
||||
if curr_frame > 0:
|
||||
# pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt = np.concatenate(
|
||||
(
|
||||
[enc._special_tokens["<|startofprev|>"]],
|
||||
transcription_tokens[0][-model.decoder.max_tokens_to_sample + 1 :],
|
||||
start_tokens,
|
||||
)
|
||||
)
|
||||
curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1))
|
||||
transcription_start_index = len(curr_segment_tokens[0])
|
||||
|
||||
for i in range(model.decoder.max_tokens_to_sample):
|
||||
out = model.decoder(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, encoded_audio, streaming=curr_frame > 0)
|
||||
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
||||
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
|
||||
curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1)
|
||||
pos = curr_segment_tokens.shape[-1] - 1
|
||||
if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens)))
|
||||
if (curr_segment_tokens[:, -1] == eot).all():
|
||||
break
|
||||
for i in range(model.decoder.max_tokens_to_sample):
|
||||
out = model.decoder(
|
||||
Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]),
|
||||
pos,
|
||||
encoded_audio,
|
||||
streaming=curr_frame > 0,
|
||||
)
|
||||
next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32)
|
||||
next_tokens[curr_segment_tokens[:, -1] == eot] = eot
|
||||
curr_segment_tokens = np.concatenate(
|
||||
(curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1
|
||||
)
|
||||
pos = curr_segment_tokens.shape[-1] - 1
|
||||
if DEBUG >= 1:
|
||||
print(
|
||||
i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens))
|
||||
)
|
||||
if (curr_segment_tokens[:, -1] == eot).all():
|
||||
break
|
||||
|
||||
for i, t in enumerate(curr_segment_tokens):
|
||||
eot_index = np.where(t == eot)[0]
|
||||
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
||||
transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index]))
|
||||
for i, t in enumerate(curr_segment_tokens):
|
||||
eot_index = np.where(t == eot)[0]
|
||||
eot_index = None if len(eot_index) == 0 else eot_index[0]
|
||||
transcription_tokens[i] = np.concatenate(
|
||||
(transcription_tokens[i], t[transcription_start_index:eot_index])
|
||||
)
|
||||
|
||||
transcriptions = list(
|
||||
map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens)
|
||||
)
|
||||
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
|
||||
|
||||
transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens))
|
||||
return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0]
|
||||
|
||||
CHUNK = 1600
|
||||
RECORD_SECONDS = 10
|
||||
|
||||
|
||||
def listener(q):
|
||||
import pyaudio
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16, channels=1, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
print("listening")
|
||||
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
data = stream.read(CHUNK)
|
||||
waveform = ((np.frombuffer(data, np.int16)/32768).astype(np.float32)*3)
|
||||
q.put(waveform)
|
||||
print("done listening")
|
||||
import pyaudio
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK,
|
||||
)
|
||||
print("listening")
|
||||
for _ in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
data = stream.read(CHUNK)
|
||||
waveform = (np.frombuffer(data, np.int16) / 32768).astype(np.float32) * 3
|
||||
q.put(waveform)
|
||||
print("done listening")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, enc = init_whisper("small.en" if getenv("SMALL") else "tiny.en", batch_size=1)
|
||||
model, enc = init_whisper(
|
||||
"small.en" if getenv("SMALL") else "tiny.en", batch_size=1
|
||||
)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
print(transcribe_file(model, enc, sys.argv[1]))
|
||||
else:
|
||||
# online
|
||||
q = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=listener, args=(q,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
if len(sys.argv) > 1:
|
||||
print(transcribe_file(model, enc, sys.argv[1]))
|
||||
else:
|
||||
# online
|
||||
q = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=listener, args=(q,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
lst = [enc._special_tokens["<|startoftranscript|>"], enc._special_tokens["<|notimestamps|>"]]
|
||||
total = None
|
||||
did_read = False
|
||||
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
while not q.empty() or total is None:
|
||||
waveform = q.get()
|
||||
if total is None: total = waveform
|
||||
else: total = np.concatenate([total, waveform])
|
||||
did_read = True
|
||||
if did_read:
|
||||
log_spec = prep_audio(total.reshape(1, -1), model.batch_size, truncate=True)
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
out = model.decoder(Tensor([lst]), 0, encoded_audio, streaming=True).realize()
|
||||
idx = int(out[0,-1].argmax().numpy().item())
|
||||
lst.append(idx)
|
||||
dec = enc.decode(lst)
|
||||
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
|
||||
if dec.endswith("<|endoftext|>"):
|
||||
lst.pop()
|
||||
lst = [
|
||||
enc._special_tokens["<|startoftranscript|>"],
|
||||
enc._special_tokens["<|notimestamps|>"],
|
||||
]
|
||||
total = None
|
||||
did_read = False
|
||||
for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
|
||||
while not q.empty() or total is None:
|
||||
waveform = q.get()
|
||||
if total is None:
|
||||
total = waveform
|
||||
else:
|
||||
total = np.concatenate([total, waveform])
|
||||
did_read = True
|
||||
if did_read:
|
||||
log_spec = prep_audio(
|
||||
total.reshape(1, -1), model.batch_size, truncate=True
|
||||
)
|
||||
encoded_audio = model.encoder.encode(Tensor(log_spec))
|
||||
# pass the previously inferred tokens as 'prefix' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
out = model.decoder(
|
||||
Tensor([lst]), 0, encoded_audio, streaming=True
|
||||
).realize()
|
||||
idx = int(out[0, -1].argmax().numpy().item())
|
||||
lst.append(idx)
|
||||
dec = enc.decode(lst)
|
||||
print(dec) # DO NOT REMOVE PRINT. IT'S VERY IMPORTANT
|
||||
if dec.endswith("<|endoftext|>"):
|
||||
lst.pop()
|
||||
|
|
|
@ -10,397 +10,462 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import BatchNorm2d, Conv2d
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
|
||||
def show_labels(prediction, confidence=0.5, num_classes=80):
|
||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_bytes()
|
||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
||||
prediction = prediction.detach().numpy()
|
||||
conf_mask = (prediction[:,:,4] > confidence)
|
||||
prediction *= np.expand_dims(conf_mask, 2)
|
||||
labels = []
|
||||
# Iterate over batches
|
||||
for img_pred in prediction:
|
||||
max_conf = np.amax(img_pred[:,5:5+num_classes], axis=1)
|
||||
max_conf_score = np.argmax(img_pred[:,5:5+num_classes], axis=1)
|
||||
max_conf_score = np.expand_dims(max_conf_score, axis=1)
|
||||
max_conf = np.expand_dims(max_conf, axis=1)
|
||||
seq = (img_pred[:,:5], max_conf, max_conf_score)
|
||||
image_pred = np.concatenate(seq, axis=1)
|
||||
non_zero_ind = np.nonzero(image_pred[:,4])[0]
|
||||
assert all(image_pred[non_zero_ind,0] > 0)
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
|
||||
classes, indexes = np.unique(image_pred_[:, -1], return_index=True)
|
||||
for index, coco_class in enumerate(classes):
|
||||
label, probability = coco_labels[int(coco_class)], image_pred_[indexes[index]][4] * 100
|
||||
print(f"Detected {label} {probability:.2f}")
|
||||
labels.append(label)
|
||||
return labels
|
||||
coco_labels = fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||
).read_bytes()
|
||||
coco_labels = coco_labels.decode("utf-8").split("\n")
|
||||
prediction = prediction.detach().numpy()
|
||||
conf_mask = prediction[:, :, 4] > confidence
|
||||
prediction *= np.expand_dims(conf_mask, 2)
|
||||
labels = []
|
||||
# Iterate over batches
|
||||
for img_pred in prediction:
|
||||
max_conf = np.amax(img_pred[:, 5 : 5 + num_classes], axis=1)
|
||||
max_conf_score = np.argmax(img_pred[:, 5 : 5 + num_classes], axis=1)
|
||||
max_conf_score = np.expand_dims(max_conf_score, axis=1)
|
||||
max_conf = np.expand_dims(max_conf, axis=1)
|
||||
seq = (img_pred[:, :5], max_conf, max_conf_score)
|
||||
image_pred = np.concatenate(seq, axis=1)
|
||||
non_zero_ind = np.nonzero(image_pred[:, 4])[0]
|
||||
assert all(image_pred[non_zero_ind, 0] > 0)
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7))
|
||||
classes, indexes = np.unique(image_pred_[:, -1], return_index=True)
|
||||
for index, coco_class in enumerate(classes):
|
||||
label, probability = (
|
||||
coco_labels[int(coco_class)],
|
||||
image_pred_[indexes[index]][4] * 100,
|
||||
)
|
||||
print(f"Detected {label} {probability:.2f}")
|
||||
labels.append(label)
|
||||
return labels
|
||||
|
||||
|
||||
def add_boxes(img, prediction):
|
||||
if isinstance(prediction, int): # no predictions
|
||||
if isinstance(prediction, int): # no predictions
|
||||
return img
|
||||
coco_labels = fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||
)
|
||||
coco_labels = coco_labels.decode("utf-8").split("\n")
|
||||
height, width = img.shape[0:2]
|
||||
scale_factor = 608 / width
|
||||
prediction[:, [1, 3]] -= (608 - scale_factor * width) / 2
|
||||
prediction[:, [2, 4]] -= (608 - scale_factor * height) / 2
|
||||
for pred in prediction:
|
||||
corner1 = tuple(pred[1:3].astype(int))
|
||||
corner2 = tuple(pred[3:5].astype(int))
|
||||
w = corner2[0] - corner1[0]
|
||||
h = corner2[1] - corner1[1]
|
||||
corner2 = (corner2[0] + w, corner2[1] + h)
|
||||
label = coco_labels[int(pred[-1])]
|
||||
img = cv2.rectangle(img, corner1, corner2, (255, 0, 0), 2)
|
||||
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
|
||||
c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4
|
||||
img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1)
|
||||
img = cv2.putText(
|
||||
img,
|
||||
label,
|
||||
(corner1[0], corner1[1] + t_size[1] + 4),
|
||||
cv2.FONT_HERSHEY_PLAIN,
|
||||
1,
|
||||
[225, 255, 255],
|
||||
1,
|
||||
)
|
||||
return img
|
||||
coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names')
|
||||
coco_labels = coco_labels.decode('utf-8').split('\n')
|
||||
height, width = img.shape[0:2]
|
||||
scale_factor = 608 / width
|
||||
prediction[:,[1,3]] -= (608 - scale_factor * width) / 2
|
||||
prediction[:,[2,4]] -= (608 - scale_factor * height) / 2
|
||||
for pred in prediction:
|
||||
corner1 = tuple(pred[1:3].astype(int))
|
||||
corner2 = tuple(pred[3:5].astype(int))
|
||||
w = corner2[0] - corner1[0]
|
||||
h = corner2[1] - corner1[1]
|
||||
corner2 = (corner2[0] + w, corner2[1] + h)
|
||||
label = coco_labels[int(pred[-1])]
|
||||
img = cv2.rectangle(img, corner1, corner2, (255, 0, 0), 2)
|
||||
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0]
|
||||
c2 = corner1[0] + t_size[0] + 3, corner1[1] + t_size[1] + 4
|
||||
img = cv2.rectangle(img, corner1, c2, (255, 0, 0), -1)
|
||||
img = cv2.putText(img, label, (corner1[0], corner1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1)
|
||||
return img
|
||||
|
||||
|
||||
def bbox_iou(box1, box2):
|
||||
"""
|
||||
Returns the IoU of two bounding boxes
|
||||
IoU: IoU = Area Of Overlap / Area of Union -> How close the predicted bounding box is
|
||||
to the ground truth bounding box. Higher IoU = Better accuracy
|
||||
In training, used to track accuracy. with inference, using to remove duplicate bounding boxes
|
||||
"""
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]
|
||||
# get the coordinates of the intersection rectangle
|
||||
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
|
||||
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
|
||||
inter_rect_x2 = np.maximum(b1_x2, b2_x2)
|
||||
inter_rect_y2 = np.maximum(b1_y2, b2_y2)
|
||||
#Intersection area
|
||||
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, 99999)
|
||||
#Union Area
|
||||
b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)
|
||||
b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)
|
||||
iou = inter_area / (b1_area + b2_area - inter_area)
|
||||
return iou
|
||||
"""
|
||||
Returns the IoU of two bounding boxes
|
||||
IoU: IoU = Area Of Overlap / Area of Union -> How close the predicted bounding box is
|
||||
to the ground truth bounding box. Higher IoU = Better accuracy
|
||||
In training, used to track accuracy. with inference, using to remove duplicate bounding boxes
|
||||
"""
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
|
||||
# get the coordinates of the intersection rectangle
|
||||
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
|
||||
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
|
||||
inter_rect_x2 = np.maximum(b1_x2, b2_x2)
|
||||
inter_rect_y2 = np.maximum(b1_y2, b2_y2)
|
||||
# Intersection area
|
||||
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(
|
||||
inter_rect_y2 - inter_rect_y1 + 1, 0, 99999
|
||||
)
|
||||
# Union Area
|
||||
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
|
||||
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
|
||||
iou = inter_area / (b1_area + b2_area - inter_area)
|
||||
return iou
|
||||
|
||||
|
||||
def process_results(prediction, confidence=0.9, num_classes=80, nms_conf=0.4):
|
||||
prediction = prediction.detach().numpy()
|
||||
conf_mask = (prediction[:,:,4] > confidence)
|
||||
conf_mask = np.expand_dims(conf_mask, 2)
|
||||
prediction = prediction * conf_mask
|
||||
# Non max suppression
|
||||
box_corner = prediction
|
||||
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
|
||||
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
|
||||
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
|
||||
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
|
||||
prediction[:,:,:4] = box_corner[:,:,:4]
|
||||
write = False
|
||||
# Process img
|
||||
img_pred = prediction[0]
|
||||
max_conf = np.amax(img_pred[:,5:5+num_classes], axis=1)
|
||||
max_conf_score = np.argmax(img_pred[:,5:5+num_classes], axis=1)
|
||||
max_conf_score = np.expand_dims(max_conf_score, axis=1)
|
||||
max_conf = np.expand_dims(max_conf, axis=1)
|
||||
seq = (img_pred[:,:5], max_conf, max_conf_score)
|
||||
image_pred = np.concatenate(seq, axis=1)
|
||||
non_zero_ind = np.nonzero(image_pred[:,4])[0]
|
||||
assert all(image_pred[non_zero_ind,0] > 0)
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind),:], (-1, 7))
|
||||
if image_pred_.shape[0] == 0:
|
||||
print("No detections found!")
|
||||
return 0
|
||||
for cls in np.unique(image_pred_[:, -1]):
|
||||
# perform NMS, get the detections with one particular class
|
||||
cls_mask = image_pred_*np.expand_dims(image_pred_[:, -1] == cls, axis=1)
|
||||
class_mask_ind = np.squeeze(np.nonzero(cls_mask[:,-2]))
|
||||
# class_mask_ind = np.nonzero()
|
||||
image_pred_class = np.reshape(image_pred_[class_mask_ind], (-1, 7))
|
||||
# sort the detections such that the entry with the maximum objectness
|
||||
# confidence is at the top
|
||||
conf_sort_index = np.argsort(image_pred_class[:,4])
|
||||
image_pred_class = image_pred_class[conf_sort_index]
|
||||
for i in range(image_pred_class.shape[0]):
|
||||
# Get the IOUs of all boxes that come after the one we are looking at in the loop
|
||||
try:
|
||||
ious = bbox_iou(np.expand_dims(image_pred_class[i], axis=0), image_pred_class[i+1:])
|
||||
except:
|
||||
break
|
||||
# Zero out all the detections that have IoU > threshold
|
||||
iou_mask = np.expand_dims((ious < nms_conf), axis=1)
|
||||
image_pred_class[i+1:] *= iou_mask
|
||||
# Remove the non-zero entries
|
||||
non_zero_ind = np.squeeze(np.nonzero(image_pred_class[:,4]))
|
||||
image_pred_class = np.reshape(image_pred_class[non_zero_ind], (-1, 7))
|
||||
batch_ind = np.array([[0]])
|
||||
seq = (batch_ind, image_pred_class)
|
||||
if not write:
|
||||
output, write = np.concatenate(seq, axis=1), True
|
||||
else:
|
||||
out = np.concatenate(seq, axis=1)
|
||||
output = np.concatenate((output,out))
|
||||
return output
|
||||
prediction = prediction.detach().numpy()
|
||||
conf_mask = prediction[:, :, 4] > confidence
|
||||
conf_mask = np.expand_dims(conf_mask, 2)
|
||||
prediction = prediction * conf_mask
|
||||
# Non max suppression
|
||||
box_corner = prediction
|
||||
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
||||
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
||||
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
||||
prediction[:, :, :4] = box_corner[:, :, :4]
|
||||
write = False
|
||||
# Process img
|
||||
img_pred = prediction[0]
|
||||
max_conf = np.amax(img_pred[:, 5 : 5 + num_classes], axis=1)
|
||||
max_conf_score = np.argmax(img_pred[:, 5 : 5 + num_classes], axis=1)
|
||||
max_conf_score = np.expand_dims(max_conf_score, axis=1)
|
||||
max_conf = np.expand_dims(max_conf, axis=1)
|
||||
seq = (img_pred[:, :5], max_conf, max_conf_score)
|
||||
image_pred = np.concatenate(seq, axis=1)
|
||||
non_zero_ind = np.nonzero(image_pred[:, 4])[0]
|
||||
assert all(image_pred[non_zero_ind, 0] > 0)
|
||||
image_pred_ = np.reshape(image_pred[np.squeeze(non_zero_ind), :], (-1, 7))
|
||||
if image_pred_.shape[0] == 0:
|
||||
print("No detections found!")
|
||||
return 0
|
||||
for cls in np.unique(image_pred_[:, -1]):
|
||||
# perform NMS, get the detections with one particular class
|
||||
cls_mask = image_pred_ * np.expand_dims(image_pred_[:, -1] == cls, axis=1)
|
||||
class_mask_ind = np.squeeze(np.nonzero(cls_mask[:, -2]))
|
||||
# class_mask_ind = np.nonzero()
|
||||
image_pred_class = np.reshape(image_pred_[class_mask_ind], (-1, 7))
|
||||
# sort the detections such that the entry with the maximum objectness
|
||||
# confidence is at the top
|
||||
conf_sort_index = np.argsort(image_pred_class[:, 4])
|
||||
image_pred_class = image_pred_class[conf_sort_index]
|
||||
for i in range(image_pred_class.shape[0]):
|
||||
# Get the IOUs of all boxes that come after the one we are looking at in the loop
|
||||
try:
|
||||
ious = bbox_iou(
|
||||
np.expand_dims(image_pred_class[i], axis=0),
|
||||
image_pred_class[i + 1 :],
|
||||
)
|
||||
except:
|
||||
break
|
||||
# Zero out all the detections that have IoU > threshold
|
||||
iou_mask = np.expand_dims((ious < nms_conf), axis=1)
|
||||
image_pred_class[i + 1 :] *= iou_mask
|
||||
# Remove the non-zero entries
|
||||
non_zero_ind = np.squeeze(np.nonzero(image_pred_class[:, 4]))
|
||||
image_pred_class = np.reshape(image_pred_class[non_zero_ind], (-1, 7))
|
||||
batch_ind = np.array([[0]])
|
||||
seq = (batch_ind, image_pred_class)
|
||||
if not write:
|
||||
output, write = np.concatenate(seq, axis=1), True
|
||||
else:
|
||||
out = np.concatenate(seq, axis=1)
|
||||
output = np.concatenate((output, out))
|
||||
return output
|
||||
|
||||
|
||||
def infer(model, img):
|
||||
img = np.array(Image.fromarray(img).resize((608, 608)))
|
||||
img = img[:,:,::-1].transpose((2,0,1))
|
||||
img = img[np.newaxis,:,:,:]/255.0
|
||||
prediction = model.forward(Tensor(img.astype(np.float32)))
|
||||
return prediction
|
||||
img = np.array(Image.fromarray(img).resize((608, 608)))
|
||||
img = img[:, :, ::-1].transpose((2, 0, 1))
|
||||
img = img[np.newaxis, :, :, :] / 255.0
|
||||
prediction = model.forward(Tensor(img.astype(np.float32)))
|
||||
return prediction
|
||||
|
||||
|
||||
def parse_cfg(cfg):
|
||||
# Return a list of blocks
|
||||
lines = cfg.decode("utf-8").split('\n')
|
||||
lines = [x for x in lines if len(x) > 0]
|
||||
lines = [x for x in lines if x[0] != '#']
|
||||
lines = [x.rstrip().lstrip() for x in lines]
|
||||
block, blocks = {}, []
|
||||
for line in lines:
|
||||
if line[0] == "[":
|
||||
if len(block) != 0:
|
||||
blocks.append(block)
|
||||
block = {}
|
||||
block["type"] = line[1:-1].rstrip()
|
||||
else:
|
||||
key,value = line.split("=")
|
||||
block[key.rstrip()] = value.lstrip()
|
||||
blocks.append(block)
|
||||
return blocks
|
||||
# Return a list of blocks
|
||||
lines = cfg.decode("utf-8").split("\n")
|
||||
lines = [x for x in lines if len(x) > 0]
|
||||
lines = [x for x in lines if x[0] != "#"]
|
||||
lines = [x.rstrip().lstrip() for x in lines]
|
||||
block, blocks = {}, []
|
||||
for line in lines:
|
||||
if line[0] == "[":
|
||||
if len(block) != 0:
|
||||
blocks.append(block)
|
||||
block = {}
|
||||
block["type"] = line[1:-1].rstrip()
|
||||
else:
|
||||
key, value = line.split("=")
|
||||
block[key.rstrip()] = value.lstrip()
|
||||
blocks.append(block)
|
||||
return blocks
|
||||
|
||||
|
||||
# TODO: Speed up this function, avoid copying stuff from GPU to CPU
|
||||
def predict_transform(prediction, inp_dim, anchors, num_classes):
|
||||
batch_size = prediction.shape[0]
|
||||
stride = inp_dim // prediction.shape[2]
|
||||
grid_size = inp_dim // stride
|
||||
bbox_attrs = 5 + num_classes
|
||||
num_anchors = len(anchors)
|
||||
prediction = prediction.reshape(shape=(batch_size, bbox_attrs*num_anchors, grid_size*grid_size))
|
||||
prediction = prediction.transpose(1, 2)
|
||||
prediction = prediction.reshape(shape=(batch_size, grid_size*grid_size*num_anchors, bbox_attrs))
|
||||
prediction_cpu = prediction.numpy()
|
||||
for i in (0, 1, 4):
|
||||
prediction_cpu[:,:,i] = 1 / (1 + np.exp(-prediction_cpu[:,:,i]))
|
||||
# Add the center offsets
|
||||
grid = np.arange(grid_size)
|
||||
a, b = np.meshgrid(grid, grid)
|
||||
x_offset = a.reshape((-1, 1))
|
||||
y_offset = b.reshape((-1, 1))
|
||||
x_y_offset = np.concatenate((x_offset, y_offset), 1)
|
||||
x_y_offset = np.tile(x_y_offset, (1, num_anchors))
|
||||
x_y_offset = x_y_offset.reshape((-1,2))
|
||||
x_y_offset = np.expand_dims(x_y_offset, 0)
|
||||
anchors = [(a[0]/stride, a[1]/stride) for a in anchors]
|
||||
anchors = np.tile(anchors, (grid_size*grid_size, 1))
|
||||
anchors = np.expand_dims(anchors, 0)
|
||||
prediction_cpu[:,:,:2] += x_y_offset
|
||||
prediction_cpu[:,:,2:4] = np.exp(prediction_cpu[:,:,2:4])*anchors
|
||||
prediction_cpu[:,:,5:5+num_classes] = 1 / (1 + np.exp(-prediction_cpu[:,:,5:5+num_classes]))
|
||||
prediction_cpu[:,:,:4] *= stride
|
||||
return Tensor(prediction_cpu)
|
||||
batch_size = prediction.shape[0]
|
||||
stride = inp_dim // prediction.shape[2]
|
||||
grid_size = inp_dim // stride
|
||||
bbox_attrs = 5 + num_classes
|
||||
num_anchors = len(anchors)
|
||||
prediction = prediction.reshape(
|
||||
shape=(batch_size, bbox_attrs * num_anchors, grid_size * grid_size)
|
||||
)
|
||||
prediction = prediction.transpose(1, 2)
|
||||
prediction = prediction.reshape(
|
||||
shape=(batch_size, grid_size * grid_size * num_anchors, bbox_attrs)
|
||||
)
|
||||
prediction_cpu = prediction.numpy()
|
||||
for i in (0, 1, 4):
|
||||
prediction_cpu[:, :, i] = 1 / (1 + np.exp(-prediction_cpu[:, :, i]))
|
||||
# Add the center offsets
|
||||
grid = np.arange(grid_size)
|
||||
a, b = np.meshgrid(grid, grid)
|
||||
x_offset = a.reshape((-1, 1))
|
||||
y_offset = b.reshape((-1, 1))
|
||||
x_y_offset = np.concatenate((x_offset, y_offset), 1)
|
||||
x_y_offset = np.tile(x_y_offset, (1, num_anchors))
|
||||
x_y_offset = x_y_offset.reshape((-1, 2))
|
||||
x_y_offset = np.expand_dims(x_y_offset, 0)
|
||||
anchors = [(a[0] / stride, a[1] / stride) for a in anchors]
|
||||
anchors = np.tile(anchors, (grid_size * grid_size, 1))
|
||||
anchors = np.expand_dims(anchors, 0)
|
||||
prediction_cpu[:, :, :2] += x_y_offset
|
||||
prediction_cpu[:, :, 2:4] = np.exp(prediction_cpu[:, :, 2:4]) * anchors
|
||||
prediction_cpu[:, :, 5 : 5 + num_classes] = 1 / (
|
||||
1 + np.exp(-prediction_cpu[:, :, 5 : 5 + num_classes])
|
||||
)
|
||||
prediction_cpu[:, :, :4] *= stride
|
||||
return Tensor(prediction_cpu)
|
||||
|
||||
|
||||
class Darknet:
|
||||
def __init__(self, cfg):
|
||||
self.blocks = parse_cfg(cfg)
|
||||
self.net_info, self.module_list = self.create_modules(self.blocks)
|
||||
print("Modules length:", len(self.module_list))
|
||||
def __init__(self, cfg):
|
||||
self.blocks = parse_cfg(cfg)
|
||||
self.net_info, self.module_list = self.create_modules(self.blocks)
|
||||
print("Modules length:", len(self.module_list))
|
||||
|
||||
def create_modules(self, blocks):
|
||||
net_info = blocks[0] # Info about model hyperparameters
|
||||
prev_filters, filters = 3, None
|
||||
output_filters, module_list = [], []
|
||||
## module
|
||||
for index, x in enumerate(blocks[1:]):
|
||||
module_type = x["type"]
|
||||
module = []
|
||||
if module_type == "convolutional":
|
||||
try:
|
||||
batch_normalize, bias = int(x["batch_normalize"]), False
|
||||
except:
|
||||
batch_normalize, bias = 0, True
|
||||
# layer
|
||||
activation = x["activation"]
|
||||
filters = int(x["filters"])
|
||||
padding = int(x["pad"])
|
||||
pad = (int(x["size"]) - 1) // 2 if padding else 0
|
||||
module.append(Conv2d(prev_filters, filters, int(x["size"]), int(x["stride"]), pad, bias=bias))
|
||||
# BatchNorm2d
|
||||
if batch_normalize:
|
||||
module.append(BatchNorm2d(filters, eps=1e-05, track_running_stats=True))
|
||||
# LeakyReLU activation
|
||||
if activation == "leaky":
|
||||
module.append(lambda x: x.leakyrelu(0.1))
|
||||
elif module_type == "maxpool":
|
||||
size, stride = int(x["size"]), int(x["stride"])
|
||||
module.append(lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride))
|
||||
elif module_type == "upsample":
|
||||
module.append(lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1)))
|
||||
elif module_type == "route":
|
||||
x["layers"] = x["layers"].split(",")
|
||||
# Start of route
|
||||
start = int(x["layers"][0])
|
||||
# End if it exists
|
||||
try:
|
||||
end = int(x["layers"][1])
|
||||
except:
|
||||
end = 0
|
||||
if start > 0: start -= index
|
||||
if end > 0: end -= index
|
||||
module.append(lambda x: x)
|
||||
if end < 0:
|
||||
filters = output_filters[index + start] + output_filters[index + end]
|
||||
else:
|
||||
filters = output_filters[index + start]
|
||||
# Shortcut corresponds to skip connection
|
||||
elif module_type == "shortcut":
|
||||
module.append(lambda x: x)
|
||||
elif module_type == "yolo":
|
||||
mask = list(map(int, x["mask"].split(",")))
|
||||
anchors = [int(a) for a in x["anchors"].split(",")]
|
||||
anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)]
|
||||
module.append([anchors[i] for i in mask])
|
||||
# Append to module_list
|
||||
module_list.append(module)
|
||||
if filters is not None:
|
||||
prev_filters = filters
|
||||
output_filters.append(filters)
|
||||
return (net_info, module_list)
|
||||
def create_modules(self, blocks):
|
||||
net_info = blocks[0] # Info about model hyperparameters
|
||||
prev_filters, filters = 3, None
|
||||
output_filters, module_list = [], []
|
||||
## module
|
||||
for index, x in enumerate(blocks[1:]):
|
||||
module_type = x["type"]
|
||||
module = []
|
||||
if module_type == "convolutional":
|
||||
try:
|
||||
batch_normalize, bias = int(x["batch_normalize"]), False
|
||||
except:
|
||||
batch_normalize, bias = 0, True
|
||||
# layer
|
||||
activation = x["activation"]
|
||||
filters = int(x["filters"])
|
||||
padding = int(x["pad"])
|
||||
pad = (int(x["size"]) - 1) // 2 if padding else 0
|
||||
module.append(
|
||||
Conv2d(
|
||||
prev_filters,
|
||||
filters,
|
||||
int(x["size"]),
|
||||
int(x["stride"]),
|
||||
pad,
|
||||
bias=bias,
|
||||
)
|
||||
)
|
||||
# BatchNorm2d
|
||||
if batch_normalize:
|
||||
module.append(
|
||||
BatchNorm2d(filters, eps=1e-05, track_running_stats=True)
|
||||
)
|
||||
# LeakyReLU activation
|
||||
if activation == "leaky":
|
||||
module.append(lambda x: x.leakyrelu(0.1))
|
||||
elif module_type == "maxpool":
|
||||
size, stride = int(x["size"]), int(x["stride"])
|
||||
module.append(
|
||||
lambda x: x.max_pool2d(kernel_size=(size, size), stride=stride)
|
||||
)
|
||||
elif module_type == "upsample":
|
||||
module.append(
|
||||
lambda x: Tensor(x.numpy().repeat(2, axis=-2).repeat(2, axis=-1))
|
||||
)
|
||||
elif module_type == "route":
|
||||
x["layers"] = x["layers"].split(",")
|
||||
# Start of route
|
||||
start = int(x["layers"][0])
|
||||
# End if it exists
|
||||
try:
|
||||
end = int(x["layers"][1])
|
||||
except:
|
||||
end = 0
|
||||
if start > 0:
|
||||
start -= index
|
||||
if end > 0:
|
||||
end -= index
|
||||
module.append(lambda x: x)
|
||||
if end < 0:
|
||||
filters = (
|
||||
output_filters[index + start] + output_filters[index + end]
|
||||
)
|
||||
else:
|
||||
filters = output_filters[index + start]
|
||||
# Shortcut corresponds to skip connection
|
||||
elif module_type == "shortcut":
|
||||
module.append(lambda x: x)
|
||||
elif module_type == "yolo":
|
||||
mask = list(map(int, x["mask"].split(",")))
|
||||
anchors = [int(a) for a in x["anchors"].split(",")]
|
||||
anchors = [
|
||||
(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)
|
||||
]
|
||||
module.append([anchors[i] for i in mask])
|
||||
# Append to module_list
|
||||
module_list.append(module)
|
||||
if filters is not None:
|
||||
prev_filters = filters
|
||||
output_filters.append(filters)
|
||||
return (net_info, module_list)
|
||||
|
||||
def dump_weights(self):
|
||||
for i in range(len(self.module_list)):
|
||||
module_type = self.blocks[i + 1]["type"]
|
||||
if module_type == "convolutional":
|
||||
print(self.blocks[i + 1]["type"], "weights", i)
|
||||
model = self.module_list[i]
|
||||
conv = model[0]
|
||||
print(conv.weight.numpy()[0][0][0])
|
||||
if conv.bias is not None:
|
||||
print("biases")
|
||||
print(conv.bias.shape)
|
||||
print(conv.bias.numpy()[0][0:5])
|
||||
else:
|
||||
print("None biases for layer", i)
|
||||
def dump_weights(self):
|
||||
for i in range(len(self.module_list)):
|
||||
module_type = self.blocks[i + 1]["type"]
|
||||
if module_type == "convolutional":
|
||||
print(self.blocks[i + 1]["type"], "weights", i)
|
||||
model = self.module_list[i]
|
||||
conv = model[0]
|
||||
print(conv.weight.numpy()[0][0][0])
|
||||
if conv.bias is not None:
|
||||
print("biases")
|
||||
print(conv.bias.shape)
|
||||
print(conv.bias.numpy()[0][0:5])
|
||||
else:
|
||||
print("None biases for layer", i)
|
||||
|
||||
def load_weights(self, url):
|
||||
weights = np.frombuffer(fetch(url), dtype=np.float32)[5:]
|
||||
ptr = 0
|
||||
for i in range(len(self.module_list)):
|
||||
module_type = self.blocks[i + 1]["type"]
|
||||
if module_type == "convolutional":
|
||||
model = self.module_list[i]
|
||||
try: # we have batchnorm, load conv weights without biases, and batchnorm values
|
||||
batch_normalize = int(self.blocks[i+1]["batch_normalize"])
|
||||
except: # no batchnorm, load conv weights + biases
|
||||
batch_normalize = 0
|
||||
conv = model[0]
|
||||
if batch_normalize:
|
||||
bn = model[1]
|
||||
# Get the number of weights of batchnorm
|
||||
num_bn_biases = math.prod(bn.bias.shape)
|
||||
# Load weights
|
||||
bn_biases = Tensor(weights[ptr:ptr + num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
bn_weights = Tensor(weights[ptr:ptr+num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
bn_running_mean = Tensor(weights[ptr:ptr+num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
bn_running_var = Tensor(weights[ptr:ptr+num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
# Cast the loaded weights into dims of model weights
|
||||
bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape))
|
||||
bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape))
|
||||
bn_running_mean = bn_running_mean.reshape(shape=tuple(bn.running_mean.shape))
|
||||
bn_running_var = bn_running_var.reshape(shape=tuple(bn.running_var.shape))
|
||||
# Copy data
|
||||
bn.bias = bn_biases
|
||||
bn.weight = bn_weights
|
||||
bn.running_mean = bn_running_mean
|
||||
bn.running_var = bn_running_var
|
||||
else:
|
||||
# load biases of the conv layer
|
||||
num_biases = math.prod(conv.bias.shape)
|
||||
# Load weights
|
||||
conv_biases = Tensor(weights[ptr: ptr+num_biases])
|
||||
ptr += num_biases
|
||||
# Reshape
|
||||
conv_biases = conv_biases.reshape(shape=tuple(conv.bias.shape))
|
||||
# Copy
|
||||
conv.bias = conv_biases
|
||||
# Load weighys for conv layers
|
||||
num_weights = math.prod(conv.weight.shape)
|
||||
conv_weights = Tensor(weights[ptr:ptr+num_weights])
|
||||
ptr += num_weights
|
||||
conv_weights = conv_weights.reshape(shape=tuple(conv.weight.shape))
|
||||
conv.weight = conv_weights
|
||||
def load_weights(self, url):
|
||||
weights = np.frombuffer(fetch(url), dtype=np.float32)[5:]
|
||||
ptr = 0
|
||||
for i in range(len(self.module_list)):
|
||||
module_type = self.blocks[i + 1]["type"]
|
||||
if module_type == "convolutional":
|
||||
model = self.module_list[i]
|
||||
try: # we have batchnorm, load conv weights without biases, and batchnorm values
|
||||
batch_normalize = int(self.blocks[i + 1]["batch_normalize"])
|
||||
except: # no batchnorm, load conv weights + biases
|
||||
batch_normalize = 0
|
||||
conv = model[0]
|
||||
if batch_normalize:
|
||||
bn = model[1]
|
||||
# Get the number of weights of batchnorm
|
||||
num_bn_biases = math.prod(bn.bias.shape)
|
||||
# Load weights
|
||||
bn_biases = Tensor(weights[ptr : ptr + num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
bn_weights = Tensor(weights[ptr : ptr + num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
bn_running_mean = Tensor(weights[ptr : ptr + num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
bn_running_var = Tensor(weights[ptr : ptr + num_bn_biases])
|
||||
ptr += num_bn_biases
|
||||
# Cast the loaded weights into dims of model weights
|
||||
bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape))
|
||||
bn_weights = bn_weights.reshape(shape=tuple(bn.weight.shape))
|
||||
bn_running_mean = bn_running_mean.reshape(
|
||||
shape=tuple(bn.running_mean.shape)
|
||||
)
|
||||
bn_running_var = bn_running_var.reshape(
|
||||
shape=tuple(bn.running_var.shape)
|
||||
)
|
||||
# Copy data
|
||||
bn.bias = bn_biases
|
||||
bn.weight = bn_weights
|
||||
bn.running_mean = bn_running_mean
|
||||
bn.running_var = bn_running_var
|
||||
else:
|
||||
# load biases of the conv layer
|
||||
num_biases = math.prod(conv.bias.shape)
|
||||
# Load weights
|
||||
conv_biases = Tensor(weights[ptr : ptr + num_biases])
|
||||
ptr += num_biases
|
||||
# Reshape
|
||||
conv_biases = conv_biases.reshape(shape=tuple(conv.bias.shape))
|
||||
# Copy
|
||||
conv.bias = conv_biases
|
||||
# Load weighys for conv layers
|
||||
num_weights = math.prod(conv.weight.shape)
|
||||
conv_weights = Tensor(weights[ptr : ptr + num_weights])
|
||||
ptr += num_weights
|
||||
conv_weights = conv_weights.reshape(shape=tuple(conv.weight.shape))
|
||||
conv.weight = conv_weights
|
||||
|
||||
def forward(self, x):
|
||||
modules = self.blocks[1:]
|
||||
outputs = {} # Cached outputs for route layer
|
||||
detections, write = None, False
|
||||
for i, module in enumerate(modules):
|
||||
module_type = module["type"]
|
||||
if module_type == "convolutional" or module_type == "upsample":
|
||||
for layer in self.module_list[i]:
|
||||
x = layer(x)
|
||||
elif module_type == "route":
|
||||
layers = module["layers"]
|
||||
layers = [int(a) for a in layers]
|
||||
if (layers[0]) > 0:
|
||||
layers[0] = layers[0] - i
|
||||
if len(layers) == 1:
|
||||
x = outputs[i + (layers[0])]
|
||||
else:
|
||||
if (layers[1]) > 0:
|
||||
layers[1] = layers[1] - i
|
||||
map1 = outputs[i + layers[0]]
|
||||
map2 = outputs[i + layers[1]]
|
||||
x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1))
|
||||
elif module_type == "shortcut":
|
||||
from_ = int(module["from"])
|
||||
x = outputs[i - 1] + outputs[i + from_]
|
||||
elif module_type == "yolo":
|
||||
anchors = self.module_list[i][0]
|
||||
inp_dim = int(self.net_info["height"]) # 416
|
||||
num_classes = int(module["classes"])
|
||||
x = predict_transform(x, inp_dim, anchors, num_classes)
|
||||
if not write:
|
||||
detections, write = x, True
|
||||
else:
|
||||
detections = Tensor(
|
||||
np.concatenate((detections.numpy(), x.numpy()), axis=1)
|
||||
)
|
||||
outputs[i] = x
|
||||
return detections
|
||||
|
||||
def forward(self, x):
|
||||
modules = self.blocks[1:]
|
||||
outputs = {} # Cached outputs for route layer
|
||||
detections, write = None, False
|
||||
for i, module in enumerate(modules):
|
||||
module_type = (module["type"])
|
||||
if module_type == "convolutional" or module_type == "upsample":
|
||||
for layer in self.module_list[i]:
|
||||
x = layer(x)
|
||||
elif module_type == "route":
|
||||
layers = module["layers"]
|
||||
layers = [int(a) for a in layers]
|
||||
if (layers[0]) > 0:
|
||||
layers[0] = layers[0] - i
|
||||
if len(layers) == 1:
|
||||
x = outputs[i + (layers[0])]
|
||||
else:
|
||||
if (layers[1]) > 0: layers[1] = layers[1] - i
|
||||
map1 = outputs[i + layers[0]]
|
||||
map2 = outputs[i + layers[1]]
|
||||
x = Tensor(np.concatenate((map1.numpy(), map2.numpy()), axis=1))
|
||||
elif module_type == "shortcut":
|
||||
from_ = int(module["from"])
|
||||
x = outputs[i - 1] + outputs[i + from_]
|
||||
elif module_type == "yolo":
|
||||
anchors = self.module_list[i][0]
|
||||
inp_dim = int(self.net_info["height"]) # 416
|
||||
num_classes = int(module["classes"])
|
||||
x = predict_transform(x, inp_dim, anchors, num_classes)
|
||||
if not write:
|
||||
detections, write = x, True
|
||||
else:
|
||||
detections = Tensor(np.concatenate((detections.numpy(), x.numpy()), axis=1))
|
||||
outputs[i] = x
|
||||
return detections
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg'))
|
||||
print("Loading weights file (237MB). This might take a while…")
|
||||
model.load_weights('https://pjreddie.com/media/files/yolov3.weights')
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
|
||||
if url == 'webcam':
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
while 1:
|
||||
_ = cap.grab() # discard one frame to circumvent capture buffering
|
||||
ret, frame = cap.read()
|
||||
prediction = process_results(infer(model, frame))
|
||||
img = Image.fromarray(frame[:, :, [2,1,0]])
|
||||
boxes = add_boxes(np.array(img.resize((608, 608))), prediction)
|
||||
boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow('yolo', boxes)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
elif url.startswith('http'):
|
||||
img_stream = io.BytesIO(fetch(url))
|
||||
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
|
||||
else:
|
||||
img = cv2.imread(url)
|
||||
st = time.time()
|
||||
print('running inference…')
|
||||
prediction = infer(model, img)
|
||||
print(f'did inference in {(time.time() - st):2f}s')
|
||||
show_labels(prediction)
|
||||
prediction = process_results(prediction)
|
||||
boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction)
|
||||
cv2.imwrite('boxes.jpg', boxes)
|
||||
model = Darknet(
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg"
|
||||
)
|
||||
)
|
||||
print("Loading weights file (237MB). This might take a while…")
|
||||
model.load_weights("https://pjreddie.com/media/files/yolov3.weights")
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
url = "https://github.com/ayooshkathuria/pytorch-yolo-v3/raw/master/dog-cycle-car.png"
|
||||
if url == "webcam":
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
while 1:
|
||||
_ = cap.grab() # discard one frame to circumvent capture buffering
|
||||
ret, frame = cap.read()
|
||||
prediction = process_results(infer(model, frame))
|
||||
img = Image.fromarray(frame[:, :, [2, 1, 0]])
|
||||
boxes = add_boxes(np.array(img.resize((608, 608))), prediction)
|
||||
boxes = cv2.cvtColor(boxes, cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow("yolo", boxes)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
elif url.startswith("http"):
|
||||
img_stream = io.BytesIO(fetch(url))
|
||||
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
|
||||
else:
|
||||
img = cv2.imread(url)
|
||||
st = time.time()
|
||||
print("running inference…")
|
||||
prediction = infer(model, img)
|
||||
print(f"did inference in {(time.time() - st):2f}s")
|
||||
show_labels(prediction)
|
||||
prediction = process_results(prediction)
|
||||
boxes = add_boxes(np.array(Image.fromarray(img).resize((608, 608))), prediction)
|
||||
cv2.imwrite("boxes.jpg", boxes)
|
||||
|
|
|
@ -8,11 +8,14 @@ from tinygrad.tensor import Tensor
|
|||
|
||||
os.chdir("/tmp")
|
||||
if not Path("yolov8n-seg.onnx").is_file():
|
||||
model = YOLO("yolov8n-seg.pt")
|
||||
model.export(format="onnx", imgsz=[480,640])
|
||||
model = YOLO("yolov8n-seg.pt")
|
||||
model.export(format="onnx", imgsz=[480, 640])
|
||||
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
|
||||
# TODO: move get example inputs to onnx
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
input_shapes = {
|
||||
inp.name: tuple(x.dim_value for x in inp.type.tensor_type.shape.dim)
|
||||
for inp in onnx_model.graph.input
|
||||
}
|
||||
print(input_shapes)
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True)
|
||||
run_onnx({"images": Tensor.zeros(1, 3, 480, 640)}, debug=True)
|
||||
|
|
|
@ -9,424 +9,646 @@ import time, sys
|
|||
from tinygrad.helpers import fetch
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
|
||||
#Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||
#The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
|
||||
# Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||
# The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this)
|
||||
|
||||
|
||||
# Pre processing image functions.
|
||||
def compute_transform(
|
||||
image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32
|
||||
):
|
||||
shape = image.shape[:2] # current shape [height, width]
|
||||
new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
r = min(r, 1.0) if not scaleup else r
|
||||
new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r)))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
||||
dw, dh = (np.mod(dw, stride), np.mod(dh, stride)) if auto else (0.0, 0.0)
|
||||
new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad
|
||||
dw /= 2
|
||||
dh /= 2
|
||||
image = (
|
||||
cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
if shape[::-1] != new_unpad
|
||||
else image
|
||||
)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
image = cv2.copyMakeBorder(
|
||||
image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||
)
|
||||
return image
|
||||
|
||||
#Pre processing image functions.
|
||||
def compute_transform(image, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
|
||||
shape = image.shape[:2] # current shape [height, width]
|
||||
new_shape = (new_shape, new_shape) if isinstance(new_shape, int) else new_shape
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
r = min(r, 1.0) if not scaleup else r
|
||||
new_unpad = (int(round(shape[1] * r)), int(round(shape[0] * r)))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
||||
dw, dh = (np.mod(dw, stride), np.mod(dh, stride)) if auto else (0.0, 0.0)
|
||||
new_unpad = (new_shape[1], new_shape[0]) if scaleFill else new_unpad
|
||||
dw /= 2
|
||||
dh /= 2
|
||||
image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) if shape[::-1] != new_unpad else image
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
|
||||
return image
|
||||
|
||||
def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
|
||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||
auto = same_shapes and model_pt
|
||||
im = Tensor([compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im])
|
||||
im = Tensor.stack(im) if im.shape[0] > 1 else im
|
||||
im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return im
|
||||
same_shapes = all(x.shape == im[0].shape for x in im)
|
||||
auto = same_shapes and model_pt
|
||||
im = Tensor(
|
||||
[
|
||||
compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride)
|
||||
for x in im
|
||||
]
|
||||
)
|
||||
im = Tensor.stack(im) if im.shape[0] > 1 else im
|
||||
im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return im
|
||||
|
||||
|
||||
# Post Processing functions
|
||||
def box_area(box):
|
||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
|
||||
|
||||
|
||||
def box_iou(box1, box2):
|
||||
lt = np.maximum(box1[:, None, :2], box2[:, :2])
|
||||
rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
|
||||
wh = np.clip(rb - lt, 0, None)
|
||||
inter = wh[:, :, 0] * wh[:, :, 1]
|
||||
area1 = box_area(box1)[:, None]
|
||||
area2 = box_area(box2)[None, :]
|
||||
iou = inter / (area1 + area2 - inter)
|
||||
return iou
|
||||
lt = np.maximum(box1[:, None, :2], box2[:, :2])
|
||||
rb = np.minimum(box1[:, None, 2:], box2[:, 2:])
|
||||
wh = np.clip(rb - lt, 0, None)
|
||||
inter = wh[:, :, 0] * wh[:, :, 1]
|
||||
area1 = box_area(box1)[:, None]
|
||||
area2 = box_area(box2)[None, :]
|
||||
iou = inter / (area1 + area2 - inter)
|
||||
return iou
|
||||
|
||||
|
||||
def compute_nms(boxes, scores, iou_threshold):
|
||||
order, keep = scores.argsort()[::-1], []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
if order.size == 1:
|
||||
break
|
||||
iou = box_iou(boxes[i][None, :], boxes[order[1:]])
|
||||
inds = np.where(iou.squeeze() <= iou_threshold)[0]
|
||||
order = order[inds + 1]
|
||||
return np.array(keep)
|
||||
order, keep = scores.argsort()[::-1], []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
if order.size == 1:
|
||||
break
|
||||
iou = box_iou(boxes[i][None, :], boxes[order[1:]])
|
||||
inds = np.where(iou.squeeze() <= iou_threshold)[0]
|
||||
order = order[inds + 1]
|
||||
return np.array(keep)
|
||||
|
||||
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, agnostic=False, max_det=300, nc=0, max_wh=7680):
|
||||
prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
|
||||
bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
|
||||
xc = np.amax(prediction[:, 4:4 + nc], axis=1) > conf_thres
|
||||
nm = prediction.shape[1] - nc - 4
|
||||
output = [np.zeros((0, 6 + nm))] * bs
|
||||
|
||||
for xi, x in enumerate(prediction):
|
||||
x = x.swapaxes(0, -1)[xc[xi]]
|
||||
if not x.shape[0]: continue
|
||||
box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
|
||||
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(cls, axis=1, keepdims=True)
|
||||
x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
|
||||
x = x[conf.ravel() > conf_thres]
|
||||
if not x.shape[0]: continue
|
||||
x = x[np.argsort(-x[:, 4])]
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh)
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = compute_nms(boxes, scores, iou_thres)[:max_det]
|
||||
output[xi] = x[i]
|
||||
return output
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres=0.25,
|
||||
iou_thres=0.45,
|
||||
agnostic=False,
|
||||
max_det=300,
|
||||
nc=0,
|
||||
max_wh=7680,
|
||||
):
|
||||
prediction = prediction[0] if isinstance(prediction, (list, tuple)) else prediction
|
||||
bs, nc = prediction.shape[0], nc or (prediction.shape[1] - 4)
|
||||
xc = np.amax(prediction[:, 4 : 4 + nc], axis=1) > conf_thres
|
||||
nm = prediction.shape[1] - nc - 4
|
||||
output = [np.zeros((0, 6 + nm))] * bs
|
||||
|
||||
for xi, x in enumerate(prediction):
|
||||
x = x.swapaxes(0, -1)[xc[xi]]
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
box, cls, mask = np.split(x, [4, 4 + nc], axis=1)
|
||||
conf, j = np.max(cls, axis=1, keepdims=True), np.argmax(
|
||||
cls, axis=1, keepdims=True
|
||||
)
|
||||
x = np.concatenate((xywh2xyxy(box), conf, j.astype(np.float32), mask), axis=1)
|
||||
x = x[conf.ravel() > conf_thres]
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
x = x[np.argsort(-x[:, 4])]
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh)
|
||||
boxes, scores = x[:, :4] + c, x[:, 4]
|
||||
i = compute_nms(boxes, scores, iou_thres)[:max_det]
|
||||
output[xi] = x[i]
|
||||
return output
|
||||
|
||||
|
||||
def postprocess(preds, img, orig_imgs):
|
||||
print('copying to CPU now for post processing')
|
||||
#if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
|
||||
# TODO: make non_max_suppression in tinygrad - to make this faster
|
||||
preds = preds.numpy() if isinstance(preds, Tensor) else preds
|
||||
preds = non_max_suppression(prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300)
|
||||
all_preds = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
if not isinstance(orig_imgs, Tensor):
|
||||
pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
all_preds.append(pred)
|
||||
return all_preds
|
||||
print("copying to CPU now for post processing")
|
||||
# if you are on CPU, this causes an overflow runtime error. doesn't "seem" to make any difference in the predictions though.
|
||||
# TODO: make non_max_suppression in tinygrad - to make this faster
|
||||
preds = preds.numpy() if isinstance(preds, Tensor) else preds
|
||||
preds = non_max_suppression(
|
||||
prediction=preds, conf_thres=0.25, iou_thres=0.7, agnostic=False, max_det=300
|
||||
)
|
||||
all_preds = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
if not isinstance(orig_imgs, Tensor):
|
||||
pred[:, :4] = scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
all_preds.append(pred)
|
||||
return all_preds
|
||||
|
||||
def draw_bounding_boxes_and_save(orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5):
|
||||
color_dict = {label: tuple((((i+1) * 50) % 256, ((i+1) * 100) % 256, ((i+1) * 150) % 256)) for i, label in enumerate(class_labels)}
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
|
||||
def is_bright_color(color):
|
||||
r, g, b = color
|
||||
brightness = (r * 299 + g * 587 + b * 114) / 1000
|
||||
return brightness > 127
|
||||
def draw_bounding_boxes_and_save(
|
||||
orig_img_paths, output_img_paths, all_predictions, class_labels, iou_threshold=0.5
|
||||
):
|
||||
color_dict = {
|
||||
label: tuple(
|
||||
(((i + 1) * 50) % 256, ((i + 1) * 100) % 256, ((i + 1) * 150) % 256)
|
||||
)
|
||||
for i, label in enumerate(class_labels)
|
||||
}
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
|
||||
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(zip(orig_img_paths, output_img_paths, all_predictions)):
|
||||
predictions = np.array(predictions)
|
||||
orig_img = cv2.imread(orig_img_path) if not isinstance(orig_img_path, np.ndarray) else cv2.imdecode(orig_img_path, 1)
|
||||
height, width, _ = orig_img.shape
|
||||
box_thickness = int((height + width) / 400)
|
||||
font_scale = (height + width) / 2500
|
||||
def is_bright_color(color):
|
||||
r, g, b = color
|
||||
brightness = (r * 299 + g * 587 + b * 114) / 1000
|
||||
return brightness > 127
|
||||
|
||||
grouped_preds = defaultdict(list)
|
||||
object_count = defaultdict(int)
|
||||
for img_idx, (orig_img_path, output_img_path, predictions) in enumerate(
|
||||
zip(orig_img_paths, output_img_paths, all_predictions)
|
||||
):
|
||||
predictions = np.array(predictions)
|
||||
orig_img = (
|
||||
cv2.imread(orig_img_path)
|
||||
if not isinstance(orig_img_path, np.ndarray)
|
||||
else cv2.imdecode(orig_img_path, 1)
|
||||
)
|
||||
height, width, _ = orig_img.shape
|
||||
box_thickness = int((height + width) / 400)
|
||||
font_scale = (height + width) / 2500
|
||||
|
||||
for pred_np in predictions:
|
||||
grouped_preds[int(pred_np[-1])].append(pred_np)
|
||||
grouped_preds = defaultdict(list)
|
||||
object_count = defaultdict(int)
|
||||
|
||||
def draw_box_and_label(pred, color):
|
||||
x1, y1, x2, y2, conf, _ = pred
|
||||
x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
|
||||
cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
|
||||
label = f"{class_labels[class_id]} {conf:.2f}"
|
||||
text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
|
||||
label_y, bg_y = (y1 - 4, y1 - text_size[1] - 4) if y1 - text_size[1] - 4 > 0 else (y1 + text_size[1], y1)
|
||||
cv2.rectangle(orig_img, (x1, bg_y), (x1 + text_size[0], bg_y + text_size[1]), color, -1)
|
||||
font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
|
||||
cv2.putText(orig_img, label, (x1, label_y), font, font_scale, font_color, 1, cv2.LINE_AA)
|
||||
for pred_np in predictions:
|
||||
grouped_preds[int(pred_np[-1])].append(pred_np)
|
||||
|
||||
for class_id, pred_list in grouped_preds.items():
|
||||
pred_list = np.array(pred_list)
|
||||
while len(pred_list) > 0:
|
||||
max_conf_idx = np.argmax(pred_list[:, 4])
|
||||
max_conf_pred = pred_list[max_conf_idx]
|
||||
pred_list = np.delete(pred_list, max_conf_idx, axis=0)
|
||||
color = color_dict[class_labels[class_id]]
|
||||
draw_box_and_label(max_conf_pred, color)
|
||||
object_count[class_labels[class_id]] += 1
|
||||
iou_scores = box_iou(np.array([max_conf_pred[:4]]), pred_list[:, :4])
|
||||
low_iou_indices = np.where(iou_scores[0] < iou_threshold)[0]
|
||||
pred_list = pred_list[low_iou_indices]
|
||||
for low_conf_pred in pred_list:
|
||||
draw_box_and_label(low_conf_pred, color)
|
||||
def draw_box_and_label(pred, color):
|
||||
x1, y1, x2, y2, conf, _ = pred
|
||||
x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
|
||||
cv2.rectangle(orig_img, (x1, y1), (x2, y2), color, box_thickness)
|
||||
label = f"{class_labels[class_id]} {conf:.2f}"
|
||||
text_size, _ = cv2.getTextSize(label, font, font_scale, 1)
|
||||
label_y, bg_y = (
|
||||
(y1 - 4, y1 - text_size[1] - 4)
|
||||
if y1 - text_size[1] - 4 > 0
|
||||
else (y1 + text_size[1], y1)
|
||||
)
|
||||
cv2.rectangle(
|
||||
orig_img,
|
||||
(x1, bg_y),
|
||||
(x1 + text_size[0], bg_y + text_size[1]),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
font_color = (0, 0, 0) if is_bright_color(color) else (255, 255, 255)
|
||||
cv2.putText(
|
||||
orig_img,
|
||||
label,
|
||||
(x1, label_y),
|
||||
font,
|
||||
font_scale,
|
||||
font_color,
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
print(f"Image {img_idx + 1}:")
|
||||
print("Objects detected:")
|
||||
for obj, count in object_count.items():
|
||||
print(f"- {obj}: {count}")
|
||||
for class_id, pred_list in grouped_preds.items():
|
||||
pred_list = np.array(pred_list)
|
||||
while len(pred_list) > 0:
|
||||
max_conf_idx = np.argmax(pred_list[:, 4])
|
||||
max_conf_pred = pred_list[max_conf_idx]
|
||||
pred_list = np.delete(pred_list, max_conf_idx, axis=0)
|
||||
color = color_dict[class_labels[class_id]]
|
||||
draw_box_and_label(max_conf_pred, color)
|
||||
object_count[class_labels[class_id]] += 1
|
||||
iou_scores = box_iou(np.array([max_conf_pred[:4]]), pred_list[:, :4])
|
||||
low_iou_indices = np.where(iou_scores[0] < iou_threshold)[0]
|
||||
pred_list = pred_list[low_iou_indices]
|
||||
for low_conf_pred in pred_list:
|
||||
draw_box_and_label(low_conf_pred, color)
|
||||
|
||||
print(f"Image {img_idx + 1}:")
|
||||
print("Objects detected:")
|
||||
for obj, count in object_count.items():
|
||||
print(f"- {obj}: {count}")
|
||||
|
||||
cv2.imwrite(output_img_path, orig_img)
|
||||
print(f"saved detections at {output_img_path}")
|
||||
|
||||
cv2.imwrite(output_img_path, orig_img)
|
||||
print(f'saved detections at {output_img_path}')
|
||||
|
||||
# utility functions for forward pass.
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return c_xy.cat(wh, dim=1)
|
||||
return x1y1.cat(x2y2, dim=1)
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return c_xy.cat(wh, dim=1)
|
||||
return x1y1.cat(x2y2, dim=1)
|
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
for i, stride in enumerate(strides):
|
||||
_, _, h, w = feats[i].shape
|
||||
sx = Tensor.arange(w) + grid_cell_offset
|
||||
sy = Tensor.arange(h) + grid_cell_offset
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
for i, stride in enumerate(strides):
|
||||
_, _, h, w = feats[i].shape
|
||||
sx = Tensor.arange(w) + grid_cell_offset
|
||||
sy = Tensor.arange(h) + grid_cell_offset
|
||||
|
||||
# this is np.meshgrid but in tinygrad
|
||||
sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1)
|
||||
sy = sy.reshape(-1, 1).repeat([1, w]).reshape(-1)
|
||||
# this is np.meshgrid but in tinygrad
|
||||
sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1)
|
||||
sy = sy.reshape(-1, 1).repeat([1, w]).reshape(-1)
|
||||
|
||||
anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2))
|
||||
stride_tensor.append(Tensor.full((h * w), stride))
|
||||
anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
|
||||
stride_tensor = (
|
||||
stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
|
||||
)
|
||||
return anchor_points, stride_tensor
|
||||
|
||||
anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2))
|
||||
stride_tensor.append(Tensor.full((h * w), stride))
|
||||
anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
|
||||
stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
|
||||
return anchor_points, stride_tensor
|
||||
|
||||
# this function is from the original implementation
|
||||
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
||||
if d > 1:
|
||||
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
||||
return p
|
||||
if d > 1:
|
||||
k = (
|
||||
d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
|
||||
) # actual kernel-size
|
||||
if p is None:
|
||||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
||||
return p
|
||||
|
||||
|
||||
def clip_boxes(boxes, shape):
|
||||
boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
|
||||
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
|
||||
return boxes
|
||||
boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
|
||||
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
|
||||
return boxes
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||
gain = ratio_pad if ratio_pad else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
|
||||
pad = ((img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2)
|
||||
boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
|
||||
boxes_np[..., [0, 2]] -= pad[0]
|
||||
boxes_np[..., [1, 3]] -= pad[1]
|
||||
boxes_np[..., :4] /= gain
|
||||
boxes_np = clip_boxes(boxes_np, img0_shape)
|
||||
return boxes_np
|
||||
gain = (
|
||||
ratio_pad
|
||||
if ratio_pad
|
||||
else min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])
|
||||
)
|
||||
pad = (
|
||||
(img1_shape[1] - img0_shape[1] * gain) / 2,
|
||||
(img1_shape[0] - img0_shape[0] * gain) / 2,
|
||||
)
|
||||
boxes_np = boxes.numpy() if isinstance(boxes, Tensor) else boxes
|
||||
boxes_np[..., [0, 2]] -= pad[0]
|
||||
boxes_np[..., [1, 3]] -= pad[1]
|
||||
boxes_np[..., :4] /= gain
|
||||
boxes_np = clip_boxes(boxes_np, img0_shape)
|
||||
return boxes_np
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
xy = x[..., :2] # center x, y
|
||||
wh = x[..., 2:4] # width, height
|
||||
xy1 = xy - wh / 2 # top left x, y
|
||||
xy2 = xy + wh / 2 # bottom right x, y
|
||||
result = np.concatenate((xy1, xy2), axis=-1)
|
||||
return Tensor(result) if isinstance(x, Tensor) else result
|
||||
xy = x[..., :2] # center x, y
|
||||
wh = x[..., 2:4] # width, height
|
||||
xy1 = xy - wh / 2 # top left x, y
|
||||
xy2 = xy + wh / 2 # bottom right x, y
|
||||
result = np.concatenate((xy1, xy2), axis=-1)
|
||||
return Tensor(result) if isinstance(x, Tensor) else result
|
||||
|
||||
|
||||
def get_variant_multiples(variant):
|
||||
return {'n':(0.33, 0.25, 2.0), 's':(0.33, 0.50, 2.0), 'm':(0.67, 0.75, 1.5), 'l':(1.0, 1.0, 1.0), 'x':(1, 1.25, 1.0) }.get(variant, None)
|
||||
return {
|
||||
"n": (0.33, 0.25, 2.0),
|
||||
"s": (0.33, 0.50, 2.0),
|
||||
"m": (0.67, 0.75, 1.5),
|
||||
"l": (1.0, 1.0, 1.0),
|
||||
"x": (1, 1.25, 1.0),
|
||||
}.get(variant, None)
|
||||
|
||||
|
||||
def label_predictions(all_predictions):
|
||||
class_index_count = defaultdict(int)
|
||||
for predictions in all_predictions:
|
||||
predictions = np.array(predictions)
|
||||
for pred_np in predictions:
|
||||
class_id = int(pred_np[-1])
|
||||
class_index_count[class_id] += 1
|
||||
class_index_count = defaultdict(int)
|
||||
for predictions in all_predictions:
|
||||
predictions = np.array(predictions)
|
||||
for pred_np in predictions:
|
||||
class_id = int(pred_np[-1])
|
||||
class_index_count[class_id] += 1
|
||||
|
||||
return dict(class_index_count)
|
||||
return dict(class_index_count)
|
||||
|
||||
#this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling)
|
||||
|
||||
# this is taken from https://github.com/tinygrad/tinygrad/pull/784/files by dc-dc-dc (Now 2 models use upsampling)
|
||||
class Upsample:
|
||||
def __init__(self, scale_factor:int, mode: str = "nearest") -> None:
|
||||
assert mode == "nearest" # only mode supported for now
|
||||
self.mode = mode
|
||||
self.scale_factor = scale_factor
|
||||
def __init__(self, scale_factor: int, mode: str = "nearest") -> None:
|
||||
assert mode == "nearest" # only mode supported for now
|
||||
self.mode = mode
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
assert len(x.shape) > 2 and len(x.shape) <= 5
|
||||
(b, c), _lens = x.shape[:2], len(x.shape[2:])
|
||||
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(
|
||||
*[1, 1, 1] + [self.scale_factor] * _lens
|
||||
)
|
||||
return (
|
||||
tmp.reshape(list(x.shape) + [self.scale_factor] * _lens)
|
||||
.permute(
|
||||
[0, 1]
|
||||
+ list(
|
||||
chain.from_iterable([[y + 2, y + 2 + _lens] for y in range(_lens)])
|
||||
)
|
||||
)
|
||||
.reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
|
||||
)
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
assert len(x.shape) > 2 and len(x.shape) <= 5
|
||||
(b, c), _lens = x.shape[:2], len(x.shape[2:])
|
||||
tmp = x.reshape([b, c, -1] + [1] * _lens) * Tensor.ones(*[1, 1, 1] + [self.scale_factor] * _lens)
|
||||
return tmp.reshape(list(x.shape) + [self.scale_factor] * _lens).permute([0, 1] + list(chain.from_iterable([[y+2, y+2+_lens] for y in range(_lens)]))).reshape([b, c] + [x * self.scale_factor for x in x.shape[2:]])
|
||||
|
||||
class Conv_Block:
|
||||
def __init__(self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None):
|
||||
self.conv = Conv2d(c1,c2, kernel_size, stride, padding=autopad(kernel_size, padding, dilation), bias=False, groups=groups, dilation=dilation)
|
||||
self.bn = BatchNorm2d(c2, eps=0.001)
|
||||
def __init__(
|
||||
self, c1, c2, kernel_size=1, stride=1, groups=1, dilation=1, padding=None
|
||||
):
|
||||
self.conv = Conv2d(
|
||||
c1,
|
||||
c2,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=autopad(kernel_size, padding, dilation),
|
||||
bias=False,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
self.bn = BatchNorm2d(c2, eps=0.001)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.bn(self.conv(x)).silu()
|
||||
|
||||
def __call__(self, x):
|
||||
return self.bn(self.conv(x)).silu()
|
||||
|
||||
class Bottleneck:
|
||||
def __init__(self, c1, c2 , shortcut: bool, g=1, kernels: list = (3,3), channel_factor=0.5):
|
||||
c_ = int(c2 * channel_factor)
|
||||
self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None)
|
||||
self.cv2 = Conv_Block(c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g)
|
||||
self.residual = c1 == c2 and shortcut
|
||||
def __init__(
|
||||
self, c1, c2, shortcut: bool, g=1, kernels: list = (3, 3), channel_factor=0.5
|
||||
):
|
||||
c_ = int(c2 * channel_factor)
|
||||
self.cv1 = Conv_Block(c1, c_, kernel_size=kernels[0], stride=1, padding=None)
|
||||
self.cv2 = Conv_Block(
|
||||
c_, c2, kernel_size=kernels[1], stride=1, padding=None, groups=g
|
||||
)
|
||||
self.residual = c1 == c2 and shortcut
|
||||
|
||||
def __call__(self, x):
|
||||
return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x))
|
||||
|
||||
def __call__(self, x):
|
||||
return x + self.cv2(self.cv1(x)) if self.residual else self.cv2(self.cv1(x))
|
||||
|
||||
class C2f:
|
||||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
|
||||
self.c = int(c2 * e)
|
||||
self.cv1 = Conv_Block(c1, 2 * self.c, 1,)
|
||||
self.cv2 = Conv_Block((2 + n) * self.c, c2, 1)
|
||||
self.bottleneck = [Bottleneck(self.c, self.c, shortcut, g, kernels=[(3, 3), (3, 3)], channel_factor=1.0) for _ in range(n)]
|
||||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
|
||||
self.c = int(c2 * e)
|
||||
self.cv1 = Conv_Block(
|
||||
c1,
|
||||
2 * self.c,
|
||||
1,
|
||||
)
|
||||
self.cv2 = Conv_Block((2 + n) * self.c, c2, 1)
|
||||
self.bottleneck = [
|
||||
Bottleneck(
|
||||
self.c,
|
||||
self.c,
|
||||
shortcut,
|
||||
g,
|
||||
kernels=[(3, 3), (3, 3)],
|
||||
channel_factor=1.0,
|
||||
)
|
||||
for _ in range(n)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
y = list(self.cv1(x).chunk(2, 1))
|
||||
y.extend(m(y[-1]) for m in self.bottleneck)
|
||||
z = y[0]
|
||||
for i in y[1:]:
|
||||
z = z.cat(i, dim=1)
|
||||
return self.cv2(z)
|
||||
|
||||
def __call__(self, x):
|
||||
y= list(self.cv1(x).chunk(2, 1))
|
||||
y.extend(m(y[-1]) for m in self.bottleneck)
|
||||
z = y[0]
|
||||
for i in y[1:]: z = z.cat(i, dim=1)
|
||||
return self.cv2(z)
|
||||
|
||||
class SPPF:
|
||||
def __init__(self, c1, c2, k=5):
|
||||
c_ = c1 // 2 # hidden channels
|
||||
self.cv1 = Conv_Block(c1, c_, 1, 1, padding=None)
|
||||
self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None)
|
||||
def __init__(self, c1, c2, k=5):
|
||||
c_ = c1 // 2 # hidden channels
|
||||
self.cv1 = Conv_Block(c1, c_, 1, 1, padding=None)
|
||||
self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None)
|
||||
|
||||
# TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually.
|
||||
self.maxpool = lambda x : x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1)
|
||||
# TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually.
|
||||
self.maxpool = lambda x: x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(
|
||||
kernel_size=k, stride=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.cv1(x)
|
||||
x2 = self.maxpool(x)
|
||||
x3 = self.maxpool(x2)
|
||||
x4 = self.maxpool(x3)
|
||||
return self.cv2(x.cat(x2, x3, x4, dim=1))
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.cv1(x)
|
||||
x2 = self.maxpool(x)
|
||||
x3 = self.maxpool(x2)
|
||||
x4 = self.maxpool(x3)
|
||||
return self.cv2(x.cat(x2, x3, x4, dim=1))
|
||||
|
||||
class DFL:
|
||||
def __init__(self, c1=16):
|
||||
self.conv = Conv2d(c1, 1, 1, bias=False)
|
||||
x = Tensor.arange(c1)
|
||||
self.conv.weight.assign(x.reshape(1, c1, 1, 1))
|
||||
self.c1 = c1
|
||||
def __init__(self, c1=16):
|
||||
self.conv = Conv2d(c1, 1, 1, bias=False)
|
||||
x = Tensor.arange(c1)
|
||||
self.conv.weight.assign(x.reshape(1, c1, 1, 1))
|
||||
self.c1 = c1
|
||||
|
||||
def __call__(self, x):
|
||||
b, c, a = x.shape # batch, channels, anchors
|
||||
return self.conv(x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)).reshape(b, 4, a)
|
||||
def __call__(self, x):
|
||||
b, c, a = x.shape # batch, channels, anchors
|
||||
return self.conv(
|
||||
x.reshape(b, 4, self.c1, a).transpose(2, 1).softmax(1)
|
||||
).reshape(b, 4, a)
|
||||
|
||||
#backbone
|
||||
|
||||
# backbone
|
||||
class Darknet:
|
||||
def __init__(self, w, r, d):
|
||||
self.b1 = [Conv_Block(c1=3, c2= int(64*w), kernel_size=3, stride=2, padding=1), Conv_Block(int(64*w), int(128*w), kernel_size=3, stride=2, padding=1)]
|
||||
self.b2 = [C2f(c1=int(128*w), c2=int(128*w), n=round(3*d), shortcut=True), Conv_Block(int(128*w), int(256*w), 3, 2, 1), C2f(int(256*w), int(256*w), round(6*d), True)]
|
||||
self.b3 = [Conv_Block(int(256*w), int(512*w), kernel_size=3, stride=2, padding=1), C2f(int(512*w), int(512*w), round(6*d), True)]
|
||||
self.b4 = [Conv_Block(int(512*w), int(512*w*r), kernel_size=3, stride=2, padding=1), C2f(int(512*w*r), int(512*w*r), round(3*d), True)]
|
||||
self.b5 = [SPPF(int(512*w*r), int(512*w*r), 5)]
|
||||
def __init__(self, w, r, d):
|
||||
self.b1 = [
|
||||
Conv_Block(c1=3, c2=int(64 * w), kernel_size=3, stride=2, padding=1),
|
||||
Conv_Block(int(64 * w), int(128 * w), kernel_size=3, stride=2, padding=1),
|
||||
]
|
||||
self.b2 = [
|
||||
C2f(c1=int(128 * w), c2=int(128 * w), n=round(3 * d), shortcut=True),
|
||||
Conv_Block(int(128 * w), int(256 * w), 3, 2, 1),
|
||||
C2f(int(256 * w), int(256 * w), round(6 * d), True),
|
||||
]
|
||||
self.b3 = [
|
||||
Conv_Block(int(256 * w), int(512 * w), kernel_size=3, stride=2, padding=1),
|
||||
C2f(int(512 * w), int(512 * w), round(6 * d), True),
|
||||
]
|
||||
self.b4 = [
|
||||
Conv_Block(
|
||||
int(512 * w), int(512 * w * r), kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
C2f(int(512 * w * r), int(512 * w * r), round(3 * d), True),
|
||||
]
|
||||
self.b5 = [SPPF(int(512 * w * r), int(512 * w * r), 5)]
|
||||
|
||||
def return_modules(self):
|
||||
return [*self.b1, *self.b2, *self.b3, *self.b4, *self.b5]
|
||||
def return_modules(self):
|
||||
return [*self.b1, *self.b2, *self.b3, *self.b4, *self.b5]
|
||||
|
||||
def __call__(self, x):
|
||||
x1 = x.sequential(self.b1)
|
||||
x2 = x1.sequential(self.b2)
|
||||
x3 = x2.sequential(self.b3)
|
||||
x4 = x3.sequential(self.b4)
|
||||
x5 = x4.sequential(self.b5)
|
||||
return (x2, x3, x5)
|
||||
def __call__(self, x):
|
||||
x1 = x.sequential(self.b1)
|
||||
x2 = x1.sequential(self.b2)
|
||||
x3 = x2.sequential(self.b3)
|
||||
x4 = x3.sequential(self.b4)
|
||||
x5 = x4.sequential(self.b5)
|
||||
return (x2, x3, x5)
|
||||
|
||||
#yolo fpn (neck)
|
||||
|
||||
# yolo fpn (neck)
|
||||
class Yolov8NECK:
|
||||
def __init__(self, w, r, d): #width_multiple, ratio_multiple, depth_multiple
|
||||
self.up = Upsample(2, mode='nearest')
|
||||
self.n1 = C2f(c1=int(512*w*(1+r)), c2=int(512*w), n=round(3*d), shortcut=False)
|
||||
self.n2 = C2f(c1=int(768*w), c2=int(256*w), n=round(3*d), shortcut=False)
|
||||
self.n3 = Conv_Block(c1=int(256*w), c2=int(256*w), kernel_size=3, stride=2, padding=1)
|
||||
self.n4 = C2f(c1=int(768*w), c2=int(512*w), n=round(3*d), shortcut=False)
|
||||
self.n5 = Conv_Block(c1=int(512* w), c2=int(512 * w), kernel_size=3, stride=2, padding=1)
|
||||
self.n6 = C2f(c1=int(512*w*(1+r)), c2=int(512*w*r), n=round(3*d), shortcut=False)
|
||||
def __init__(self, w, r, d): # width_multiple, ratio_multiple, depth_multiple
|
||||
self.up = Upsample(2, mode="nearest")
|
||||
self.n1 = C2f(
|
||||
c1=int(512 * w * (1 + r)), c2=int(512 * w), n=round(3 * d), shortcut=False
|
||||
)
|
||||
self.n2 = C2f(c1=int(768 * w), c2=int(256 * w), n=round(3 * d), shortcut=False)
|
||||
self.n3 = Conv_Block(
|
||||
c1=int(256 * w), c2=int(256 * w), kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.n4 = C2f(c1=int(768 * w), c2=int(512 * w), n=round(3 * d), shortcut=False)
|
||||
self.n5 = Conv_Block(
|
||||
c1=int(512 * w), c2=int(512 * w), kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.n6 = C2f(
|
||||
c1=int(512 * w * (1 + r)),
|
||||
c2=int(512 * w * r),
|
||||
n=round(3 * d),
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
def return_modules(self):
|
||||
return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6]
|
||||
def return_modules(self):
|
||||
return [self.n1, self.n2, self.n3, self.n4, self.n5, self.n6]
|
||||
|
||||
def __call__(self, p3, p4, p5):
|
||||
x = self.n1(self.up(p5).cat(p4, dim=1))
|
||||
head_1 = self.n2(self.up(x).cat(p3, dim=1))
|
||||
head_2 = self.n4(self.n3(head_1).cat(x, dim=1))
|
||||
head_3 = self.n6(self.n5(head_2).cat(p5, dim=1))
|
||||
return [head_1, head_2, head_3]
|
||||
def __call__(self, p3, p4, p5):
|
||||
x = self.n1(self.up(p5).cat(p4, dim=1))
|
||||
head_1 = self.n2(self.up(x).cat(p3, dim=1))
|
||||
head_2 = self.n4(self.n3(head_1).cat(x, dim=1))
|
||||
head_3 = self.n6(self.n5(head_2).cat(p5, dim=1))
|
||||
return [head_1, head_2, head_3]
|
||||
|
||||
#task specific head.
|
||||
|
||||
# task specific head.
|
||||
class DetectionHead:
|
||||
def __init__(self, nc=80, filters=()):
|
||||
self.ch = 16
|
||||
self.nc = nc # number of classes
|
||||
self.nl = len(filters)
|
||||
self.no = nc + self.ch * 4 #
|
||||
self.stride = [8, 16, 32]
|
||||
c1 = max(filters[0], self.nc)
|
||||
c2 = max((filters[0] // 4, self.ch * 4))
|
||||
self.dfl = DFL(self.ch)
|
||||
self.cv3 = [[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)] for x in filters]
|
||||
self.cv2 = [[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)] for x in filters]
|
||||
def __init__(self, nc=80, filters=()):
|
||||
self.ch = 16
|
||||
self.nc = nc # number of classes
|
||||
self.nl = len(filters)
|
||||
self.no = nc + self.ch * 4 #
|
||||
self.stride = [8, 16, 32]
|
||||
c1 = max(filters[0], self.nc)
|
||||
c2 = max((filters[0] // 4, self.ch * 4))
|
||||
self.dfl = DFL(self.ch)
|
||||
self.cv3 = [
|
||||
[Conv_Block(x, c1, 3), Conv_Block(c1, c1, 3), Conv2d(c1, self.nc, 1)]
|
||||
for x in filters
|
||||
]
|
||||
self.cv2 = [
|
||||
[Conv_Block(x, c2, 3), Conv_Block(c2, c2, 3), Conv2d(c2, 4 * self.ch, 1)]
|
||||
for x in filters
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
for i in range(self.nl):
|
||||
x[i] = x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1)
|
||||
self.anchors, self.strides = (
|
||||
x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)
|
||||
)
|
||||
y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
|
||||
x_cat = y[0].cat(y[1], y[2], dim=2)
|
||||
box, cls = x_cat[:, : self.ch * 4], x_cat[:, self.ch * 4 :]
|
||||
dbox = (
|
||||
dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1)
|
||||
* self.strides
|
||||
)
|
||||
z = dbox.cat(cls.sigmoid(), dim=1)
|
||||
return z
|
||||
|
||||
def __call__(self, x):
|
||||
for i in range(self.nl):
|
||||
x[i] = (x[i].sequential(self.cv2[i]).cat(x[i].sequential(self.cv3[i]), dim=1))
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
|
||||
x_cat = y[0].cat(y[1], y[2], dim=2)
|
||||
box, cls = x_cat[:, :self.ch * 4], x_cat[:, self.ch * 4:]
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
z = dbox.cat(cls.sigmoid(), dim=1)
|
||||
return z
|
||||
|
||||
class YOLOv8:
|
||||
def __init__(self, w, r, d, num_classes): #width_multiple, ratio_multiple, depth_multiple
|
||||
self.net = Darknet(w, r, d)
|
||||
self.fpn = Yolov8NECK(w, r, d)
|
||||
self.head = DetectionHead(num_classes, filters=(int(256*w), int(512*w), int(512*w*r)))
|
||||
def __init__(
|
||||
self, w, r, d, num_classes
|
||||
): # width_multiple, ratio_multiple, depth_multiple
|
||||
self.net = Darknet(w, r, d)
|
||||
self.fpn = Yolov8NECK(w, r, d)
|
||||
self.head = DetectionHead(
|
||||
num_classes, filters=(int(256 * w), int(512 * w), int(512 * w * r))
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.net(x)
|
||||
x = self.fpn(*x)
|
||||
return self.head(x)
|
||||
def __call__(self, x):
|
||||
x = self.net(x)
|
||||
x = self.fpn(*x)
|
||||
return self.head(x)
|
||||
|
||||
def return_all_trainable_modules(self):
|
||||
backbone_modules = [*range(10)]
|
||||
yolov8neck_modules = [12, 15, 16, 18, 19, 21]
|
||||
yolov8_head_weights = [(22, self.head)]
|
||||
return [*zip(backbone_modules, self.net.return_modules()), *zip(yolov8neck_modules, self.fpn.return_modules()), *yolov8_head_weights]
|
||||
def return_all_trainable_modules(self):
|
||||
backbone_modules = [*range(10)]
|
||||
yolov8neck_modules = [12, 15, 16, 18, 19, 21]
|
||||
yolov8_head_weights = [(22, self.head)]
|
||||
return [
|
||||
*zip(backbone_modules, self.net.return_modules()),
|
||||
*zip(yolov8neck_modules, self.fpn.return_modules()),
|
||||
*yolov8_head_weights,
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
|
||||
if len(sys.argv) < 2:
|
||||
print("Error: Image URL or path not provided.")
|
||||
sys.exit(1)
|
||||
if __name__ == "__main__":
|
||||
# usage : python3 yolov8.py "image_URL OR image_path" "v8 variant" (optional, n is default)
|
||||
if len(sys.argv) < 2:
|
||||
print("Error: Image URL or path not provided.")
|
||||
sys.exit(1)
|
||||
|
||||
img_path = sys.argv[1]
|
||||
yolo_variant = sys.argv[2] if len(sys.argv) >= 3 else (print("No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']") or 'n')
|
||||
print(f'running inference for YOLO version {yolo_variant}')
|
||||
img_path = sys.argv[1]
|
||||
yolo_variant = (
|
||||
sys.argv[2]
|
||||
if len(sys.argv) >= 3
|
||||
else (
|
||||
print(
|
||||
"No variant given, so choosing 'n' as the default. Yolov8 has different variants, you can choose from ['n', 's', 'm', 'l', 'x']"
|
||||
)
|
||||
or "n"
|
||||
)
|
||||
)
|
||||
print(f"running inference for YOLO version {yolo_variant}")
|
||||
|
||||
output_folder_path = Path('./outputs_yolov8')
|
||||
output_folder_path.mkdir(parents=True, exist_ok=True)
|
||||
#absolute image path or URL
|
||||
image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
|
||||
image = [cv2.imdecode(image_location[0], 1)]
|
||||
out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}").as_posix()]
|
||||
if not isinstance(image[0], np.ndarray):
|
||||
print('Error in image loading. Check your image file.')
|
||||
sys.exit(1)
|
||||
pre_processed_image = preprocess(image)
|
||||
output_folder_path = Path("./outputs_yolov8")
|
||||
output_folder_path.mkdir(parents=True, exist_ok=True)
|
||||
# absolute image path or URL
|
||||
image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)]
|
||||
image = [cv2.imdecode(image_location[0], 1)]
|
||||
out_paths = [
|
||||
(
|
||||
output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}"
|
||||
).as_posix()
|
||||
]
|
||||
if not isinstance(image[0], np.ndarray):
|
||||
print("Error in image loading. Check your image file.")
|
||||
sys.exit(1)
|
||||
pre_processed_image = preprocess(image)
|
||||
|
||||
# Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/v8/yolov8.yaml
|
||||
depth, width, ratio = get_variant_multiples(yolo_variant)
|
||||
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
# Different YOLOv8 variants use different w , r, and d multiples. For a list , refer to this yaml file (the scales section) https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/v8/yolov8.yaml
|
||||
depth, width, ratio = get_variant_multiples(yolo_variant)
|
||||
yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
|
||||
state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors'))
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
state_dict = safe_load(
|
||||
fetch(
|
||||
f"https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors"
|
||||
)
|
||||
)
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
|
||||
st = time.time()
|
||||
predictions = yolo_infer(pre_processed_image)
|
||||
print(f'did inference in {int(round(((time.time() - st) * 1000)))}ms')
|
||||
st = time.time()
|
||||
predictions = yolo_infer(pre_processed_image)
|
||||
print(f"did inference in {int(round(((time.time() - st) * 1000)))}ms")
|
||||
|
||||
post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image)
|
||||
post_predictions = postprocess(
|
||||
preds=predictions, img=pre_processed_image, orig_imgs=image
|
||||
)
|
||||
|
||||
#v8 and v3 have same 80 class names for Object Detection
|
||||
class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n")
|
||||
# v8 and v3 have same 80 class names for Object Detection
|
||||
class_labels = (
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
|
||||
)
|
||||
.read_text()
|
||||
.split("\n")
|
||||
)
|
||||
|
||||
draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels)
|
||||
draw_bounding_boxes_and_save(
|
||||
orig_img_paths=image_location,
|
||||
output_img_paths=out_paths,
|
||||
all_predictions=post_predictions,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# TODO for later:
|
||||
# 1. Fix SPPF minor difference due to maxpool
|
||||
# 2. AST exp overflow warning while on cpu
|
||||
# 3. Make NMS faster
|
||||
# 4. Add video inference and webcam support
|
||||
# 4. Add video inference and webcam support
|
||||
|
|
|
@ -6,25 +6,32 @@ from coremltools.models.neural_network import datatypes, NeuralNetworkBuilder
|
|||
# KxK GEMM with bias
|
||||
K = 64
|
||||
|
||||
input_features = [('image', datatypes.Array(K))]
|
||||
input_features2 = [('image2', datatypes.Array(K))]
|
||||
output_features = [('probs', datatypes.Array(K))]
|
||||
input_features = [("image", datatypes.Array(K))]
|
||||
input_features2 = [("image2", datatypes.Array(K))]
|
||||
output_features = [("probs", datatypes.Array(K))]
|
||||
|
||||
weights = np.zeros((K, K)) + 3
|
||||
bias = np.ones(K)
|
||||
|
||||
builder = NeuralNetworkBuilder(input_features+input_features2, output_features)
|
||||
builder = NeuralNetworkBuilder(input_features + input_features2, output_features)
|
||||
|
||||
#builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med')
|
||||
#builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs')
|
||||
builder.add_elementwise(name='element', input_names=['image', 'image2'], output_name='probs', mode='ADD')
|
||||
#builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,))
|
||||
#builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs')
|
||||
# builder.add_inner_product(name='ip_layer', W=weights, b=None, input_channels=K, output_channels=K, has_bias=False, input_name='image', output_name='med')
|
||||
# builder.add_inner_product(name='ip_layer_2', W=weights, b=None, input_channels=3, output_channels=3, has_bias=False, input_name='med', output_name='probs')
|
||||
builder.add_elementwise(
|
||||
name="element", input_names=["image", "image2"], output_name="probs", mode="ADD"
|
||||
)
|
||||
# builder.add_bias(name='bias', b=bias, input_name='med', output_name='probs', shape_bias=(K,))
|
||||
# builder.add_activation(name='act_layer', non_linearity='SIGMOID', input_name='med', output_name='probs')
|
||||
|
||||
# compile the spec
|
||||
mlmodel = ct.models.MLModel(builder.spec)
|
||||
|
||||
# trigger the ANE!
|
||||
out = mlmodel.predict({"image": np.zeros(K, dtype=np.float32)+1, "image2": np.zeros(K, dtype=np.float32)+2})
|
||||
out = mlmodel.predict(
|
||||
{
|
||||
"image": np.zeros(K, dtype=np.float32) + 1,
|
||||
"image2": np.zeros(K, dtype=np.float32) + 2,
|
||||
}
|
||||
)
|
||||
print(out)
|
||||
mlmodel.save('test.mlmodel')
|
||||
mlmodel.save("test.mlmodel")
|
||||
|
|
|
@ -5,13 +5,13 @@ import networkx as nx
|
|||
import pylab as plt
|
||||
from networkx.drawing.nx_pydot import read_dot
|
||||
|
||||
ret = os.system("./a.out "+sys.argv[1]+" debug")
|
||||
assert(ret == 0)
|
||||
ret = os.system("./a.out " + sys.argv[1] + " debug")
|
||||
assert ret == 0
|
||||
|
||||
df = "debug/model.hwx.zinir_graph_after_reg_spill.dot"
|
||||
|
||||
#from graphviz import render
|
||||
#render('dot', 'png', df)
|
||||
# from graphviz import render
|
||||
# render('dot', 'png', df)
|
||||
|
||||
#plt = Image(pdot.create_png()
|
||||
#display(plt)
|
||||
# plt = Image(pdot.create_png()
|
||||
# display(plt)
|
||||
|
|
|
@ -3,138 +3,155 @@ import sys
|
|||
from hexdump import hexdump
|
||||
from macholib import MachO
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
||||
def get_macho(fn):
|
||||
# mod to make the header okay
|
||||
# MH_CIGAM_64 is good
|
||||
dat = open(fn, "rb").read()
|
||||
dat = b"\xcf\xfa\xed\xfe"+dat[4:]
|
||||
from tempfile import NamedTemporaryFile
|
||||
with NamedTemporaryFile(delete=False) as f:
|
||||
f.write(dat)
|
||||
f.close()
|
||||
return MachO.MachO(f.name)
|
||||
# mod to make the header okay
|
||||
# MH_CIGAM_64 is good
|
||||
dat = open(fn, "rb").read()
|
||||
dat = b"\xcf\xfa\xed\xfe" + dat[4:]
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
with NamedTemporaryFile(delete=False) as f:
|
||||
f.write(dat)
|
||||
f.close()
|
||||
return MachO.MachO(f.name)
|
||||
|
||||
|
||||
a = get_macho("model.hwx.golden")
|
||||
|
||||
# load commands
|
||||
for c in a.headers[0].commands:
|
||||
print("command", c[0], c[1])
|
||||
if c[0].cmd == 4:
|
||||
hexdump(c[2])
|
||||
pass
|
||||
if c[0].cmd == 6:
|
||||
print("name:", c[2].decode('utf-8'))
|
||||
if c[0].cmd == 8:
|
||||
print(c[2].decode('utf-8'))
|
||||
if c[0].cmd == 25:
|
||||
for section in c[2]:
|
||||
print(section.segname.strip(b'\0'), section.sectname.strip(b'\0'), hex(section.addr), hex(section.size), "@", hex(c[1].fileoff))
|
||||
#print(dir(section))
|
||||
if c[1].filesize > 0:
|
||||
if len(section.section_data) < 0x100:
|
||||
hexdump(section.section_data)
|
||||
else:
|
||||
print("in file, not dumping 0x%x" % len(section.section_data))
|
||||
print("command", c[0], c[1])
|
||||
if c[0].cmd == 4:
|
||||
hexdump(c[2])
|
||||
pass
|
||||
if c[0].cmd == 6:
|
||||
print("name:", c[2].decode("utf-8"))
|
||||
if c[0].cmd == 8:
|
||||
print(c[2].decode("utf-8"))
|
||||
if c[0].cmd == 25:
|
||||
for section in c[2]:
|
||||
print(
|
||||
section.segname.strip(b"\0"),
|
||||
section.sectname.strip(b"\0"),
|
||||
hex(section.addr),
|
||||
hex(section.size),
|
||||
"@",
|
||||
hex(c[1].fileoff),
|
||||
)
|
||||
# print(dir(section))
|
||||
if c[1].filesize > 0:
|
||||
if len(section.section_data) < 0x100:
|
||||
hexdump(section.section_data)
|
||||
else:
|
||||
print("in file, not dumping 0x%x" % len(section.section_data))
|
||||
|
||||
# this parser is wrong (fixed with 64-bit one)
|
||||
from macholib import SymbolTable
|
||||
|
||||
sym = SymbolTable.SymbolTable(a)
|
||||
|
||||
syms = {}
|
||||
for l in sym.nlists:
|
||||
print(l)
|
||||
if l[0].n_value != 0:
|
||||
syms[l[1]] = l[0].n_value
|
||||
print(l)
|
||||
if l[0].n_value != 0:
|
||||
syms[l[1]] = l[0].n_value
|
||||
|
||||
for k,v in syms.items():
|
||||
print(k, hex(v))
|
||||
for k, v in syms.items():
|
||||
print(k, hex(v))
|
||||
|
||||
|
||||
# **** document what we know ***
|
||||
from ane import ANE_Struct, ANE
|
||||
|
||||
ane = ANE()
|
||||
|
||||
aneb = set()
|
||||
for typ, num, nam in ANE_Struct:
|
||||
ltyp = {"u32": 4, "u16": 2, "u8": 1}[typ]
|
||||
for l in range(num, num+ltyp):
|
||||
aneb.add(l)
|
||||
ltyp = {"u32": 4, "u16": 2, "u8": 1}[typ]
|
||||
for l in range(num, num + ltyp):
|
||||
aneb.add(l)
|
||||
|
||||
# we understand these too
|
||||
for l in range(0x34, 0xF4):
|
||||
aneb.add(l)
|
||||
aneb.add(l)
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def compare(x, y):
|
||||
ss = []
|
||||
ln = []
|
||||
ln2 = []
|
||||
ss = []
|
||||
ln = []
|
||||
ln2 = []
|
||||
|
||||
ll = (max(len(x), len(y)) + 0xF)//0x10 * 0x10
|
||||
ll = (max(len(x), len(y)) + 0xF) // 0x10 * 0x10
|
||||
|
||||
highlight = False
|
||||
next_highlight = 0x2b
|
||||
for i in range(ll+1):
|
||||
if i == next_highlight:
|
||||
highlight = True
|
||||
if i < len(y):
|
||||
next_highlight += y[i]+8
|
||||
else:
|
||||
next_highlight = None
|
||||
else:
|
||||
highlight = False
|
||||
a = "%02X" % x[i] if i < len(x) else "--", \
|
||||
"%02X" % y[i] if i < len(y) else "--"
|
||||
def fj(x):
|
||||
ss = []
|
||||
for i in range(0, 0x10, 4):
|
||||
ss.append(' '.join(x[i:i+4]))
|
||||
return ' '.join(ss)
|
||||
|
||||
if i!=0 and i%0x10 == 0:
|
||||
ss.append("%8X: " % (i-0x10)+fj(ln)+" | "+fj(ln2)+"\n")
|
||||
ln = []
|
||||
ln2 = []
|
||||
if a[0] != a[1] and a[0] != "--" and a[1] != "--":
|
||||
ln.append(colored(a[0], 'green'))
|
||||
ln2.append(colored(a[1], 'red'))
|
||||
else:
|
||||
if highlight:
|
||||
ln.append(colored(a[0], 'yellow'))
|
||||
ln2.append(colored(a[1], 'yellow'))
|
||||
else:
|
||||
if i in aneb:
|
||||
ln.append(colored(a[0], 'white'))
|
||||
ln2.append(colored(a[1], 'white'))
|
||||
highlight = False
|
||||
next_highlight = 0x2B
|
||||
for i in range(ll + 1):
|
||||
if i == next_highlight:
|
||||
highlight = True
|
||||
if i < len(y):
|
||||
next_highlight += y[i] + 8
|
||||
else:
|
||||
next_highlight = None
|
||||
else:
|
||||
ln.append(a[0])
|
||||
ln2.append(a[1])
|
||||
return ''.join(ss)
|
||||
highlight = False
|
||||
a = "%02X" % x[i] if i < len(x) else "--", "%02X" % y[i] if i < len(y) else "--"
|
||||
|
||||
def fj(x):
|
||||
ss = []
|
||||
for i in range(0, 0x10, 4):
|
||||
ss.append(" ".join(x[i : i + 4]))
|
||||
return " ".join(ss)
|
||||
|
||||
if i != 0 and i % 0x10 == 0:
|
||||
ss.append("%8X: " % (i - 0x10) + fj(ln) + " | " + fj(ln2) + "\n")
|
||||
ln = []
|
||||
ln2 = []
|
||||
if a[0] != a[1] and a[0] != "--" and a[1] != "--":
|
||||
ln.append(colored(a[0], "green"))
|
||||
ln2.append(colored(a[1], "red"))
|
||||
else:
|
||||
if highlight:
|
||||
ln.append(colored(a[0], "yellow"))
|
||||
ln2.append(colored(a[1], "yellow"))
|
||||
else:
|
||||
if i in aneb:
|
||||
ln.append(colored(a[0], "white"))
|
||||
ln2.append(colored(a[1], "white"))
|
||||
else:
|
||||
ln.append(a[0])
|
||||
ln2.append(a[1])
|
||||
return "".join(ss)
|
||||
|
||||
|
||||
import json
|
||||
|
||||
aneregs = dict(json.load(open("aneregs.json")))
|
||||
g = get_macho("model.hwx.golden" if len(sys.argv) < 2 else sys.argv[1])
|
||||
f1 = g.headers[0].commands[1][2][0].section_data
|
||||
f2 = a.headers[0].commands[1][2][0].section_data
|
||||
for i in range(0, len(f2), 0x300):
|
||||
print("===== op %d =====" % (i//0x300))
|
||||
if len(f1) < 0x300:
|
||||
c1, c2 = f1, f2[i:i+0x300]
|
||||
else:
|
||||
c1, c2 = f1[i:i+0x300], f2[i:i+0x300]
|
||||
dbg1 = ane.debug(c1, 16)
|
||||
dbg2 = ane.debug(c2, 16)
|
||||
if getenv("PRINTALL"):
|
||||
for k in dbg2:
|
||||
if k in aneregs:
|
||||
rr = aneregs[k] if k in aneregs else (-1,-1,-1)
|
||||
print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k])
|
||||
else:
|
||||
for k in dbg1:
|
||||
if dbg1[k] != dbg2[k]:
|
||||
rr = aneregs[k] if k in aneregs else (-1,-1,-1)
|
||||
print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k])
|
||||
print("===== op %d =====" % (i // 0x300))
|
||||
if len(f1) < 0x300:
|
||||
c1, c2 = f1, f2[i : i + 0x300]
|
||||
else:
|
||||
c1, c2 = f1[i : i + 0x300], f2[i : i + 0x300]
|
||||
dbg1 = ane.debug(c1, 16)
|
||||
dbg2 = ane.debug(c2, 16)
|
||||
if getenv("PRINTALL"):
|
||||
for k in dbg2:
|
||||
if k in aneregs:
|
||||
rr = aneregs[k] if k in aneregs else (-1, -1, -1)
|
||||
print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k])
|
||||
else:
|
||||
for k in dbg1:
|
||||
if dbg1[k] != dbg2[k]:
|
||||
rr = aneregs[k] if k in aneregs else (-1, -1, -1)
|
||||
print("0x%3x %d %2d" % tuple(rr), k, dbg1[k], "->", dbg2[k])
|
||||
|
||||
print(compare(c1, c2))
|
||||
#open("/tmp/data.section", "wb").write(f2)
|
||||
#print(compare(open("model.hwx.golden", "rb").read(), open("model.hwx", "rb").read()))
|
||||
print(compare(c1, c2))
|
||||
# open("/tmp/data.section", "wb").write(f2)
|
||||
# print(compare(open("model.hwx.golden", "rb").read(), open("model.hwx", "rb").read()))
|
||||
|
|
|
@ -1,36 +1,37 @@
|
|||
#!/usr/bin/env python3
|
||||
from ane import ANE
|
||||
|
||||
ane = ANE()
|
||||
|
||||
lens = {}
|
||||
|
||||
dat = b"\xff"*0x300
|
||||
dat = b"\xff" * 0x300
|
||||
ret = ane.debug(dat, 16)
|
||||
for k,v in ret.items():
|
||||
found = None
|
||||
for i in range(33):
|
||||
#print(v, (1 << i) - 1)
|
||||
if v == (1 << i) - 1:
|
||||
found = i
|
||||
break
|
||||
#print(k, hex(v), found)
|
||||
lens[k] = found
|
||||
for k, v in ret.items():
|
||||
found = None
|
||||
for i in range(33):
|
||||
# print(v, (1 << i) - 1)
|
||||
if v == (1 << i) - 1:
|
||||
found = i
|
||||
break
|
||||
# print(k, hex(v), found)
|
||||
lens[k] = found
|
||||
|
||||
pos = []
|
||||
dat = b"\x00"*0x300
|
||||
dat = b"\x00" * 0x300
|
||||
for i in range(0x300):
|
||||
for j in range(8):
|
||||
dat = b"\x00"*i
|
||||
dat += bytes([1 << j])
|
||||
dat += b"\x00"*(0x300-len(dat))
|
||||
ret = ane.debug(dat, 16)
|
||||
for k,v in ret.items():
|
||||
if v == 1:
|
||||
print("0x%3x %d %2d" % (i, j, lens[k]), k)
|
||||
pos.append((k, (i,j, lens[k])))
|
||||
for j in range(8):
|
||||
dat = b"\x00" * i
|
||||
dat += bytes([1 << j])
|
||||
dat += b"\x00" * (0x300 - len(dat))
|
||||
ret = ane.debug(dat, 16)
|
||||
for k, v in ret.items():
|
||||
if v == 1:
|
||||
print("0x%3x %d %2d" % (i, j, lens[k]), k)
|
||||
pos.append((k, (i, j, lens[k])))
|
||||
|
||||
import json
|
||||
|
||||
jpos = json.dumps(pos, indent=2)
|
||||
with open("aneregs.json", "w") as f:
|
||||
f.write(jpos)
|
||||
|
||||
f.write(jpos)
|
||||
|
|
|
@ -2,15 +2,18 @@ import ctypes
|
|||
from subprocess import check_output
|
||||
from hexdump import hexdump
|
||||
|
||||
|
||||
def get_pid(name):
|
||||
try:
|
||||
output = check_output(["pgrep", name])
|
||||
return int(output)
|
||||
except:
|
||||
return None
|
||||
try:
|
||||
output = check_output(["pgrep", name])
|
||||
return int(output)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
from ctypes.util import find_library
|
||||
libc = ctypes.CDLL(find_library('c'))
|
||||
|
||||
libc = ctypes.CDLL(find_library("c"))
|
||||
|
||||
amfid_pid = get_pid("amfid")
|
||||
|
||||
|
@ -19,25 +22,28 @@ mytask = libc.mach_task_self()
|
|||
ret = libc.task_for_pid(mytask, ctypes.c_int(amfid_pid), ctypes.pointer(task))
|
||||
print(amfid_pid, ret, task, mytask)
|
||||
|
||||
#myport = libc.mach_task_self()
|
||||
# myport = libc.mach_task_self()
|
||||
|
||||
|
||||
class vm_region_submap_short_info_data_64(ctypes.Structure):
|
||||
_pack_ = 1
|
||||
_fields_ = [
|
||||
("protection", ctypes.c_uint32),
|
||||
("max_protection", ctypes.c_uint32),
|
||||
("inheritance", ctypes.c_uint32),
|
||||
("offset", ctypes.c_ulonglong),
|
||||
("user_tag", ctypes.c_uint32),
|
||||
("ref_count", ctypes.c_uint32),
|
||||
("shadow_depth", ctypes.c_uint16),
|
||||
("external_pager", ctypes.c_byte),
|
||||
("share_mode", ctypes.c_byte),
|
||||
("is_submap", ctypes.c_uint32),
|
||||
("behavior", ctypes.c_uint32),
|
||||
("object_id", ctypes.c_uint32),
|
||||
("user_wired_count", ctypes.c_uint32),
|
||||
]
|
||||
_pack_ = 1
|
||||
_fields_ = [
|
||||
("protection", ctypes.c_uint32),
|
||||
("max_protection", ctypes.c_uint32),
|
||||
("inheritance", ctypes.c_uint32),
|
||||
("offset", ctypes.c_ulonglong),
|
||||
("user_tag", ctypes.c_uint32),
|
||||
("ref_count", ctypes.c_uint32),
|
||||
("shadow_depth", ctypes.c_uint16),
|
||||
("external_pager", ctypes.c_byte),
|
||||
("share_mode", ctypes.c_byte),
|
||||
("is_submap", ctypes.c_uint32),
|
||||
("behavior", ctypes.c_uint32),
|
||||
("object_id", ctypes.c_uint32),
|
||||
("user_wired_count", ctypes.c_uint32),
|
||||
]
|
||||
|
||||
|
||||
submap_info_size = ctypes.sizeof(vm_region_submap_short_info_data_64) // 4
|
||||
|
||||
address = ctypes.c_ulong(0)
|
||||
|
@ -48,27 +54,37 @@ depth = 0
|
|||
|
||||
c_depth = ctypes.c_uint32(depth)
|
||||
for i in range(1):
|
||||
ret = libc.mach_vm_region_recurse(task,
|
||||
ctypes.pointer(address), ctypes.pointer(mapsize),
|
||||
ctypes.pointer(c_depth), ctypes.pointer(sub_info),
|
||||
ctypes.pointer(count))
|
||||
print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection)
|
||||
#address.value += mapsize.value
|
||||
#exit(0)
|
||||
ret = libc.mach_vm_region_recurse(
|
||||
task,
|
||||
ctypes.pointer(address),
|
||||
ctypes.pointer(mapsize),
|
||||
ctypes.pointer(c_depth),
|
||||
ctypes.pointer(sub_info),
|
||||
ctypes.pointer(count),
|
||||
)
|
||||
print("aslr", hex(ret), hex(address.value), mapsize, count, sub_info.protection)
|
||||
# address.value += mapsize.value
|
||||
# exit(0)
|
||||
|
||||
patch_address = address.value + 0x8e38
|
||||
patch_address = address.value + 0x8E38
|
||||
patch = b"\x00\x00\x80\xd2"
|
||||
|
||||
pdata = ctypes.c_void_p(0)
|
||||
data_cnt = ctypes.c_uint32(0)
|
||||
|
||||
ret = libc.mach_vm_read(task, ctypes.c_ulong(patch_address), 4, ctypes.pointer(pdata), ctypes.pointer(data_cnt))
|
||||
ret = libc.mach_vm_read(
|
||||
task,
|
||||
ctypes.c_ulong(patch_address),
|
||||
4,
|
||||
ctypes.pointer(pdata),
|
||||
ctypes.pointer(data_cnt),
|
||||
)
|
||||
buf = ctypes.string_at(pdata.value, data_cnt.value)
|
||||
hexdump(buf)
|
||||
|
||||
#ret = libc.mach_vm_wire(mytask, task, patch_address, 4, 3)
|
||||
#print(ret)
|
||||
#exit(0)
|
||||
# ret = libc.mach_vm_wire(mytask, task, patch_address, 4, 3)
|
||||
# print(ret)
|
||||
# exit(0)
|
||||
|
||||
"""
|
||||
ret = libc.mach_vm_read(task, address, mapsize, ctypes.pointer(pdata), ctypes.pointer(data_cnt))
|
||||
|
@ -86,17 +102,17 @@ ret = libc.mach_vm_protect(task, ctypes.c_ulong(patch_address), 4, True, 3)
|
|||
print("protect", ret)
|
||||
|
||||
longptr = ctypes.POINTER(ctypes.c_ulong)
|
||||
#shellcodePtr = ctypes.cast(buf, longptr)
|
||||
#ret = libc.mach_vm_write(task, address, shellcodePtr, len(buf))
|
||||
#print("write", ret)
|
||||
# shellcodePtr = ctypes.cast(buf, longptr)
|
||||
# ret = libc.mach_vm_write(task, address, shellcodePtr, len(buf))
|
||||
# print("write", ret)
|
||||
|
||||
shellcodePtr = ctypes.cast(patch, longptr)
|
||||
ret = libc.mach_vm_write(task, ctypes.c_ulong(patch_address), shellcodePtr, len(buf))
|
||||
print("write", ret)
|
||||
|
||||
#libc.mach_vm_write.argtypes = [ctypes.c_uint32, ctypes.c_ulong, longptr, ctypes.c_uint32]
|
||||
#libc.mach_vm_write.restype = ctypes.c_uint32
|
||||
#ret = libc.mach_vm_write(task, ctypes.c_ulong(patch_address), shellcodePtr, len(patch))
|
||||
# libc.mach_vm_write.argtypes = [ctypes.c_uint32, ctypes.c_ulong, longptr, ctypes.c_uint32]
|
||||
# libc.mach_vm_write.restype = ctypes.c_uint32
|
||||
# ret = libc.mach_vm_write(task, ctypes.c_ulong(patch_address), shellcodePtr, len(patch))
|
||||
|
||||
ret = libc.mach_vm_protect(task, ctypes.c_ulong(patch_address), 4, False, 5)
|
||||
print("protect", ret)
|
||||
print("protect", ret)
|
||||
|
|
|
@ -6,217 +6,214 @@ import collections
|
|||
import numpy as np
|
||||
import faulthandler
|
||||
import struct
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
basedir = Path(__file__).resolve().parent
|
||||
|
||||
libane = None
|
||||
aneregs = None
|
||||
|
||||
|
||||
def init_libane():
|
||||
global libane, aneregs
|
||||
libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())
|
||||
global libane, aneregs
|
||||
libane = cdll.LoadLibrary((basedir / "libane.dylib").as_posix())
|
||||
|
||||
libane.ANE_Compile.argtypes = [c_char_p, c_int]
|
||||
libane.ANE_Compile.restype = c_void_p
|
||||
libane.ANE_Compile.argtypes = [c_char_p, c_int]
|
||||
libane.ANE_Compile.restype = c_void_p
|
||||
|
||||
libane.ANE_TensorCreate.restype = c_void_p
|
||||
libane.ANE_TensorCreate.restype = c_void_p
|
||||
|
||||
libane.ANE_TensorData.argtypes = [c_void_p]
|
||||
libane.ANE_TensorData.restype = POINTER(c_uint16)
|
||||
libane.ANE_TensorData.argtypes = [c_void_p]
|
||||
libane.ANE_TensorData.restype = POINTER(c_uint16)
|
||||
|
||||
libane.ANE_Run.argtypes = [c_void_p]*4
|
||||
libane.ANE_Run.restype = c_int
|
||||
libane.ANE_Run.argtypes = [c_void_p] * 4
|
||||
libane.ANE_Run.restype = c_int
|
||||
|
||||
#libane.ANE_RegDebug.restype = c_char_p
|
||||
# libane.ANE_RegDebug.restype = c_char_p
|
||||
|
||||
with open(basedir / "aneregs.json") as f:
|
||||
aneregs = json.load(f)
|
||||
|
||||
with open(basedir / "aneregs.json") as f:
|
||||
aneregs = json.load(f)
|
||||
|
||||
ANE_Struct = [
|
||||
# aneTD.Header
|
||||
("u32", 0x1C, "NextCommandOffset"),
|
||||
|
||||
# KernelDMASrc @ section @ 0x2C len 0xF4
|
||||
# reloc 0x2c-0x34?? = weights
|
||||
# u32[16] 0x34-0x74 = 0x80 | 1 if used
|
||||
# u32[16] 0x74-0xB4 = <channel data offset>
|
||||
# u32[16] 0xB4-0xF4 = <channel data length>
|
||||
|
||||
# Common @ section @ 0x128 len 0x3C (conv)
|
||||
("u16", 0x128, "InputWidth"),
|
||||
("u16", 0x12A, "InputHeight"),
|
||||
("u16", 0x12C, "InputDepth"),
|
||||
|
||||
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
|
||||
# UInt8 = 0, Int8 = 1, Float16 = 2
|
||||
|
||||
("u32", 0x134, "InputChannels"),
|
||||
("u32", 0x138, "OutputChannels"),
|
||||
|
||||
("u16", 0x13C, "OutputWidth"),
|
||||
("u16", 0x13E, "OutputHeight"),
|
||||
("u16", 0x140, "OutputDepth"),
|
||||
|
||||
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
|
||||
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
|
||||
|
||||
("u16", 0x14C, "BatchSize"),
|
||||
|
||||
# TileDMASrc @ section @ 0x16C len 0x6C (input)
|
||||
# reloc 0x16c-0x174 = image
|
||||
("u32", 0x178, "InputRowStride"),
|
||||
("u32", 0x17C, "InputPlaneStride"),
|
||||
("u32", 0x180, "InputDepthStride"),
|
||||
("u32", 0x184, "InputBatchStride"),
|
||||
|
||||
("u8", 0x1A7, "InputInterleave"),
|
||||
|
||||
# L2 @ section @ 0x1E0 len 0x44
|
||||
# [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines
|
||||
# [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv?
|
||||
# [0x21c, 0x220, 0x224] = engines for outconv?
|
||||
|
||||
# NE @ section @ 0x22c len 0xC (scaling)
|
||||
("u16", 0x230, "BiasScalar"),
|
||||
("u16", 0x232, "ScaleScalar"),
|
||||
|
||||
# section @ 0x240 len 0x10
|
||||
("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom
|
||||
("u32", 0x250, "PostScale"),
|
||||
|
||||
# TileDMADst @ section @ 0x258 len 0x18
|
||||
|
||||
# HandleTileDmaDstConfig
|
||||
# 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0;
|
||||
# (GetCacheHintRegisterValue & 0xf) << 6;
|
||||
("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at?
|
||||
|
||||
# 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6;
|
||||
("u32", 0x260, "OutputRowStride"),
|
||||
("u32", 0x264, "OutputPlaneStride"),
|
||||
("u32", 0x268, "OutputDepthStride"),
|
||||
("u32", 0x26C, "OutputBatchStride"),
|
||||
|
||||
# 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000;
|
||||
# uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031;
|
||||
# (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18;
|
||||
("u8", 0x273, "OutputInterleave"), # i also have this at 0x211?
|
||||
# aneTD.Header
|
||||
("u32", 0x1C, "NextCommandOffset"),
|
||||
# KernelDMASrc @ section @ 0x2C len 0xF4
|
||||
# reloc 0x2c-0x34?? = weights
|
||||
# u32[16] 0x34-0x74 = 0x80 | 1 if used
|
||||
# u32[16] 0x74-0xB4 = <channel data offset>
|
||||
# u32[16] 0xB4-0xF4 = <channel data length>
|
||||
# Common @ section @ 0x128 len 0x3C (conv)
|
||||
("u16", 0x128, "InputWidth"),
|
||||
("u16", 0x12A, "InputHeight"),
|
||||
("u16", 0x12C, "InputDepth"),
|
||||
("u32", 0x130, "InputOutputType"), # (OutputType * 0x10) | InputType
|
||||
# UInt8 = 0, Int8 = 1, Float16 = 2
|
||||
("u32", 0x134, "InputChannels"),
|
||||
("u32", 0x138, "OutputChannels"),
|
||||
("u16", 0x13C, "OutputWidth"),
|
||||
("u16", 0x13E, "OutputHeight"),
|
||||
("u16", 0x140, "OutputDepth"),
|
||||
("u16", 0x144, "KernelSize"), # 0xa000 | (KernelHeight * 0x20) | KernelWidth
|
||||
("u16", 0x146, "Padding"), # 0x5000 | (PadTop * 0x40) | (PadLeft * 2)
|
||||
("u16", 0x14C, "BatchSize"),
|
||||
# TileDMASrc @ section @ 0x16C len 0x6C (input)
|
||||
# reloc 0x16c-0x174 = image
|
||||
("u32", 0x178, "InputRowStride"),
|
||||
("u32", 0x17C, "InputPlaneStride"),
|
||||
("u32", 0x180, "InputDepthStride"),
|
||||
("u32", 0x184, "InputBatchStride"),
|
||||
("u8", 0x1A7, "InputInterleave"),
|
||||
# L2 @ section @ 0x1E0 len 0x44
|
||||
# [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214] = number of engines
|
||||
# [0x1f0, 0x1f4, 0x1f8, 0x214] = engines for inconv?
|
||||
# [0x21c, 0x220, 0x224] = engines for outconv?
|
||||
# NE @ section @ 0x22c len 0xC (scaling)
|
||||
("u16", 0x230, "BiasScalar"),
|
||||
("u16", 0x232, "ScaleScalar"),
|
||||
# section @ 0x240 len 0x10
|
||||
("u16", 0x246, "NeuronType"), # 0x10 = copy, 0x11 = ReLU, 0x12 = custom
|
||||
("u32", 0x250, "PostScale"),
|
||||
# TileDMADst @ section @ 0x258 len 0x18
|
||||
# HandleTileDmaDstConfig
|
||||
# 0x258 -- *(uint *)(this + 0x334) = *(uint *)(this + 0x334) & 0xfffffc3f | 0xc0;
|
||||
# (GetCacheHintRegisterValue & 0xf) << 6;
|
||||
("u32", 0x25C, "OutputOffset"), # offset into output buffer to write at?
|
||||
# 0x260 -- *(uint *)(this + 0x33c) = *(uint *)(this + 0x33c) & 0x3f | (int)uVar10 << 6;
|
||||
("u32", 0x260, "OutputRowStride"),
|
||||
("u32", 0x264, "OutputPlaneStride"),
|
||||
("u32", 0x268, "OutputDepthStride"),
|
||||
("u32", 0x26C, "OutputBatchStride"),
|
||||
# 0x270 -- *(uint *)(this + 0x34c) = *(uint *)(this + 0x34c) & 0xf0ffffff | 0x1000000;
|
||||
# uVar6 = *(uint *)(this + 0x34c) & 0xffffcfcc | 0x2031;
|
||||
# (ZinTensorDescriptorDmaInterleave & 0xf) << 0x18;
|
||||
("u8", 0x273, "OutputInterleave"), # i also have this at 0x211?
|
||||
]
|
||||
|
||||
ANE_Struct_Dict = {}
|
||||
for typ, num, nam in ANE_Struct:
|
||||
styp = {"u32": "I", "u16": "H", "u8": "B"}[typ]
|
||||
ANE_Struct_Dict[nam] = (styp, num)
|
||||
styp = {"u32": "I", "u16": "H", "u8": "B"}[typ]
|
||||
ANE_Struct_Dict[nam] = (styp, num)
|
||||
|
||||
|
||||
class ANETensor:
|
||||
def __init__(self, *shape):
|
||||
self.shape = shape
|
||||
self.dtype = np.float16
|
||||
self.sz = int(np.prod(shape))
|
||||
assert(self.sz <= 0x4000)
|
||||
self.tt = libane.ANE_TensorCreate(self.sz, 1)
|
||||
assert(self.tt is not None)
|
||||
def __init__(self, *shape):
|
||||
self.shape = shape
|
||||
self.dtype = np.float16
|
||||
self.sz = int(np.prod(shape))
|
||||
assert self.sz <= 0x4000
|
||||
self.tt = libane.ANE_TensorCreate(self.sz, 1)
|
||||
assert self.tt is not None
|
||||
|
||||
def data(self):
|
||||
data = libane.ANE_TensorData(self.tt)
|
||||
assert data is not None
|
||||
# print(hex(addressof(data.contents)))
|
||||
buf = np.ctypeslib.as_array(data, shape=(self.sz,))
|
||||
ret = np.frombuffer(buf, dtype=self.dtype)
|
||||
# print(ret.data)
|
||||
return ret
|
||||
|
||||
def data(self):
|
||||
data = libane.ANE_TensorData(self.tt)
|
||||
assert(data is not None)
|
||||
#print(hex(addressof(data.contents)))
|
||||
buf = np.ctypeslib.as_array(data, shape=(self.sz,))
|
||||
ret = np.frombuffer(buf, dtype=self.dtype)
|
||||
#print(ret.data)
|
||||
return ret
|
||||
|
||||
class ANE:
|
||||
def __init__(self):
|
||||
init_libane()
|
||||
libane.ANE_Open()
|
||||
def __init__(self):
|
||||
init_libane()
|
||||
libane.ANE_Open()
|
||||
|
||||
def compile(self, dat):
|
||||
ret = libane.ANE_Compile(create_string_buffer(dat), len(dat))
|
||||
assert(ret is not None)
|
||||
return ret
|
||||
def compile(self, dat):
|
||||
ret = libane.ANE_Compile(create_string_buffer(dat), len(dat))
|
||||
assert ret is not None
|
||||
return ret
|
||||
|
||||
def run(self, prog, tin, tout, tweights=None):
|
||||
libane.ANE_Run(prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0)
|
||||
def run(self, prog, tin, tout, tweights=None):
|
||||
libane.ANE_Run(
|
||||
prog, tin.tt, tout.tt, tweights.tt if tweights is not None else 0
|
||||
)
|
||||
|
||||
def tensor(self, shape):
|
||||
return ANETensor(shape)
|
||||
def tensor(self, shape):
|
||||
return ANETensor(shape)
|
||||
|
||||
def unpack(self, dat):
|
||||
dat = struct.unpack("Q"*(len(dat)//8), dat)
|
||||
ret = {}
|
||||
for k,v in aneregs:
|
||||
by,bi,sz = v
|
||||
bi += (by%8)*8
|
||||
by //= 8
|
||||
rv = (dat[by] >> bi) & ((1 << sz)-1)
|
||||
ret[k] = rv
|
||||
return ret
|
||||
def unpack(self, dat):
|
||||
dat = struct.unpack("Q" * (len(dat) // 8), dat)
|
||||
ret = {}
|
||||
for k, v in aneregs:
|
||||
by, bi, sz = v
|
||||
bi += (by % 8) * 8
|
||||
by //= 8
|
||||
rv = (dat[by] >> bi) & ((1 << sz) - 1)
|
||||
ret[k] = rv
|
||||
return ret
|
||||
|
||||
def pack(self, pk, dat):
|
||||
dat = list(struct.unpack("Q"*(len(dat)//8), dat))
|
||||
for k,v in aneregs:
|
||||
by,bi,sz = v
|
||||
bi += (by%8)*8
|
||||
by //= 8
|
||||
dat[by] &= ~(((1 << sz)-1) << bi)
|
||||
dat[by] |= pk[k] << bi
|
||||
dat = struct.pack("Q"*len(dat), *dat)
|
||||
return dat
|
||||
def pack(self, pk, dat):
|
||||
dat = list(struct.unpack("Q" * (len(dat) // 8), dat))
|
||||
for k, v in aneregs:
|
||||
by, bi, sz = v
|
||||
bi += (by % 8) * 8
|
||||
by //= 8
|
||||
dat[by] &= ~(((1 << sz) - 1) << bi)
|
||||
dat[by] |= pk[k] << bi
|
||||
dat = struct.pack("Q" * len(dat), *dat)
|
||||
return dat
|
||||
|
||||
def debug(self, dat, mems=0):
|
||||
add = [0x30, 0x1d4, 0x220, 0x29c, 0x2f0, 0x30c, 0x32c]
|
||||
lens = [244, 60, 108, 68, 12, 16, 24]
|
||||
ptr = 0x2b
|
||||
ddat = dat[0:0x28]
|
||||
for a, pm in zip(add, lens):
|
||||
#assert pm == dat[ptr]
|
||||
ddat += b"\x00" * (a-len(ddat))
|
||||
ddat += dat[ptr+1:ptr+1+pm+4]
|
||||
ptr += pm+8
|
||||
ddat += b"\x00" * 0x100
|
||||
ret = collections.OrderedDict()
|
||||
for ln in libane.ANE_RegDebug(0, create_string_buffer(ddat), mems).decode('utf-8').strip().split("\n"):
|
||||
lnn = ln.split(" = ")
|
||||
if len(lnn) == 2:
|
||||
ret[lnn[0]] = int(lnn[1])
|
||||
return ret
|
||||
def debug(self, dat, mems=0):
|
||||
add = [0x30, 0x1D4, 0x220, 0x29C, 0x2F0, 0x30C, 0x32C]
|
||||
lens = [244, 60, 108, 68, 12, 16, 24]
|
||||
ptr = 0x2B
|
||||
ddat = dat[0:0x28]
|
||||
for a, pm in zip(add, lens):
|
||||
# assert pm == dat[ptr]
|
||||
ddat += b"\x00" * (a - len(ddat))
|
||||
ddat += dat[ptr + 1 : ptr + 1 + pm + 4]
|
||||
ptr += pm + 8
|
||||
ddat += b"\x00" * 0x100
|
||||
ret = collections.OrderedDict()
|
||||
for ln in (
|
||||
libane.ANE_RegDebug(0, create_string_buffer(ddat), mems)
|
||||
.decode("utf-8")
|
||||
.strip()
|
||||
.split("\n")
|
||||
):
|
||||
lnn = ln.split(" = ")
|
||||
if len(lnn) == 2:
|
||||
ret[lnn[0]] = int(lnn[1])
|
||||
return ret
|
||||
|
||||
def filln(self, dat, nvdict, base=0x4000):
|
||||
for n,v in nvdict.items():
|
||||
styp, num = ANE_Struct_Dict[n]
|
||||
dat = self.fill(dat, [num], styp, v)
|
||||
return dat
|
||||
def filln(self, dat, nvdict, base=0x4000):
|
||||
for n, v in nvdict.items():
|
||||
styp, num = ANE_Struct_Dict[n]
|
||||
dat = self.fill(dat, [num], styp, v)
|
||||
return dat
|
||||
|
||||
def fill(self, dat, addrs, type, val, base=0x4000):
|
||||
x = struct.pack(type, val)
|
||||
for a in addrs:
|
||||
dat[base + a : base + a + len(x)] = x
|
||||
return dat
|
||||
|
||||
def fill(self, dat, addrs, type, val, base=0x4000):
|
||||
x = struct.pack(type, val)
|
||||
for a in addrs:
|
||||
dat[base+a:base+a+len(x)] = x
|
||||
return dat
|
||||
|
||||
if __name__ == "__main__":
|
||||
ane = ANE()
|
||||
ane = ANE()
|
||||
|
||||
tin = ANETensor(16)
|
||||
tout = ANETensor(16)
|
||||
tin = ANETensor(16)
|
||||
tout = ANETensor(16)
|
||||
|
||||
tind = tin.data()
|
||||
toutd = tout.data()
|
||||
tind = tin.data()
|
||||
toutd = tout.data()
|
||||
|
||||
tind[0:4] = [-1,1,-2,2]
|
||||
print("** before **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
tind[0:4] = [-1, 1, -2, 2]
|
||||
print("** before **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
|
||||
dat = open("../ops/relu.hwx", "rb").read()
|
||||
md = dat[0x4000:0x4300]
|
||||
dd = ane.unpack(md)
|
||||
mdf = ane.pack(dd, md)
|
||||
assert(md == mdf)
|
||||
|
||||
comp = ane.compile(dat)
|
||||
ret = ane.run(comp, tin, tout)
|
||||
print("** after **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
dat = open("../ops/relu.hwx", "rb").read()
|
||||
md = dat[0x4000:0x4300]
|
||||
dd = ane.unpack(md)
|
||||
mdf = ane.pack(dd, md)
|
||||
assert md == mdf
|
||||
|
||||
comp = ane.compile(dat)
|
||||
ret = ane.run(comp, tin, tout)
|
||||
print("** after **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
|
|
|
@ -2,63 +2,64 @@
|
|||
import time
|
||||
from ane import ANE, ANETensor
|
||||
|
||||
|
||||
def benchmark(ane):
|
||||
tin = ANETensor(512*0x20)
|
||||
tout = ANETensor(512*0x20)
|
||||
dat = open("../ops/gemm.hwx", "rb").read()
|
||||
for k,v in ane.debug(dat[0x4000:0x4300], 16).items():
|
||||
print(k,v)
|
||||
comp = ane.compile(dat)
|
||||
tin = ANETensor(512 * 0x20)
|
||||
tout = ANETensor(512 * 0x20)
|
||||
dat = open("../ops/gemm.hwx", "rb").read()
|
||||
for k, v in ane.debug(dat[0x4000:0x4300], 16).items():
|
||||
print(k, v)
|
||||
comp = ane.compile(dat)
|
||||
|
||||
st = time.time()
|
||||
for i in range(1000):
|
||||
ret = ane.run(comp, tin, tout)
|
||||
et = time.time()
|
||||
ts = (et-st)
|
||||
ops = 1000*512*512*2
|
||||
st = time.time()
|
||||
for i in range(1000):
|
||||
ret = ane.run(comp, tin, tout)
|
||||
et = time.time()
|
||||
ts = et - st
|
||||
ops = 1000 * 512 * 512 * 2
|
||||
|
||||
print("%.2f ms, %.2f gigaops/sec" % (ts*1000, ops*1e-9/ts))
|
||||
print("%.2f ms, %.2f gigaops/sec" % (ts * 1000, ops * 1e-9 / ts))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ane = ANE()
|
||||
ane = ANE()
|
||||
|
||||
# 0x20 per row
|
||||
tin = ANETensor(0x60)
|
||||
tout = ANETensor(0x60)
|
||||
tw = ANETensor(0x60)
|
||||
# 0x20 per row
|
||||
tin = ANETensor(0x60)
|
||||
tout = ANETensor(0x60)
|
||||
tw = ANETensor(0x60)
|
||||
|
||||
tind = tin.data()
|
||||
toutd = tout.data()
|
||||
twd = tw.data()
|
||||
tind = tin.data()
|
||||
toutd = tout.data()
|
||||
twd = tw.data()
|
||||
|
||||
#tind[0:4] = [-1,1,-2,2]
|
||||
tind[0] = 1
|
||||
tind[0x20] = -2
|
||||
tind[0x40] = 3
|
||||
# tind[0:4] = [-1,1,-2,2]
|
||||
tind[0] = 1
|
||||
tind[0x20] = -2
|
||||
tind[0x40] = 3
|
||||
|
||||
# toutd[0] = \
|
||||
# tind[0] * twd[0] + \
|
||||
# tind[0x20] + twd[1] + \
|
||||
# tind[0x40] + twd[2]
|
||||
# toutd[0] = \
|
||||
# tind[0] * twd[0] + \
|
||||
# tind[0x20] + twd[1] + \
|
||||
# tind[0x40] + twd[2]
|
||||
|
||||
twd[0] = 4
|
||||
twd[1] = 0x100
|
||||
twd[0] = 4
|
||||
twd[1] = 0x100
|
||||
|
||||
twd[0x20] = 5
|
||||
twd[0x21] = 5
|
||||
twd[0x22] = 5
|
||||
twd[0x20] = 5
|
||||
twd[0x21] = 5
|
||||
twd[0x22] = 5
|
||||
|
||||
twd[0x40] = 12
|
||||
twd[0x40] = 12
|
||||
|
||||
print("** before **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
print("** before **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
|
||||
#benchmark(ane)
|
||||
#exit(0)
|
||||
# benchmark(ane)
|
||||
# exit(0)
|
||||
|
||||
"""
|
||||
"""
|
||||
dat = list(open("../ops/sum.hwx", "rb").read())
|
||||
dat = bytes(dat)
|
||||
for k,v in ane.debug(dat[0x4000:0x4300], 16).items():
|
||||
|
@ -67,25 +68,25 @@ if __name__ == "__main__":
|
|||
ret = ane.run(comp, tin, tout, tw)
|
||||
"""
|
||||
|
||||
datb = open("../ops/sum.hwx", "rb").read()
|
||||
dat = open("../ops/conv.hwx", "rb").read()
|
||||
dd = ane.unpack(dat[0x4000:0x4300])
|
||||
# use the 3rd arg as the weights
|
||||
dd["aneTD.Header[9].KBase0"] = 6
|
||||
dd["aneRegs.NE.PostScale.PostScale"] = 0x3c00
|
||||
#dd["aneRegs.L2.L2Cfg.InputReLU"] = 1
|
||||
#dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1
|
||||
#dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0
|
||||
#dd["aneRegs.L2.ResultBase.Addr"] = 0
|
||||
#dd["aneRegs.Common.ChCfg.InFmt"] = 1
|
||||
#dd["aneRegs.TileDMADst.Fmt.ZeroPadFirst"] = 0
|
||||
#dd["aneRegs.TileDMADst.DMAConfig.En"] = 0
|
||||
for k,v in dd.items():
|
||||
print(k,v)
|
||||
dat = datb[:0x4000] + ane.pack(dd, dat[0x4000:0x4300]) + datb[0x4300:]
|
||||
comp = ane.compile(dat)
|
||||
ret = ane.run(comp, tin, tout, tw)
|
||||
datb = open("../ops/sum.hwx", "rb").read()
|
||||
dat = open("../ops/conv.hwx", "rb").read()
|
||||
dd = ane.unpack(dat[0x4000:0x4300])
|
||||
# use the 3rd arg as the weights
|
||||
dd["aneTD.Header[9].KBase0"] = 6
|
||||
dd["aneRegs.NE.PostScale.PostScale"] = 0x3C00
|
||||
# dd["aneRegs.L2.L2Cfg.InputReLU"] = 1
|
||||
# dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1
|
||||
# dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0
|
||||
# dd["aneRegs.L2.ResultBase.Addr"] = 0
|
||||
# dd["aneRegs.Common.ChCfg.InFmt"] = 1
|
||||
# dd["aneRegs.TileDMADst.Fmt.ZeroPadFirst"] = 0
|
||||
# dd["aneRegs.TileDMADst.DMAConfig.En"] = 0
|
||||
for k, v in dd.items():
|
||||
print(k, v)
|
||||
dat = datb[:0x4000] + ane.pack(dd, dat[0x4000:0x4300]) + datb[0x4300:]
|
||||
comp = ane.compile(dat)
|
||||
ret = ane.run(comp, tin, tout, tw)
|
||||
|
||||
print("** after **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
print("** after **")
|
||||
print(tind)
|
||||
print(toutd)
|
||||
|
|
|
@ -1,39 +1,52 @@
|
|||
from functools import lru_cache
|
||||
from .tensor import Device, Function, register
|
||||
|
||||
|
||||
@lru_cache
|
||||
def compile_wrapper(ane, dat):
|
||||
return ane.compile(dat)
|
||||
return ane.compile(dat)
|
||||
|
||||
|
||||
def roundup(x, v):
|
||||
return x + (v-x)%v
|
||||
return x + (v - x) % v
|
||||
|
||||
|
||||
@lru_cache
|
||||
def compile_relu(ane, sz):
|
||||
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
|
||||
# TODO: make this all nice and once
|
||||
# number of engines? (max 0x100)
|
||||
l2_stride = max(0x100, roundup(sz*2, 0x10))
|
||||
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
|
||||
# 0x1f4, 0x1f8?
|
||||
# 0x214 = L2.ResultBase.Addr
|
||||
dat = ane.fill(dat, [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214], "I", l2_stride)
|
||||
stride = roundup(sz*2, 0x40)
|
||||
dat = ane.filln(dat, {
|
||||
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
|
||||
"InputWidth": sz, "OutputWidth": sz,
|
||||
"InputRowStride": stride, "InputPlaneStride": stride, "InputDepthStride": stride,
|
||||
"OutputRowStride": stride, "OutputPlaneStride": stride, "OutputDepthStride": stride,
|
||||
})
|
||||
return compile_wrapper(ane, bytes(dat))
|
||||
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
|
||||
# TODO: make this all nice and once
|
||||
# number of engines? (max 0x100)
|
||||
l2_stride = max(0x100, roundup(sz * 2, 0x10))
|
||||
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
|
||||
# 0x1f4, 0x1f8?
|
||||
# 0x214 = L2.ResultBase.Addr
|
||||
dat = ane.fill(dat, [0x1EC, 0x1F0, 0x1F4, 0x1F8, 0x214], "I", l2_stride)
|
||||
stride = roundup(sz * 2, 0x40)
|
||||
dat = ane.filln(
|
||||
dat,
|
||||
{
|
||||
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
|
||||
"InputWidth": sz,
|
||||
"OutputWidth": sz,
|
||||
"InputRowStride": stride,
|
||||
"InputPlaneStride": stride,
|
||||
"InputDepthStride": stride,
|
||||
"OutputRowStride": stride,
|
||||
"OutputPlaneStride": stride,
|
||||
"OutputDepthStride": stride,
|
||||
},
|
||||
)
|
||||
return compile_wrapper(ane, bytes(dat))
|
||||
|
||||
|
||||
class ReLU(Function):
|
||||
def forward(ctx, input):
|
||||
ret = ctx.ane.tensor(input.shape)
|
||||
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)
|
||||
return ret
|
||||
def forward(ctx, input):
|
||||
ret = ctx.ane.tensor(input.shape)
|
||||
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)
|
||||
return ret
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return 0
|
||||
def backward(ctx, grad_output):
|
||||
return 0
|
||||
|
||||
register('relu', ReLU, device=Device.ANE)
|
||||
|
||||
register("relu", ReLU, device=Device.ANE)
|
||||
|
|
|
@ -31,19 +31,20 @@ for x in out.values(): x.realize()
|
|||
"""
|
||||
|
||||
from openvino.runtime import Core
|
||||
|
||||
core = Core()
|
||||
devices = core.available_devices
|
||||
for device in devices:
|
||||
device_name = core.get_property(device, "FULL_DEVICE_NAME")
|
||||
print(f"{device}: {device_name}")
|
||||
device_name = core.get_property(device, "FULL_DEVICE_NAME")
|
||||
print(f"{device}: {device_name}")
|
||||
model = core.read_model(onnx_path)
|
||||
compiled_model = core.compile_model(model, device_name='GPU.0')
|
||||
compiled_model = core.compile_model(model, device_name="GPU.0")
|
||||
print(compiled_model)
|
||||
ireq = compiled_model.create_infer_request()
|
||||
for model_input in compiled_model.inputs:
|
||||
tensor = ireq.get_tensor(model_input)
|
||||
tensor.data[:] = 2
|
||||
print(tensor)
|
||||
tensor = ireq.get_tensor(model_input)
|
||||
tensor.data[:] = 2
|
||||
print(tensor)
|
||||
print("request")
|
||||
ireq.infer()
|
||||
ireq.infer()
|
||||
|
@ -51,7 +52,7 @@ print("did one")
|
|||
|
||||
REPS = 20
|
||||
st = time.perf_counter()
|
||||
for i in range(REPS): ireq.infer()
|
||||
for i in range(REPS):
|
||||
ireq.infer()
|
||||
et = time.perf_counter() - st
|
||||
print(f"{et*1000:.2f} ms {(CNT*N*N*N*REPS*2/et)*1e-9:.2f} GFLOPS")
|
||||
|
||||
|
|
|
@ -7,11 +7,14 @@ from tqdm import trange, tqdm
|
|||
from matplotlib import pyplot as plt
|
||||
|
||||
tests = {}
|
||||
|
||||
|
||||
def register_test(fxn):
|
||||
tests[fxn.__name__] = fxn
|
||||
tests[fxn.__name__] = fxn
|
||||
|
||||
|
||||
def warp_size2(nthread):
|
||||
prg = """__kernel void warp_size2(
|
||||
prg = """__kernel void warp_size2(
|
||||
__global float* src,
|
||||
__global int* dst,
|
||||
const int niter,
|
||||
|
@ -24,20 +27,40 @@ def warp_size2(nthread):
|
|||
}
|
||||
dst[get_local_id(0)] = drain;
|
||||
}"""
|
||||
src_buf = CLBuffer(1, dtypes.float32)
|
||||
dst_buf = CLBuffer(1, dtypes.int32)
|
||||
cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32])
|
||||
return min([cl([nthread, 1024, 1], [nthread, 1, 1], src_buf, dst_buf, 10, 3, wait=True) for _ in range(5)])*1e9
|
||||
src_buf = CLBuffer(1, dtypes.float32)
|
||||
dst_buf = CLBuffer(1, dtypes.int32)
|
||||
cl = CLProgram("warp_size2", prg, argdtypes=[None, None, np.int32, np.int32])
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl(
|
||||
[nthread, 1024, 1],
|
||||
[nthread, 1, 1],
|
||||
src_buf,
|
||||
dst_buf,
|
||||
10,
|
||||
3,
|
||||
wait=True,
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_warp_size():
|
||||
return [(nthread, warp_size2(nthread)) for nthread in trange(1,256)]
|
||||
return [(nthread, warp_size2(nthread)) for nthread in trange(1, 256)]
|
||||
|
||||
|
||||
def reg_count(nthread, ngrp, nreg):
|
||||
reg_declr = ''.join([f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)])
|
||||
reg_comp = ''.join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
|
||||
reg_reduce = ''.join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
|
||||
prg = f"""__kernel void reg_count(
|
||||
reg_declr = "".join(
|
||||
[f"float reg_data{i} = (float)niter + {i};\n" for i in range(nreg)]
|
||||
)
|
||||
reg_comp = "".join([f"reg_data{i} *= {(i-1)%nreg};\n" for i in range(nreg)])
|
||||
reg_reduce = "".join([f"out_buf[{i}] = reg_data{i};\n" for i in range(nreg)])
|
||||
prg = f"""__kernel void reg_count(
|
||||
__global float* out_buf,
|
||||
__private const int niter
|
||||
) {{
|
||||
|
@ -49,18 +72,31 @@ def reg_count(nthread, ngrp, nreg):
|
|||
i = i >> 31;
|
||||
{reg_reduce}
|
||||
}}"""
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
|
||||
return min([cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True) for _ in range(10)])*1e9
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("reg_count", prg, argdtypes=[None, np.int32])
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl([nthread, ngrp, 1], [nthread, 1, 1], out_buf, 20, wait=True)
|
||||
for _ in range(10)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_reg_count(nthread=1, ngrp=1):
|
||||
base = reg_count(nthread, ngrp, 1)
|
||||
return [(nreg, (reg_count(nthread, ngrp, nreg)-base)/nreg) for nreg in trange(4, 513, 4)]
|
||||
base = reg_count(nthread, ngrp, 1)
|
||||
return [
|
||||
(nreg, (reg_count(nthread, ngrp, nreg) - base) / nreg)
|
||||
for nreg in trange(4, 513, 4)
|
||||
]
|
||||
|
||||
|
||||
def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
|
||||
ndata //= NCOMP*4 # ptr size
|
||||
prg = f"""__kernel void buf_cache_hierarchy_pchase(
|
||||
ndata //= NCOMP * 4 # ptr size
|
||||
prg = f"""__kernel void buf_cache_hierarchy_pchase(
|
||||
__global int{str(NCOMP) if NCOMP > 1 else ''}* src,
|
||||
__global int* dst,
|
||||
const int niter
|
||||
|
@ -71,49 +107,76 @@ def buf_cache_hierarchy_pchase(ndata, stride=1, NCOMP=1, steps=65536):
|
|||
}}
|
||||
*dst = idx;
|
||||
}}"""
|
||||
idx_buf = np.zeros(ndata*NCOMP, dtype=np.int32)
|
||||
for i in range(ndata): idx_buf[i*NCOMP] = (i + stride) % ndata
|
||||
in_buf = CLBuffer.fromCPU(idx_buf)
|
||||
out_buf = CLBuffer(1, dtypes.int32)
|
||||
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
|
||||
return min([cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True)/steps for _ in range(5)])*1e9
|
||||
idx_buf = np.zeros(ndata * NCOMP, dtype=np.int32)
|
||||
for i in range(ndata):
|
||||
idx_buf[i * NCOMP] = (i + stride) % ndata
|
||||
in_buf = CLBuffer.fromCPU(idx_buf)
|
||||
out_buf = CLBuffer(1, dtypes.int32)
|
||||
cl = CLProgram("buf_cache_hierarchy_pchase", prg, argdtypes=[None, None, np.int32])
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl([1, 1, 1], [1, 1, 1], in_buf, out_buf, steps, wait=True) / steps
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_memory_latency():
|
||||
# requires cacheline < 16
|
||||
szs = [int(1.3**x) for x in range(20, 70)]
|
||||
return [(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128*1024)) for ndata in tqdm(szs)]
|
||||
# requires cacheline < 16
|
||||
szs = [int(1.3**x) for x in range(20, 70)]
|
||||
return [
|
||||
(ndata, buf_cache_hierarchy_pchase(ndata, NCOMP=16, steps=128 * 1024))
|
||||
for ndata in tqdm(szs)
|
||||
]
|
||||
|
||||
|
||||
@register_test
|
||||
def test_cacheline_size():
|
||||
# TODO: this buffer must be at least 2x the L1 cache for this test to work
|
||||
return [(stride, buf_cache_hierarchy_pchase(4*65536, stride, steps=65536)) for stride in trange(1,64)]
|
||||
# TODO: this buffer must be at least 2x the L1 cache for this test to work
|
||||
return [
|
||||
(stride, buf_cache_hierarchy_pchase(4 * 65536, stride, steps=65536))
|
||||
for stride in trange(1, 64)
|
||||
]
|
||||
|
||||
|
||||
def cl_read(sz, niter=1):
|
||||
prg = f"""__kernel void copy(
|
||||
prg = f"""__kernel void copy(
|
||||
__global float4* src,
|
||||
__global float* dst) {{
|
||||
int gid = get_global_id(0);
|
||||
if (src[gid].x == 99+get_global_id(1)) *dst = 1;
|
||||
}}"""
|
||||
|
||||
in_buf = CLBuffer(sz//4, dtypes.float32)
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("copy", prg)
|
||||
# NOTE: if nay of the niters form a local group, this is wrong
|
||||
return min([cl([sz//16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True) for _ in range(10)])*1e9
|
||||
in_buf = CLBuffer(sz // 4, dtypes.float32)
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("copy", prg)
|
||||
# NOTE: if nay of the niters form a local group, this is wrong
|
||||
return (
|
||||
min(
|
||||
[
|
||||
cl([sz // 16, niter, 1], [1, 1, 1], in_buf, out_buf, wait=True)
|
||||
for _ in range(10)
|
||||
]
|
||||
)
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_read_bandwidth():
|
||||
szs = list(range(128*1024, 20*1024*1024, 128*1024))
|
||||
NITER = 8
|
||||
base = cl_read(16, niter=NITER)
|
||||
return [(sz, (sz*NITER)/(cl_read(sz, niter=NITER)-base)) for sz in tqdm(szs)]
|
||||
szs = list(range(128 * 1024, 20 * 1024 * 1024, 128 * 1024))
|
||||
NITER = 8
|
||||
base = cl_read(16, niter=NITER)
|
||||
return [(sz, (sz * NITER) / (cl_read(sz, niter=NITER) - base)) for sz in tqdm(szs)]
|
||||
|
||||
|
||||
def gflops(niter=4, nroll=4, ngroups=4096):
|
||||
NCOMP = 8
|
||||
prg = f"""__kernel void gflops(
|
||||
NCOMP = 8
|
||||
prg = f"""__kernel void gflops(
|
||||
__global float* out_buf
|
||||
) {{
|
||||
float{NCOMP} x = (float{NCOMP})({",".join(f"get_local_id(0)+{i}" for i in range(NCOMP))});
|
||||
|
@ -125,30 +188,37 @@ def gflops(niter=4, nroll=4, ngroups=4096):
|
|||
|
||||
out_buf[get_global_id(0) >> 31] = {'+'.join(f"y.s{'0123456789abcdef'[i]}" for i in range(NCOMP))};
|
||||
}}"""
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math")
|
||||
FLOPS = NCOMP*2*2 * niter * nroll * ngroups * 32
|
||||
# NOTE: if nay of the niters form a local group, this is wrong
|
||||
return FLOPS/(min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])*1e9)
|
||||
out_buf = CLBuffer(1, dtypes.float32)
|
||||
cl = CLProgram("gflops", prg, options="-cl-mad-enable -cl-fast-relaxed-math")
|
||||
FLOPS = NCOMP * 2 * 2 * niter * nroll * ngroups * 32
|
||||
# NOTE: if nay of the niters form a local group, this is wrong
|
||||
return FLOPS / (
|
||||
min([cl([32, ngroups, 1], [32, 1, 1], out_buf, wait=True) for _ in range(10)])
|
||||
* 1e9
|
||||
)
|
||||
|
||||
|
||||
@register_test
|
||||
def test_gflops():
|
||||
return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)]
|
||||
return [(niter, gflops(niter=niter, nroll=32)) for niter in trange(1, 32, 1)]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cache = {}
|
||||
#cache = pickle.load(open("/tmp/cache.pkl", "rb"))
|
||||
#tests = {"test_cacheline_size": tests["test_cacheline_size"]}
|
||||
plt.figure(figsize=(16, 9))
|
||||
for i,(k,test) in enumerate(tests.items()):
|
||||
print(f"running {k}")
|
||||
plt.subplot(2, (len(tests)+1)//2, i+1)
|
||||
plt.title(k)
|
||||
if k == "test_memory_latency": plt.xscale('log')
|
||||
if k not in cache: cache[k] = test()
|
||||
plt.plot(*zip(*cache[k]))
|
||||
#pickle.dump(cache, open("/tmp/cache.pkl", "wb"))
|
||||
cache = {}
|
||||
# cache = pickle.load(open("/tmp/cache.pkl", "rb"))
|
||||
# tests = {"test_cacheline_size": tests["test_cacheline_size"]}
|
||||
plt.figure(figsize=(16, 9))
|
||||
for i, (k, test) in enumerate(tests.items()):
|
||||
print(f"running {k}")
|
||||
plt.subplot(2, (len(tests) + 1) // 2, i + 1)
|
||||
plt.title(k)
|
||||
if k == "test_memory_latency":
|
||||
plt.xscale("log")
|
||||
if k not in cache:
|
||||
cache[k] = test()
|
||||
plt.plot(*zip(*cache[k]))
|
||||
# pickle.dump(cache, open("/tmp/cache.pkl", "wb"))
|
||||
|
||||
plt.tight_layout(pad=0.5)
|
||||
plt.savefig("/tmp/results.png")
|
||||
plt.show()
|
||||
plt.tight_layout(pad=0.5)
|
||||
plt.savefig("/tmp/results.png")
|
||||
plt.show()
|
||||
|
|
|
@ -1,188 +1,427 @@
|
|||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from typing import (
|
||||
Tuple,
|
||||
List,
|
||||
NamedTuple,
|
||||
Any,
|
||||
Dict,
|
||||
Optional,
|
||||
Union,
|
||||
DefaultDict,
|
||||
cast,
|
||||
)
|
||||
from tinygrad.codegen.linearizer import UOps, MemOp, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.helpers import DType, dtypes, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
from tinygrad.shape.symbolic import (
|
||||
Variable,
|
||||
NumNode,
|
||||
MulNode,
|
||||
DivNode,
|
||||
ModNode,
|
||||
LtNode,
|
||||
SumNode,
|
||||
AndNode,
|
||||
)
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
_type_to_letter = {dtypes.float32: 'f', dtypes.bool: 'p', dtypes.int32: 'i', dtypes.int64: 'a', dtypes.uint32: 'u', dtypes.uint64: 'b', dtypes.float.vec(4): 'x', dtypes.uint8: 'uc', dtypes.float16: 'h',
|
||||
dtypes.int8: 'c', dtypes.uint16: 'us', dtypes.float64: 'd'}
|
||||
_type_to_letter = {
|
||||
dtypes.float32: "f",
|
||||
dtypes.bool: "p",
|
||||
dtypes.int32: "i",
|
||||
dtypes.int64: "a",
|
||||
dtypes.uint32: "u",
|
||||
dtypes.uint64: "b",
|
||||
dtypes.float.vec(4): "x",
|
||||
dtypes.uint8: "uc",
|
||||
dtypes.float16: "h",
|
||||
dtypes.int8: "c",
|
||||
dtypes.uint16: "us",
|
||||
dtypes.float64: "d",
|
||||
}
|
||||
|
||||
|
||||
class Register(NamedTuple):
|
||||
nm:str
|
||||
dtype:DType
|
||||
scalar:bool
|
||||
off:Optional[int] = None
|
||||
def __repr__(self): return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes.float.vec(4):
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
nm: str
|
||||
dtype: DType
|
||||
scalar: bool
|
||||
off: Optional[int] = None
|
||||
|
||||
def __repr__(self):
|
||||
return self.nm if self.off is None else f"{self.nm}:{self.off}"
|
||||
|
||||
def subregs(self):
|
||||
if self.dtype == dtypes.float.vec(4):
|
||||
return [Register(self.nm, dtypes.float, False, off=off) for off in range(4)]
|
||||
return []
|
||||
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
op: UOps
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
|
||||
|
||||
# warp size of 32, s registers are shared across the warp, v are 32-wide vectors
|
||||
class AssemblyLanguage:
|
||||
supports_load3: bool = False
|
||||
sin_is_sin2pi: bool = False
|
||||
no_div: bool = False
|
||||
#TODO: these should be global vars
|
||||
cnts:DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
ins: List[AssemblyInstruction] = []
|
||||
supports_load3: bool = False
|
||||
sin_is_sin2pi: bool = False
|
||||
no_div: bool = False
|
||||
# TODO: these should be global vars
|
||||
cnts: DefaultDict[Tuple[DType, bool], int] = defaultdict(int)
|
||||
tor: Dict[Any, Register] = {}
|
||||
ins: List[AssemblyInstruction] = []
|
||||
|
||||
def type_to_letter(self,x): return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||||
self.tor[tok] = ret = Register(f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}", dtype, scalar)
|
||||
if dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
self.cnts[(dtype, scalar)] += 1
|
||||
return ret
|
||||
def type_to_letter(self, x):
|
||||
return _type_to_letter[x[0]].upper() if x[1] else _type_to_letter[x[0]]
|
||||
|
||||
def render_numnode(self, b) -> Register:
|
||||
key = ("num", b)
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return self.tor[key]
|
||||
def newreg(self, tok, dtype=dtypes.float32, scalar=False) -> Register:
|
||||
self.tor[tok] = ret = Register(
|
||||
f"%{self.type_to_letter((dtype, scalar))}{self.cnts[(dtype, scalar)]}",
|
||||
dtype,
|
||||
scalar,
|
||||
)
|
||||
if dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
self.tor[tok] = Register(ret.nm, dtypes.float, ret.scalar, off)
|
||||
self.cnts[(dtype, scalar)] += 1
|
||||
return ret
|
||||
|
||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return self.tor[key]
|
||||
def render_numnode(self, b) -> Register:
|
||||
key = ("num", b)
|
||||
if key not in self.tor:
|
||||
self.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b
|
||||
)
|
||||
)
|
||||
return self.tor[key]
|
||||
|
||||
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
return self.tor[key]
|
||||
def render_alu(
|
||||
self, op, a: Register, b: Union[Register, int, float], dtype=dtypes.int32
|
||||
) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
# if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
self.newreg(
|
||||
key,
|
||||
dtype=dtype,
|
||||
scalar=a.scalar and (not isinstance(b, Register) or b.scalar),
|
||||
),
|
||||
[a, b],
|
||||
op,
|
||||
)
|
||||
)
|
||||
return self.tor[key]
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MUL, self.a.render(ops, ctx), self.b),
|
||||
DivNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.DIV, self.a.render(ops, ctx), self.b),
|
||||
ModNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.MOD, self.a.render(ops, ctx), self.b),
|
||||
LtNode: lambda self, ops, ctx: ctx.render_alu(BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.render_alu(BinaryOps.MUL, a, b.render(ops,ctx), dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
def render_cast(self, a: Register, new_dtype: DType) -> Register:
|
||||
if a.dtype == new_dtype:
|
||||
return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(
|
||||
AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])
|
||||
)
|
||||
return self.tor[key]
|
||||
|
||||
def addr_w_offset(self, args):
|
||||
assert isinstance(args, MemOp)
|
||||
idx = args.idx*args.memory_dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if nums and nums[0] < 4096 and (idx-nums[0]).min >= 0: # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = cast(int, nums[0])
|
||||
reg = idx.render(self.render_ops, self)
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||
return reg, None, off
|
||||
render_ops: Any = {
|
||||
Variable: lambda self, ops, ctx: ctx.tor[self],
|
||||
NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.MUL, self.a.render(ops, ctx), self.b
|
||||
),
|
||||
DivNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.DIV, self.a.render(ops, ctx), self.b
|
||||
),
|
||||
ModNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.MOD, self.a.render(ops, ctx), self.b
|
||||
),
|
||||
LtNode: lambda self, ops, ctx: ctx.render_alu(
|
||||
BinaryOps.CMPLT, self.a.render(ops, ctx), self.b, dtype=dtypes.bool
|
||||
),
|
||||
SumNode: lambda self, ops, ctx: functools.reduce(
|
||||
lambda a, b: ctx.render_alu(BinaryOps.ADD, a, b.render(ops, ctx)),
|
||||
self.nodes[1:],
|
||||
self.nodes[0].render(ops, ctx),
|
||||
),
|
||||
AndNode: lambda self, ops, ctx: functools.reduce(
|
||||
lambda a, b: ctx.render_alu(
|
||||
BinaryOps.MUL, a, b.render(ops, ctx), dtype=dtypes.bool
|
||||
),
|
||||
self.nodes[1:],
|
||||
self.nodes[0].render(ops, ctx),
|
||||
),
|
||||
}
|
||||
|
||||
def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
#TODO: Do not use clear()
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {args[0]:args[1] for uop,_,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == UOps.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
elif args[1] == 'local':
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == UOps.CAST:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(u, dtype)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == UOps.ALU:
|
||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((u, "rcp"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((u, "2pi"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
reg = lang.newreg(u, dtype=dtype)
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
||||
elif uop == UOps.SPECIAL:
|
||||
lang.tor[u] = lang.tor[args]
|
||||
elif uop == UOps.CONST:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
elif uop == UOps.LOAD:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
if args is None:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
def addr_w_offset(self, args):
|
||||
assert isinstance(args, MemOp)
|
||||
idx = args.idx * args.memory_dtype.itemsize
|
||||
off = 0 # TODO: should this be None?
|
||||
if isinstance(idx, SumNode):
|
||||
nums = [n.b for n in idx.nodes if isinstance(n, NumNode)]
|
||||
if (
|
||||
nums and nums[0] < 4096 and (idx - nums[0]).min >= 0
|
||||
): # TODO: different for each GPU?
|
||||
idx -= nums[0]
|
||||
off = cast(int, nums[0])
|
||||
reg = idx.render(self.render_ops, self)
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, "vec"), dtype=reg.dtype)
|
||||
self.ins.append(
|
||||
AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)
|
||||
)
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(
|
||||
BinaryOps.ADD,
|
||||
self.render_cast(reg, dtypes.uint64),
|
||||
self.tor[args.name],
|
||||
dtype=dtypes.uint64,
|
||||
)
|
||||
return reg, None, off
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
return global_size, local_size
|
||||
|
||||
def uops_to_asmstyle(lang, function_name: str, uops: List[UOp]):
|
||||
# TODO: Do not use clear()
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {
|
||||
args[0]: args[1] for uop, _, _, args, _ in uops if uop == UOps.DEFINE_GLOBAL
|
||||
}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [
|
||||
AssemblyInstruction(
|
||||
UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf
|
||||
)
|
||||
for buf in buf_to_dtype
|
||||
]
|
||||
for u in uops:
|
||||
uop, dtype, vin, args, _ = u
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
lang.newreg(args[0], dtype=dtypes.uint64),
|
||||
[args[0]],
|
||||
UnaryOps.NOOP,
|
||||
)
|
||||
)
|
||||
elif uop == UOps.LOOP:
|
||||
if args[1] == "global":
|
||||
for i, var in enumerate(args[0]):
|
||||
global_size.append(var.max + 1)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.SPECIAL,
|
||||
lang.newreg(var, dtype=dtypes.int32),
|
||||
[],
|
||||
f"gid{len(args[0])-1-i}",
|
||||
)
|
||||
)
|
||||
elif args[1] == "local":
|
||||
for i, var in enumerate(args[0]):
|
||||
local_size.append(var.max + 1)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.SPECIAL,
|
||||
lang.newreg(var, dtype=dtypes.int32),
|
||||
[],
|
||||
f"lid{len(args[0])-1-i}",
|
||||
)
|
||||
)
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(
|
||||
var, NumNode
|
||||
): # TODO: why is this coming through?
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LOAD,
|
||||
lang.newreg(var, dtype=dtypes.int32, scalar=True),
|
||||
[],
|
||||
0,
|
||||
)
|
||||
)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LABEL, None, [], "$loop_" + var.expr
|
||||
)
|
||||
)
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(
|
||||
var, NumNode
|
||||
): # TODO: why is this coming through?
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
lang.tor[var],
|
||||
[lang.tor[var], 1],
|
||||
BinaryOps.ADD,
|
||||
)
|
||||
)
|
||||
pred = lang.render_alu(
|
||||
BinaryOps.CMPLT, lang.tor[var], var.max + 1, dtypes.bool
|
||||
)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.COND_BRANCH,
|
||||
None,
|
||||
[pred],
|
||||
("$loop_" + var.expr, True),
|
||||
)
|
||||
)
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ENDLOOP,
|
||||
None,
|
||||
[lang.tor[var]],
|
||||
(var.max + 1, f"gid{i}"),
|
||||
)
|
||||
)
|
||||
elif args[1] == "local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ENDLOOP,
|
||||
None,
|
||||
[lang.tor[var]],
|
||||
(var.max + 1, f"lid{i}"),
|
||||
)
|
||||
)
|
||||
elif uop == UOps.CAST:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(u, dtype)
|
||||
for i, sr in enumerate(out.subregs()):
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)
|
||||
)
|
||||
elif uop == UOps.ALU:
|
||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((u, "pred"), dtype=dtypes.bool)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args
|
||||
)
|
||||
)
|
||||
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((u, "rcp"))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP
|
||||
)
|
||||
)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL
|
||||
)
|
||||
)
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((u, "2pi"))
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU,
|
||||
tmp,
|
||||
[lang.tor[vin[0]], 1 / (math.pi * 2)],
|
||||
BinaryOps.MUL,
|
||||
)
|
||||
)
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)
|
||||
)
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
reg = lang.newreg(u, dtype=dtype)
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
||||
elif uop == UOps.SPECIAL:
|
||||
lang.tor[u] = lang.tor[args]
|
||||
elif uop == UOps.CONST:
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(
|
||||
u,
|
||||
dtype=dtype,
|
||||
scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)),
|
||||
)
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.COND_BRANCH,
|
||||
None,
|
||||
[pred],
|
||||
(f"$skipload_{skipload_branch}", False),
|
||||
)
|
||||
)
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LOAD,
|
||||
reg,
|
||||
[idx] + ([treg] if treg is not None else []),
|
||||
(
|
||||
off,
|
||||
"global" if not args.local else "shared",
|
||||
args.memory_dtype
|
||||
if args.memory_dtype != dtypes.float
|
||||
else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.LABEL, None, [], f"$skipload_{skipload_branch}"
|
||||
)
|
||||
)
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
if args is None:
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP
|
||||
)
|
||||
)
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(
|
||||
AssemblyInstruction(
|
||||
UOps.STORE,
|
||||
None,
|
||||
[idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []),
|
||||
(
|
||||
off,
|
||||
"global" if not args.local else "shared",
|
||||
args.memory_dtype
|
||||
if args.memory_dtype != dtypes.float
|
||||
else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins:
|
||||
print(tins)
|
||||
return global_size, local_size
|
||||
|
|
|
@ -6,171 +6,268 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
|||
from tinygrad.helpers import dtypes, CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
|
||||
def float_to_hex(x):
|
||||
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
|
||||
|
||||
|
||||
def compute_offsets(total):
|
||||
quotient, remainder = divmod(total, 4096)
|
||||
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
|
||||
quotient, remainder = divmod(total, 4096)
|
||||
return [4096] * quotient + [remainder] if remainder else [4096] * quotient
|
||||
|
||||
#NOTE: Darwin needs names to start with a "_"
|
||||
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
||||
|
||||
class ARM64Language(AssemblyLanguage): pass
|
||||
# NOTE: Darwin needs names to start with a "_"
|
||||
def get_name(name):
|
||||
return ("_" if system() == "Darwin" else "") + name
|
||||
|
||||
|
||||
class ARM64Language(AssemblyLanguage):
|
||||
pass
|
||||
|
||||
|
||||
def specialize_to_arm64(fn_nm, asm):
|
||||
var_size = 16
|
||||
prev_uop:Optional[UOps] = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
||||
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
||||
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
||||
var_size = 16
|
||||
prev_uop: Optional[UOps] = None
|
||||
ins = []
|
||||
x_regs = ["x" + str(i) for i in reversed(range(12))]
|
||||
s_regs = ["s" + str(i) for i in reversed(range(3, 32)) if i <= 7 or i >= 16]
|
||||
type_to_reg = {
|
||||
dtypes.double: "d",
|
||||
dtypes.half: "h",
|
||||
dtypes.float32: "s",
|
||||
dtypes.bool: "w",
|
||||
dtypes.int8: "w",
|
||||
dtypes.int32: "w",
|
||||
dtypes.int64: "x",
|
||||
dtypes.uint8: "w",
|
||||
dtypes.uint32: "w",
|
||||
dtypes.uint64: "x",
|
||||
}
|
||||
alu = {
|
||||
BinaryOps.ADD: "add",
|
||||
BinaryOps.SUB: "sub",
|
||||
BinaryOps.MUL: "mul",
|
||||
BinaryOps.DIV: "div",
|
||||
BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "",
|
||||
BinaryOps.CMPLT: "subs",
|
||||
UnaryOps.NOOP: "mov",
|
||||
UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "bl " + get_name("sinf"),
|
||||
UnaryOps.LOG2: "bl " + get_name("log2f"),
|
||||
UnaryOps.EXP2: "bl " + get_name("exp2f"),
|
||||
UnaryOps.SQRT: "bl " + get_name("sqrtf"),
|
||||
TernaryOps.MULACC: "madd",
|
||||
TernaryOps.WHERE: "fcsel",
|
||||
}
|
||||
|
||||
def mov_imm(value, reg):
|
||||
# Manually move value into reg if value can't fit
|
||||
if value.__class__ is not float and abs(value) > abs(65535):
|
||||
ins.append(f"movz w15, #{value & 0xffff}")
|
||||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {reg}, w15")
|
||||
elif reg[0] == 's':
|
||||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||
ins.append("str x15, [sp, 16]")
|
||||
ins.append(f"ldr {reg}, [sp, 16]")
|
||||
else:
|
||||
ins.append(f"mov {reg}, #{value}")
|
||||
|
||||
# Get variables intervals
|
||||
live_range:Dict[str, List[int]] = {}
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
||||
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
|
||||
mem_vars:Dict[str, int] = {}
|
||||
rtor:Dict[str, str] = {}
|
||||
def allocate_regs(mvars):
|
||||
nonlocal var_size
|
||||
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||
if not available_regs:
|
||||
# ARM needs the stack 16-byte aligned
|
||||
var_size += 16
|
||||
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
|
||||
mem_vars[v.nm] = var_size
|
||||
rtor[v.nm] = available_regs.pop()
|
||||
|
||||
temp_floats = ['s0', 's1', 's2']
|
||||
temp_ints = ['x12', 'x13', 'x16']
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
# Clear regs out of interval
|
||||
for var, reg in list(rtor.items()):
|
||||
available_regs = s_regs if reg[0] == 's' else x_regs
|
||||
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs.append(rtor.pop(var))
|
||||
# Assign a registers to the variables using live ranges.
|
||||
allocate_regs([out] + vin)
|
||||
# Assign temp regs to vin and load them before direct use
|
||||
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
||||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == UOps.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
||||
ins.append(f"mov {rtor[out.nm]}, x15")
|
||||
else:
|
||||
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||
ins.append(f"loop_{arg}:")
|
||||
elif uop == UOps.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if rtor[out.nm][0] == 's':
|
||||
mov_imm(0.0, 's0')
|
||||
mov_imm(1.0, 's1')
|
||||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||
if rtor[out.nm][0] == 'x':
|
||||
mov_imm(0, 'x14')
|
||||
mov_imm(1, 'x15')
|
||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == UOps.ALU:
|
||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif arg == TernaryOps.WHERE:
|
||||
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||
def mov_imm(value, reg):
|
||||
# Manually move value into reg if value can't fit
|
||||
if value.__class__ is not float and abs(value) > abs(65535):
|
||||
ins.append(f"movz w15, #{value & 0xffff}")
|
||||
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
||||
ins.append(f"sxtw {reg}, w15")
|
||||
elif reg[0] == "s":
|
||||
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
||||
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
||||
ins.append("str x15, [sp, 16]")
|
||||
ins.append(f"ldr {reg}, [sp, 16]")
|
||||
else:
|
||||
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
||||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||
# Save the registers before they are cleared by func call
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append("stp x29, x30, [sp, #0]!")
|
||||
ins.append("mov x29, sp")
|
||||
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||
ins.append(alu[arg])
|
||||
ins.append(f"fmov {rtor[out.nm]}, s0")
|
||||
ins.append("mov sp, x29")
|
||||
ins.append("ldp x29, x30, [sp], #0")
|
||||
for i,k in enumerate(save_regs,1):
|
||||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||
elif arg == BinaryOps.CMPLT:
|
||||
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
||||
elif arg == BinaryOps.MOD:
|
||||
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
|
||||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == UOps.STORE:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg[1:]}:")
|
||||
elif uop == UOps.ENDLOOP:
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"b.lt loop_{arg[1]}")
|
||||
prev_uop = uop
|
||||
# store regs into memory if needed
|
||||
if out is not None and out.nm in mem_vars:
|
||||
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
||||
ins.append(f"mov {reg}, #{value}")
|
||||
|
||||
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
||||
lang = ARM64Language()
|
||||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True
|
||||
# Get variables intervals
|
||||
live_range: Dict[str, List[int]] = {}
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
for var in [v for v in [out] + vin if v is not None and v.__class__ is not int]:
|
||||
live_range[var.nm] = (
|
||||
[i, i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
||||
)
|
||||
|
||||
mem_vars: Dict[str, int] = {}
|
||||
rtor: Dict[str, str] = {}
|
||||
|
||||
def allocate_regs(mvars):
|
||||
nonlocal var_size
|
||||
for v in [
|
||||
v
|
||||
for v in mvars
|
||||
if v is not None and v.__class__ is not int and v.nm not in rtor
|
||||
]:
|
||||
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
||||
# NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
||||
if not available_regs:
|
||||
# ARM needs the stack 16-byte aligned
|
||||
var_size += 16
|
||||
available_regs.append("s0" if dtypes.is_float(out[1]) else "x12")
|
||||
mem_vars[v.nm] = var_size
|
||||
rtor[v.nm] = available_regs.pop()
|
||||
|
||||
temp_floats = ["s0", "s1", "s2"]
|
||||
temp_ints = ["x12", "x13", "x16"]
|
||||
for i, (uop, out, vin, arg) in enumerate(asm):
|
||||
# Clear regs out of interval
|
||||
for var, reg in list(rtor.items()):
|
||||
available_regs = s_regs if reg[0] == "s" else x_regs
|
||||
if var[1] not in "B" and var not in mem_vars and i > live_range[var][1]:
|
||||
available_regs.append(rtor.pop(var))
|
||||
# Assign a registers to the variables using live ranges.
|
||||
allocate_regs([out] + vin)
|
||||
# Assign temp regs to vin and load them before direct use
|
||||
for i, v in enumerate(
|
||||
[v for v in vin if v.__class__ is not int and v.nm in mem_vars]
|
||||
):
|
||||
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
||||
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == UOps.SPECIAL:
|
||||
if arg.startswith("data"):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
||||
ins.append(f"mov {rtor[out.nm]}, x15")
|
||||
else:
|
||||
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||
ins.append(f"loop_{arg}:")
|
||||
elif uop == UOps.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if rtor[out.nm][0] == "s":
|
||||
mov_imm(0.0, "s0")
|
||||
mov_imm(1.0, "s1")
|
||||
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
||||
if rtor[out.nm][0] == "x":
|
||||
mov_imm(0, "x14")
|
||||
mov_imm(1, "x15")
|
||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == UOps.ALU:
|
||||
if len(vin) == 2 and vin[1].__class__ is int:
|
||||
mov_imm(vin[1], "x15")
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(
|
||||
f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||
)
|
||||
elif arg == TernaryOps.WHERE:
|
||||
ins.append(
|
||||
f"fcmp {rtor[vin[0].nm]}, #0.0"
|
||||
if rtor[vin[0].nm][0] == "s"
|
||||
else f"cmp {rtor[vin[0].nm]}, #0"
|
||||
)
|
||||
ins.append(
|
||||
f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne"
|
||||
)
|
||||
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
||||
# NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
||||
if CI:
|
||||
ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
||||
else:
|
||||
save_regs = [
|
||||
k for k in rtor.keys() if k != out.nm and k not in mem_vars
|
||||
]
|
||||
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
||||
# Save the registers before they are cleared by func call
|
||||
for i, k in enumerate(save_regs, 1):
|
||||
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append("stp x29, x30, [sp, #0]!")
|
||||
ins.append("mov x29, sp")
|
||||
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
||||
ins.append(alu[arg])
|
||||
ins.append(f"fmov {rtor[out.nm]}, s0")
|
||||
ins.append("mov sp, x29")
|
||||
ins.append("ldp x29, x30, [sp], #0")
|
||||
for i, k in enumerate(save_regs, 1):
|
||||
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
||||
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
||||
elif arg == BinaryOps.CMPLT:
|
||||
ins.append(
|
||||
f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||
if not dtypes.is_float(vin[0][1])
|
||||
else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}"
|
||||
)
|
||||
elif arg == BinaryOps.MOD:
|
||||
rhs = "x15" if vin[1].__class__ is int else rtor[vin[1].nm]
|
||||
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(
|
||||
f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}"
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_in = (
|
||||
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
||||
if arg[2] is not None
|
||||
else rtor[out.nm]
|
||||
)
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(
|
||||
f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]"
|
||||
)
|
||||
if arg[2] is not None:
|
||||
ins.append(
|
||||
f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}"
|
||||
)
|
||||
elif uop == UOps.STORE:
|
||||
# NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (
|
||||
type_to_reg[arg[2]] + ("0" if dtypes.is_float(arg[2]) else "12")
|
||||
if arg[2] is not None
|
||||
else rtor[vin[1].nm]
|
||||
)
|
||||
if arg[2] is not None:
|
||||
ins.append(
|
||||
f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}"
|
||||
)
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
# TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg[1:]}:")
|
||||
elif uop == UOps.ENDLOOP:
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"b.lt loop_{arg[1]}")
|
||||
prev_uop = uop
|
||||
# store regs into memory if needed
|
||||
if out is not None and out.nm in mem_vars:
|
||||
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
||||
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
||||
return "\n".join(
|
||||
[
|
||||
f"//varsize {var_size}",
|
||||
".arch armv8-a",
|
||||
".text",
|
||||
f".global {get_name(fn_nm)}",
|
||||
".p2align 2",
|
||||
f"{get_name(fn_nm)}:",
|
||||
"mov x17, sp",
|
||||
]
|
||||
+ [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
||||
+ ins
|
||||
+ [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)]
|
||||
+ ["ret", "\n"]
|
||||
)
|
||||
|
||||
|
||||
def uops_to_arm64_asm(
|
||||
fn_nm: str, uops: List[UOp]
|
||||
) -> Tuple[str, List[int], List[int], bool]:
|
||||
lang = ARM64Language()
|
||||
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
||||
return (
|
||||
specialize_to_arm64(fn_nm, lang.ins),
|
||||
global_size[::-1],
|
||||
local_size[::-1],
|
||||
True,
|
||||
)
|
||||
|
|
|
@ -6,100 +6,211 @@ from tinygrad.helpers import dtypes
|
|||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
|
||||
dtype_to_nvtype = {dtypes.float32: "f32", dtypes.float16: "f16", dtypes.int64: "s64", dtypes.int32: "s32", dtypes.int8: "s8", dtypes.bool: "pred", dtypes.uint64: "u64", dtypes.uint32: "u32", dtypes.uint16: "u16", dtypes.uint8: "u8", "bits16": "b16", dtypes.float64: "f64"}
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
dtype_to_nvtype = {
|
||||
dtypes.float32: "f32",
|
||||
dtypes.float16: "f16",
|
||||
dtypes.int64: "s64",
|
||||
dtypes.int32: "s32",
|
||||
dtypes.int8: "s8",
|
||||
dtypes.bool: "pred",
|
||||
dtypes.uint64: "u64",
|
||||
dtypes.uint32: "u32",
|
||||
dtypes.uint16: "u16",
|
||||
dtypes.uint8: "u8",
|
||||
"bits16": "b16",
|
||||
dtypes.float64: "f64",
|
||||
}
|
||||
|
||||
|
||||
def float_to_hex(x):
|
||||
return "%02X%02X%02X%02X" % tuple(struct.pack("f", x)[::-1])
|
||||
|
||||
|
||||
def ptx_needs_cast(dest_dtype, src_dtype):
|
||||
return (
|
||||
dtypes.is_float(dest_dtype)
|
||||
and dtypes.is_int(src_dtype)
|
||||
or dtypes.is_int(dest_dtype)
|
||||
and dtypes.is_float(src_dtype)
|
||||
or (
|
||||
dtypes.is_float(src_dtype)
|
||||
and dtypes.is_float(dest_dtype)
|
||||
and dest_dtype.itemsize != src_dtype.itemsize
|
||||
)
|
||||
)
|
||||
|
||||
def ptx_needs_cast(dest_dtype, src_dtype): return dtypes.is_float(dest_dtype) and dtypes.is_int(src_dtype) or dtypes.is_int(dest_dtype) and dtypes.is_float(src_dtype) or (dtypes.is_float(src_dtype) and dtypes.is_float(dest_dtype) and dest_dtype.itemsize != src_dtype.itemsize)
|
||||
|
||||
def render_cast(ins, inp, out):
|
||||
if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)):
|
||||
ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};")
|
||||
elif out.dtype == dtypes.bool:
|
||||
if inp.dtype == dtypes.bool:
|
||||
ins.append(f"mov.pred {out}, {inp};")
|
||||
if inp.dtype == dtypes.bool and (
|
||||
dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)
|
||||
):
|
||||
ins.append(
|
||||
f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};"
|
||||
)
|
||||
elif out.dtype == dtypes.bool:
|
||||
if inp.dtype == dtypes.bool:
|
||||
ins.append(f"mov.pred {out}, {inp};")
|
||||
else:
|
||||
ins.append(
|
||||
f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};"
|
||||
)
|
||||
else:
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};")
|
||||
else:
|
||||
round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else ''
|
||||
ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};")
|
||||
round_mod = (
|
||||
".rzi"
|
||||
if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype)
|
||||
else ".rz"
|
||||
if dtypes.is_float(out.dtype)
|
||||
and (
|
||||
dtypes.is_int(inp.dtype)
|
||||
or dtypes.is_float(inp.dtype)
|
||||
and inp.dtype.itemsize > out.dtype.itemsize
|
||||
)
|
||||
else ""
|
||||
)
|
||||
ins.append(
|
||||
f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};"
|
||||
)
|
||||
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#
|
||||
|
||||
|
||||
class PTXLanguage(AssemblyLanguage):
|
||||
supports_constant_folding: bool = True
|
||||
supports_constant_folding: bool = True
|
||||
|
||||
|
||||
def specialize_to_ptx(lang, function_name):
|
||||
param_cnt = 0
|
||||
ins = []
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem", BinaryOps.CMPLT: "setp.lt", UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == UOps.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
||||
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
|
||||
if arg == TernaryOps.WHERE:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
reg = vin[0]
|
||||
else:
|
||||
reg = lang.newreg((vin[0], 'bool'), dtypes.bool)
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
dt = ('u16', dtypes.uint16) if arg[2] == dtypes.bool == out.dtype else ('u8', dtypes.uint8) if arg[2] == dtypes.bool else ('b16', dtypes.float16) if arg[2] == dtypes.half else (dtype_to_nvtype[arg[2]], arg[2])
|
||||
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
||||
ins.append(f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == UOps.STORE:
|
||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
||||
render_cast(ins, vin[1], prereg)
|
||||
else: prereg = vin[1]
|
||||
reg = lang.newreg((prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]), dtype=dtypes.uint16 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2])
|
||||
render_cast(ins, prereg, reg)
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
||||
else:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == UOps.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
param_cnt = 0
|
||||
ins = []
|
||||
alu = {
|
||||
BinaryOps.ADD: "add",
|
||||
BinaryOps.SUB: "sub",
|
||||
BinaryOps.MUL: "mul",
|
||||
BinaryOps.DIV: "div",
|
||||
BinaryOps.MAX: "max",
|
||||
BinaryOps.MOD: "rem",
|
||||
BinaryOps.CMPLT: "setp.lt",
|
||||
UnaryOps.SQRT: "sqrt.approx",
|
||||
UnaryOps.NOOP: "mov",
|
||||
UnaryOps.NEG: "neg",
|
||||
UnaryOps.SIN: "sin.approx",
|
||||
UnaryOps.LOG2: "lg2.approx",
|
||||
UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn",
|
||||
TernaryOps.WHERE: "selp",
|
||||
}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == UOps.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith("data"):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
# TODO: we sometimes want this to be local, nvcc converts to global most of the time, not sure when we would need to?
|
||||
# ins.append(f"cvta.to.global.u64 {out}, {out};")
|
||||
elif arg.startswith("gid"):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith("lid"):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype
|
||||
if arg == TernaryOps.WHERE:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
reg = vin[0]
|
||||
else:
|
||||
reg = lang.newreg((vin[0], "bool"), dtypes.bool)
|
||||
ins.append(
|
||||
f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};"
|
||||
)
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(
|
||||
f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};"
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(
|
||||
f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};"
|
||||
)
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
dt = (
|
||||
("u16", dtypes.uint16)
|
||||
if arg[2] == dtypes.bool == out.dtype
|
||||
else ("u8", dtypes.uint8)
|
||||
if arg[2] == dtypes.bool
|
||||
else ("b16", dtypes.float16)
|
||||
if arg[2] == dtypes.half
|
||||
else (dtype_to_nvtype[arg[2]], arg[2])
|
||||
)
|
||||
reg = lang.newreg((out, dt[0]), dtype=dt[1])
|
||||
ins.append(
|
||||
f"ld.{arg[1]}.{dt[0]} {reg}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
|
||||
)
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(
|
||||
f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];"
|
||||
)
|
||||
elif uop == UOps.STORE:
|
||||
if (
|
||||
ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype)
|
||||
or arg[2] == dtypes.bool
|
||||
):
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1], "bool"), dtype=dtypes.bool)
|
||||
render_cast(ins, vin[1], prereg)
|
||||
else:
|
||||
prereg = vin[1]
|
||||
reg = lang.newreg(
|
||||
(prereg, dtypes.uint16 if arg[2] == dtypes.bool else arg[2]),
|
||||
dtype=dtypes.uint16
|
||||
if arg[2] == dtypes.bool
|
||||
else dtypes.float
|
||||
if arg[2] is None
|
||||
else arg[2],
|
||||
)
|
||||
render_cast(ins, prereg, reg)
|
||||
ins.append(
|
||||
f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};"
|
||||
)
|
||||
else:
|
||||
ins.append(
|
||||
f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};"
|
||||
)
|
||||
elif uop == UOps.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{"]
|
||||
for arg in [(dtype, lang.type_to_letter(dtype), c) for dtype,c in lang.cnts.items()]: ins_prefix.append(f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",)
|
||||
ins = ins_prefix + ins
|
||||
ins += ["ret;", "}"]
|
||||
return '\n'.join(ins)
|
||||
ins_prefix = [
|
||||
".version 7.8",
|
||||
".target " + arch(),
|
||||
".address_size 64",
|
||||
f".visible .entry {function_name}({', '.join(f'.param .u64 data{i}' for i in range(param_cnt))}) {{",
|
||||
]
|
||||
for arg in [
|
||||
(dtype, lang.type_to_letter(dtype), c) for dtype, c in lang.cnts.items()
|
||||
]:
|
||||
ins_prefix.append(
|
||||
f".reg .{dtype_to_nvtype[arg[0][0]]} %{arg[1]}<{arg[2]}>;",
|
||||
)
|
||||
ins = ins_prefix + ins
|
||||
ins += ["ret;", "}"]
|
||||
return "\n".join(ins)
|
||||
|
||||
def uops_to_ptx_asm(function_name:str, uops:List[UOp]):
|
||||
lang = PTXLanguage()
|
||||
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
||||
return specialize_to_ptx(lang, function_name), global_size[::-1], local_size[::-1], True
|
||||
|
||||
def uops_to_ptx_asm(function_name: str, uops: List[UOp]):
|
||||
lang = PTXLanguage()
|
||||
global_size, local_size = uops_to_asmstyle(lang, function_name, uops)
|
||||
return (
|
||||
specialize_to_ptx(lang, function_name),
|
||||
global_size[::-1],
|
||||
local_size[::-1],
|
||||
True,
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
|||
|
||||
# ugh, is this really needed?
|
||||
from extra.helpers import enable_early_exec
|
||||
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
boilerplate_start = """
|
||||
|
@ -24,180 +25,359 @@ code_start = """.end_amdhsa_kernel
|
|||
code:
|
||||
"""
|
||||
|
||||
|
||||
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
||||
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
||||
# RDNA3 is actually a SIMD machine!
|
||||
class RDNACodegen(AssemblyCodegen):
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
supports_load3: bool = True
|
||||
sin_is_sin2pi: bool = True
|
||||
no_div: bool = True
|
||||
supports_float4: bool = True
|
||||
supports_float4_alu: bool = True
|
||||
supports_load3: bool = True
|
||||
sin_is_sin2pi: bool = True
|
||||
no_div: bool = True
|
||||
|
||||
def specialize(self, asm) -> Tuple[str, str]:
|
||||
args = []
|
||||
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
|
||||
ins = []
|
||||
def specialize(self, asm) -> Tuple[str, str]:
|
||||
args = []
|
||||
for i, b in enumerate(self.bufs):
|
||||
args.append(
|
||||
{
|
||||
".address_space": "global",
|
||||
".name": f"buf_{i}",
|
||||
".offset": i * 8,
|
||||
".size": 8,
|
||||
".type_name": b.dtype.name + "*",
|
||||
".value_kind": "global_buffer",
|
||||
}
|
||||
)
|
||||
ins = []
|
||||
|
||||
v_cnt = 3 # v[0:2] is local_xyz
|
||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||
v_cnt = 3 # v[0:2] is local_xyz
|
||||
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
||||
|
||||
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
|
||||
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
|
||||
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
|
||||
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
|
||||
BinaryOps.CMPLT: "cmp_lt"}
|
||||
dtype_to_rdnatype = {
|
||||
dtypes.float32: "f32",
|
||||
dtypes.int64: "i64",
|
||||
dtypes.int32: "i32",
|
||||
dtypes.uint64: "u64",
|
||||
dtypes.bool: "i32",
|
||||
}
|
||||
alu = {
|
||||
BinaryOps.ADD: "add",
|
||||
BinaryOps.SUB: "sub",
|
||||
BinaryOps.MUL: "mul",
|
||||
TernaryOps.MULACC: "fma",
|
||||
BinaryOps.MAX: "max",
|
||||
UnaryOps.RECIP: "rcp",
|
||||
UnaryOps.NOOP: "mov",
|
||||
UnaryOps.SIN: "sin",
|
||||
UnaryOps.LOG2: "log",
|
||||
UnaryOps.EXP2: "exp",
|
||||
BinaryOps.CMPLT: "cmp_lt",
|
||||
}
|
||||
|
||||
pend_regs:Set[Register] = set()
|
||||
rtor:Dict[Register, str] = {}
|
||||
def reg_in(x):
|
||||
nonlocal pend_regs
|
||||
#print("reg_in", x, rtor[x], pend_regs)
|
||||
if x in pend_regs:
|
||||
#print("clear")
|
||||
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
|
||||
pend_regs.clear()
|
||||
return rtor[x]
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
if arg[0][1]:
|
||||
s_cnt += s_cnt % align
|
||||
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
|
||||
s_cnt += align
|
||||
pend_regs: Set[Register] = set()
|
||||
rtor: Dict[Register, str] = {}
|
||||
|
||||
def reg_in(x):
|
||||
nonlocal pend_regs
|
||||
# print("reg_in", x, rtor[x], pend_regs)
|
||||
if x in pend_regs:
|
||||
# print("clear")
|
||||
ins.append("s_waitcnt lgkmcnt(0), vmcnt(0)")
|
||||
pend_regs.clear()
|
||||
return rtor[x]
|
||||
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if arg[0][0] in [
|
||||
dtypes.uint32,
|
||||
dtypes.uint64,
|
||||
dtypes.int64,
|
||||
dtypes.int32,
|
||||
dtypes.float32,
|
||||
dtypes.float.vec(4),
|
||||
]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
if arg[0][1]:
|
||||
s_cnt += s_cnt % align
|
||||
reg_name = (
|
||||
f"s[{s_cnt}:{s_cnt + align - 1}]"
|
||||
if align > 1
|
||||
else f"s{s_cnt}"
|
||||
)
|
||||
s_cnt += align
|
||||
else:
|
||||
v_cnt += v_cnt % align
|
||||
reg_name = (
|
||||
f"v[{v_cnt}:{v_cnt + align - 1}]"
|
||||
if align > 1
|
||||
else f"v{v_cnt}"
|
||||
)
|
||||
v_cnt += align
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
|
||||
if arg[0][0] == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
reg_name = (
|
||||
f"s{s_cnt-align+off}"
|
||||
if arg[0][1]
|
||||
else f"v{v_cnt-align+off}"
|
||||
)
|
||||
rtor[
|
||||
Register(
|
||||
f"%{arg[1]}{i}", dtypes.float, False, off=off
|
||||
)
|
||||
] = reg_name
|
||||
elif arg[0][0] == dtypes.bool:
|
||||
for i in range(arg[2]):
|
||||
reg_name = (
|
||||
"scc" if arg[0][1] else "vcc_lo"
|
||||
) # `_lo` suffix since we're running wavefront_size=32
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"DEFINE_REGISTER not implemented for arg: ", arg
|
||||
)
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith("buf"):
|
||||
i = int(arg[3:])
|
||||
ins.append(f"s_load_b64 {reg_out(out)}, s[0:1], {i*8}")
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs():
|
||||
pend_regs.add(r)
|
||||
elif arg.startswith("gid"):
|
||||
ins.append(f"v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}")
|
||||
# the docs lied, this is actually y
|
||||
if int(arg[3]) == 2:
|
||||
ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||
if int(arg[3]) == 1:
|
||||
ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||
elif int(arg[3]) == 0:
|
||||
ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||
# get local size
|
||||
offset = len(args) * 8
|
||||
args.append(
|
||||
{
|
||||
".offset": offset,
|
||||
".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}",
|
||||
".size": 8,
|
||||
}
|
||||
)
|
||||
ins.append(f"s_load_b32 s{2+int(arg[3])}, s[0:1], {offset}")
|
||||
ins.append("s_waitcnt vmcnt(0) lgkmcnt(0)")
|
||||
pend_regs.clear()
|
||||
ins.append(
|
||||
f"v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}"
|
||||
)
|
||||
ins.append(
|
||||
f"v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}"
|
||||
)
|
||||
elif uop == UOps.CONST:
|
||||
if arg == float("inf"):
|
||||
arg = "0x7f800000"
|
||||
elif arg == float("-inf"):
|
||||
arg = "0xff800000"
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
ins.append(
|
||||
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}"
|
||||
)
|
||||
else:
|
||||
ins.append(
|
||||
f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}"
|
||||
)
|
||||
elif uop == UOps.ALU:
|
||||
if arg in [BinaryOps.CMPLT]:
|
||||
ins.append(
|
||||
f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
|
||||
)
|
||||
else:
|
||||
alu_arg = alu[arg]
|
||||
if arg == TernaryOps.MULACC and out == vin[2]:
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for rr in zip(
|
||||
*[
|
||||
x.subregs()
|
||||
if x.dtype == dtypes.float.vec(4)
|
||||
else [x, x, x, x]
|
||||
for x in [out] + vin
|
||||
]
|
||||
):
|
||||
ins.append(
|
||||
f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}"
|
||||
)
|
||||
else:
|
||||
ins.append(
|
||||
f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}"
|
||||
)
|
||||
elif uop == UOps.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(
|
||||
f"s_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}"
|
||||
)
|
||||
else:
|
||||
ins.append(
|
||||
f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
|
||||
)
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs():
|
||||
pend_regs.add(r)
|
||||
elif uop == UOps.STORE:
|
||||
ins.append(
|
||||
f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}'
|
||||
)
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
elif uop == UOps.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(
|
||||
f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
||||
else:
|
||||
v_cnt += v_cnt % align
|
||||
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
|
||||
v_cnt += align
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
raise NotImplementedError(uop)
|
||||
|
||||
if arg[0][0] == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
||||
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
||||
elif arg[0][0] == dtypes.bool:
|
||||
for i in range(arg[2]):
|
||||
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
||||
elif uop == UOps.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
i = int(arg[3:])
|
||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif arg.startswith('gid'):
|
||||
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
|
||||
# the docs lied, this is actually y
|
||||
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
||||
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
|
||||
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
||||
# get local size
|
||||
offset = len(args)*8
|
||||
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
|
||||
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
|
||||
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
|
||||
pend_regs.clear()
|
||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
||||
elif uop == UOps.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for off in range(4):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
elif uop == UOps.ALU:
|
||||
if arg in [BinaryOps.CMPLT]:
|
||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
alu_arg = alu[arg]
|
||||
if arg == TernaryOps.MULACC and out == vin[2]:
|
||||
alu_arg = "fmac"
|
||||
vin = vin[0:2]
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
else:
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif uop == UOps.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
elif uop == UOps.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
elif uop == UOps.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
||||
else:
|
||||
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
||||
else:
|
||||
raise NotImplementedError(uop)
|
||||
ins += ["s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm", "s_code_end"]
|
||||
|
||||
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
||||
# dual alu group
|
||||
seen = set()
|
||||
new_ins = []
|
||||
for i, tins in enumerate(ins):
|
||||
if tins in seen:
|
||||
continue
|
||||
if tins.startswith("v_fmac_f32"):
|
||||
for gins in reversed(ins[i + 1 :]):
|
||||
if gins in seen:
|
||||
continue
|
||||
if gins.startswith("v_fmac_f32"):
|
||||
r0 = [int(x[1:].strip(",")) for x in tins.split(" ")[1:]]
|
||||
r1 = [int(x[1:].strip(",")) for x in gins.split(" ")[1:]]
|
||||
if r0[0] % 2 == r1[0] % 2:
|
||||
continue
|
||||
if r0[1] % 2 == r1[1] % 2:
|
||||
continue
|
||||
if r0[2] % 2 == r1[2] % 2:
|
||||
continue
|
||||
new_ins.append(
|
||||
tins.replace("v_", "v_dual_")
|
||||
+ " :: "
|
||||
+ gins.replace("v_", "v_dual_")
|
||||
)
|
||||
seen.add(tins)
|
||||
seen.add(gins)
|
||||
break
|
||||
if tins not in seen:
|
||||
new_ins.append(tins)
|
||||
ins = new_ins
|
||||
|
||||
# dual alu group
|
||||
seen = set()
|
||||
new_ins = []
|
||||
for i,tins in enumerate(ins):
|
||||
if tins in seen: continue
|
||||
if tins.startswith("v_fmac_f32"):
|
||||
for gins in reversed(ins[i+1:]):
|
||||
if gins in seen: continue
|
||||
if gins.startswith("v_fmac_f32"):
|
||||
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
|
||||
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
|
||||
if r0[0]%2 == r1[0]%2: continue
|
||||
if r0[1]%2 == r1[1]%2: continue
|
||||
if r0[2]%2 == r1[2]%2: continue
|
||||
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
|
||||
seen.add(tins)
|
||||
seen.add(gins)
|
||||
break
|
||||
if tins not in seen:
|
||||
new_ins.append(tins)
|
||||
ins = new_ins
|
||||
return "code", self.assemble(args, ins, v_cnt, s_cnt)
|
||||
|
||||
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||
kernel_desc = {
|
||||
".amdhsa_group_segment_fixed_size": 0,
|
||||
".amdhsa_private_segment_fixed_size": 0,
|
||||
".amdhsa_kernarg_size": 0,
|
||||
".amdhsa_next_free_vgpr": v_cnt, # this matters!
|
||||
".amdhsa_reserve_vcc": 0,
|
||||
".amdhsa_reserve_xnack_mask": 0,
|
||||
".amdhsa_next_free_sgpr": s_cnt,
|
||||
".amdhsa_float_round_mode_32": 0,
|
||||
".amdhsa_float_round_mode_16_64": 0,
|
||||
".amdhsa_float_denorm_mode_32": 3,
|
||||
".amdhsa_float_denorm_mode_16_64": 3,
|
||||
".amdhsa_dx10_clamp": 1,
|
||||
".amdhsa_ieee_mode": 1,
|
||||
".amdhsa_fp16_overflow": 0,
|
||||
".amdhsa_workgroup_processor_mode": 1,
|
||||
".amdhsa_memory_ordered": 1,
|
||||
".amdhsa_forward_progress": 0,
|
||||
".amdhsa_enable_private_segment": 0,
|
||||
".amdhsa_system_sgpr_workgroup_id_x": 1,
|
||||
".amdhsa_system_sgpr_workgroup_id_y": 1,
|
||||
".amdhsa_system_sgpr_workgroup_id_z": 1,
|
||||
".amdhsa_system_sgpr_workgroup_info": 0,
|
||||
".amdhsa_system_vgpr_workitem_id": 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
".amdhsa_exception_fp_ieee_invalid_op": 0,
|
||||
".amdhsa_exception_fp_denorm_src": 0,
|
||||
".amdhsa_exception_fp_ieee_div_zero": 0,
|
||||
".amdhsa_exception_fp_ieee_overflow": 0,
|
||||
".amdhsa_exception_fp_ieee_underflow": 0,
|
||||
".amdhsa_exception_fp_ieee_inexact": 0,
|
||||
".amdhsa_exception_int_div_zero": 0,
|
||||
".amdhsa_user_sgpr_dispatch_ptr": 0,
|
||||
".amdhsa_user_sgpr_queue_ptr": 0,
|
||||
".amdhsa_user_sgpr_kernarg_segment_ptr": 1,
|
||||
".amdhsa_user_sgpr_dispatch_id": 0,
|
||||
".amdhsa_user_sgpr_private_segment_size": 0,
|
||||
".amdhsa_wavefront_size32": 1,
|
||||
".amdhsa_uses_dynamic_stack": 0,
|
||||
}
|
||||
|
||||
def assemble(self, args, ins, v_cnt, s_cnt):
|
||||
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
||||
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
||||
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
||||
'.amdhsa_next_free_sgpr': s_cnt,
|
||||
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
|
||||
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
||||
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
||||
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
||||
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
||||
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
||||
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
||||
metadata = {
|
||||
"amdhsa.kernels": [
|
||||
{
|
||||
".args": args,
|
||||
".group_segment_fixed_size": 0,
|
||||
".kernarg_segment_align": 8,
|
||||
".kernarg_segment_size": args[-1][".offset"] + args[-1][".size"],
|
||||
".language": "OpenCL C",
|
||||
".language_version": [1, 2],
|
||||
".max_flat_workgroup_size": 256,
|
||||
".name": "code",
|
||||
".private_segment_fixed_size": 0,
|
||||
".sgpr_count": s_cnt,
|
||||
".sgpr_spill_count": 0,
|
||||
".symbol": "code.kd",
|
||||
".uses_dynamic_stack": False,
|
||||
".vgpr_count": v_cnt,
|
||||
".vgpr_spill_count": 0,
|
||||
".wavefront_size": 32,
|
||||
}
|
||||
],
|
||||
"amdhsa.target": "amdgcn-amd-amdhsa--gfx1100",
|
||||
"amdhsa.version": [1, 2],
|
||||
}
|
||||
|
||||
metadata = {'amdhsa.kernels': [{'.args': args,
|
||||
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
||||
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
||||
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
||||
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
||||
'.wavefront_size': 32}],
|
||||
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
||||
|
||||
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
|
||||
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
|
||||
return asm
|
||||
code = (
|
||||
boilerplate_start
|
||||
+ "\n"
|
||||
+ "\n".join("%s %d" % x for x in kernel_desc.items())
|
||||
+ "\n"
|
||||
+ code_start
|
||||
+ "\n".join(ins)
|
||||
+ "\n.amdgpu_metadata\n"
|
||||
+ yaml.dump(metadata)
|
||||
+ ".end_amdgpu_metadata"
|
||||
)
|
||||
obj = early_exec(
|
||||
(
|
||||
[
|
||||
ROCM_LLVM_PATH / "llvm-mc",
|
||||
"--arch=amdgcn",
|
||||
"--mcpu=gfx1100",
|
||||
"--triple=amdgcn-amd-amdhsa",
|
||||
"--filetype=obj",
|
||||
"-",
|
||||
],
|
||||
code.encode("utf-8"),
|
||||
)
|
||||
)
|
||||
asm = early_exec(
|
||||
(
|
||||
[ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"],
|
||||
obj,
|
||||
)
|
||||
)
|
||||
return asm
|
||||
|
|
|
@ -3,8 +3,10 @@ import numpy as np
|
|||
from tinygrad.runtime.ops_cuda import CUDAProgram, RawCUDABuffer
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
|
||||
prg = CUDAProgram("test", """
|
||||
test = RawCUDABuffer.fromCPU(np.zeros(10, np.float32))
|
||||
prg = CUDAProgram(
|
||||
"test",
|
||||
"""
|
||||
.version 7.8
|
||||
.target sm_86
|
||||
.address_size 64
|
||||
|
@ -17,7 +19,8 @@ if __name__ == "__main__":
|
|||
mov.u32 %r1, 0x40000000; // 2.0 in float
|
||||
st.global.u32 [%rd2], %r1;
|
||||
ret;
|
||||
}""", binary=True)
|
||||
prg([1], [1], test)
|
||||
print(test.toCPU())
|
||||
|
||||
}""",
|
||||
binary=True,
|
||||
)
|
||||
prg([1], [1], test)
|
||||
print(test.toCPU())
|
||||
|
|
|
@ -3,6 +3,7 @@ import pathlib
|
|||
from hexdump import hexdump
|
||||
from tinygrad.helpers import colored
|
||||
from extra.helpers import enable_early_exec
|
||||
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH
|
||||
|
@ -14,13 +15,13 @@ DUAL_ALU = True
|
|||
F32 = True
|
||||
|
||||
if ENABLE_NON_ASM:
|
||||
buf = CLBuffer.fromCPU(np.zeros(10, np.float32))
|
||||
prg_empty = CLProgram("code", "__kernel void code(__global float *a) { a[0] = 1; }")
|
||||
asm_real = prg_empty.binary()
|
||||
with open("/tmp/cc.elf", "wb") as f:
|
||||
f.write(asm_real)
|
||||
prg_empty([1], [1], buf, wait=True)
|
||||
print(buf.toCPU())
|
||||
buf = CLBuffer.fromCPU(np.zeros(10, np.float32))
|
||||
prg_empty = CLProgram("code", "__kernel void code(__global float *a) { a[0] = 1; }")
|
||||
asm_real = prg_empty.binary()
|
||||
with open("/tmp/cc.elf", "wb") as f:
|
||||
f.write(asm_real)
|
||||
prg_empty([1], [1], buf, wait=True)
|
||||
print(buf.toCPU())
|
||||
|
||||
print(colored("creating CLBuffer", "green"))
|
||||
buf = CLBuffer.fromCPU(np.zeros(10, np.float32))
|
||||
|
@ -30,51 +31,71 @@ gen = []
|
|||
FLOPS = 0
|
||||
MAX_REG = 251
|
||||
for j in range(1):
|
||||
if WMMA:
|
||||
KY, KX = 4, 4
|
||||
for y in range(KY):
|
||||
for x in range(KX):
|
||||
c = (y*KX+x)*8
|
||||
a = (KY*KX*8) + y*8
|
||||
b = (KY*KX*8) + (KY*8) + x*8
|
||||
gen.append(f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]")
|
||||
FLOPS += 16*8*2
|
||||
else:
|
||||
for i in range(0, MAX_REG, 6):
|
||||
if DUAL_ALU:
|
||||
if F32:
|
||||
gen.append(f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
||||
FLOPS += 4
|
||||
else:
|
||||
gen.append(f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}")
|
||||
FLOPS += 8
|
||||
else:
|
||||
assert F32
|
||||
gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}")
|
||||
gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
||||
code = code.replace("// FLOPS", '\n'.join(gen))
|
||||
if WMMA:
|
||||
KY, KX = 4, 4
|
||||
for y in range(KY):
|
||||
for x in range(KX):
|
||||
c = (y * KX + x) * 8
|
||||
a = (KY * KX * 8) + y * 8
|
||||
b = (KY * KX * 8) + (KY * 8) + x * 8
|
||||
gen.append(
|
||||
f"v_wmma_f32_16x16x16_f16 v[{c}:{c+7}], v[{a}:{a+7}], v[{b}:{b+7}], v[{c}:{c+7}]"
|
||||
)
|
||||
FLOPS += 16 * 8 * 2
|
||||
else:
|
||||
for i in range(0, MAX_REG, 6):
|
||||
if DUAL_ALU:
|
||||
if F32:
|
||||
gen.append(
|
||||
f"v_dual_fmac_f32 v{i+0}, v{i+1}, v{i+2} :: v_dual_fmac_f32 v{i+3}, v{i+4}, v{i+5}"
|
||||
)
|
||||
FLOPS += 4
|
||||
else:
|
||||
gen.append(
|
||||
f"v_dual_dot2acc_f32_f16 v{i+0}, v{i+1}, v{i+2} :: v_dual_dot2acc_f32_f16 v{i+3}, v{i+4}, v{i+5}"
|
||||
)
|
||||
FLOPS += 8
|
||||
else:
|
||||
assert F32
|
||||
gen.append(f"v_fmac_f32 v{i+0}, v{i+1}, v{i+2}")
|
||||
gen.append(f"v_fmac_f32 v{i+3}, v{i+4}, v{i+5}")
|
||||
code = code.replace("// FLOPS", "\n".join(gen))
|
||||
print(code)
|
||||
|
||||
|
||||
# fix: COMGR failed to get code object ISA name. set triple to 'amdgcn-amd-amdhsa'
|
||||
|
||||
object = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
||||
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object))
|
||||
object = early_exec(
|
||||
(
|
||||
[
|
||||
ROCM_LLVM_PATH / "llvm-mc",
|
||||
"--arch=amdgcn",
|
||||
"--mcpu=gfx1100",
|
||||
"--triple=amdgcn-amd-amdhsa",
|
||||
"--filetype=obj",
|
||||
"-",
|
||||
],
|
||||
code.encode("utf-8"),
|
||||
)
|
||||
)
|
||||
asm = early_exec(
|
||||
([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], object)
|
||||
)
|
||||
|
||||
with open("/tmp/cc2.o", "wb") as f:
|
||||
f.write(object)
|
||||
f.write(object)
|
||||
with open("/tmp/cc2.elf", "wb") as f:
|
||||
f.write(asm)
|
||||
f.write(asm)
|
||||
|
||||
print(colored("creating CLProgram", "green"))
|
||||
prg = CLProgram("code", asm)
|
||||
|
||||
print(colored("running program", "green"))
|
||||
G = 512
|
||||
FLOPS *= 100000*G*G # loop * global_size
|
||||
FLOPS *= 100000 * G * G # loop * global_size
|
||||
for i in range(3):
|
||||
tm = prg(buf, global_size=[G//256, G, 1], local_size=[256, 1, 1], wait=True)
|
||||
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
|
||||
tm = prg(buf, global_size=[G // 256, G, 1], local_size=[256, 1, 1], wait=True)
|
||||
print(f"ran in {tm*1e3:.2f} ms, {FLOPS/(tm*1e9):.2f} GFLOPS")
|
||||
|
||||
print(colored("transferring buffer", "green"))
|
||||
print(buf.toCPU())
|
||||
|
|
|
@ -2,41 +2,49 @@ import numpy as np
|
|||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
cwd = Path.cwd()
|
||||
sys.path.append(cwd.as_posix())
|
||||
sys.path.append((cwd / 'test').as_posix())
|
||||
sys.path.append((cwd / "test").as_posix())
|
||||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
|
||||
def augment_img(X, rotate=10, px=3):
|
||||
Xaug = np.zeros_like(X)
|
||||
for i in trange(len(X)):
|
||||
im = Image.fromarray(X[i])
|
||||
im = im.rotate(np.random.randint(-rotate,rotate), resample=Image.BICUBIC)
|
||||
w, h = X.shape[1:]
|
||||
#upper left, lower left, lower right, upper right
|
||||
quad = np.random.randint(-px,px,size=(8)) + np.array([0,0,0,h,w,h,w,0])
|
||||
im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC)
|
||||
Xaug[i] = im
|
||||
return Xaug
|
||||
Xaug = np.zeros_like(X)
|
||||
for i in trange(len(X)):
|
||||
im = Image.fromarray(X[i])
|
||||
im = im.rotate(np.random.randint(-rotate, rotate), resample=Image.BICUBIC)
|
||||
w, h = X.shape[1:]
|
||||
# upper left, lower left, lower right, upper right
|
||||
quad = np.random.randint(-px, px, size=(8)) + np.array([0, 0, 0, h, w, h, w, 0])
|
||||
im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC)
|
||||
Xaug[i] = im
|
||||
return Xaug
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import matplotlib.pyplot as plt
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X = np.vstack([X_train[:1]]*10+[X_train[1:2]]*10)
|
||||
fig, a = plt.subplots(2,len(X))
|
||||
Xaug = augment_img(X)
|
||||
for i in range(len(X)):
|
||||
a[0][i].imshow(X[i], cmap='gray')
|
||||
a[1][i].imshow(Xaug[i],cmap='gray')
|
||||
a[0][i].axis('off')
|
||||
a[1][i].axis('off')
|
||||
plt.show()
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
#create some nice gifs for doc?!
|
||||
for i in range(10):
|
||||
im = Image.fromarray(X_train[7353+i])
|
||||
im_aug = [Image.fromarray(x) for x in augment_img(np.array([X_train[7353+i]]*100))]
|
||||
im.save(f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0)
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
|
||||
X = np.vstack([X_train[:1]] * 10 + [X_train[1:2]] * 10)
|
||||
fig, a = plt.subplots(2, len(X))
|
||||
Xaug = augment_img(X)
|
||||
for i in range(len(X)):
|
||||
a[0][i].imshow(X[i], cmap="gray")
|
||||
a[1][i].imshow(Xaug[i], cmap="gray")
|
||||
a[0][i].axis("off")
|
||||
a[1][i].axis("off")
|
||||
plt.show()
|
||||
|
||||
# create some nice gifs for doc?!
|
||||
for i in range(10):
|
||||
im = Image.fromarray(X_train[7353 + i])
|
||||
im_aug = [
|
||||
Image.fromarray(x) for x in augment_img(np.array([X_train[7353 + i]] * 100))
|
||||
]
|
||||
im.save(
|
||||
f"aug{i}.gif", save_all=True, append_images=im_aug, duration=100, loop=0
|
||||
)
|
||||
|
|
|
@ -37,4 +37,4 @@ lin.apply_opt(Opt(op=OptOps.PADTO, axis=1, amt=32))
|
|||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
|
||||
run_linearizer(lin)
|
||||
run_linearizer(lin)
|
||||
|
|
|
@ -3,41 +3,82 @@ import numpy as np
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, fetch
|
||||
|
||||
|
||||
def fetch_mnist(tensors=False):
|
||||
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
||||
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
|
||||
X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
|
||||
else: return X_train, Y_train, X_test, Y_test
|
||||
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
||||
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
|
||||
X_train = (
|
||||
parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:]
|
||||
.reshape((-1, 28 * 28))
|
||||
.astype(np.float32)
|
||||
)
|
||||
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = (
|
||||
parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:]
|
||||
.reshape((-1, 28 * 28))
|
||||
.astype(np.float32)
|
||||
)
|
||||
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
if tensors:
|
||||
return (
|
||||
Tensor(X_train).reshape(-1, 1, 28, 28),
|
||||
Tensor(Y_train),
|
||||
Tensor(X_test).reshape(-1, 1, 28, 28),
|
||||
Tensor(Y_test),
|
||||
)
|
||||
else:
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
||||
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
|
||||
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
||||
|
||||
|
||||
def fetch_cifar():
|
||||
X_train = Tensor.empty(50000, 3*32*32, device=f'disk:/tmp/cifar_train_x', dtype=dtypes.uint8)
|
||||
Y_train = Tensor.empty(50000, device=f'disk:/tmp/cifar_train_y', dtype=dtypes.int64)
|
||||
X_test = Tensor.empty(10000, 3*32*32, device=f'disk:/tmp/cifar_test_x', dtype=dtypes.uint8)
|
||||
Y_test = Tensor.empty(10000, device=f'disk:/tmp/cifar_test_y', dtype=dtypes.int64)
|
||||
X_train = Tensor.empty(
|
||||
50000, 3 * 32 * 32, device=f"disk:/tmp/cifar_train_x", dtype=dtypes.uint8
|
||||
)
|
||||
Y_train = Tensor.empty(50000, device=f"disk:/tmp/cifar_train_y", dtype=dtypes.int64)
|
||||
X_test = Tensor.empty(
|
||||
10000, 3 * 32 * 32, device=f"disk:/tmp/cifar_test_x", dtype=dtypes.uint8
|
||||
)
|
||||
Y_test = Tensor.empty(10000, device=f"disk:/tmp/cifar_test_y", dtype=dtypes.int64)
|
||||
|
||||
if not os.path.isfile("/tmp/cifar_extracted"):
|
||||
def _load_disk_tensor(X, Y, db_list):
|
||||
idx = 0
|
||||
for db in db_list:
|
||||
x, y = db[b'data'], np.array(db[b'labels'])
|
||||
assert x.shape[0] == y.shape[0]
|
||||
X[idx:idx+x.shape[0]].assign(x)
|
||||
Y[idx:idx+x.shape[0]].assign(y)
|
||||
idx += x.shape[0]
|
||||
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
|
||||
if not os.path.isfile("/tmp/cifar_extracted"):
|
||||
|
||||
print("downloading and extracting CIFAR...")
|
||||
fn = fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')
|
||||
tt = tarfile.open(fn, mode='r:gz')
|
||||
_load_disk_tensor(X_train, Y_train, [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)])
|
||||
_load_disk_tensor(X_test, Y_test, [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")])
|
||||
open("/tmp/cifar_extracted", "wb").close()
|
||||
def _load_disk_tensor(X, Y, db_list):
|
||||
idx = 0
|
||||
for db in db_list:
|
||||
x, y = db[b"data"], np.array(db[b"labels"])
|
||||
assert x.shape[0] == y.shape[0]
|
||||
X[idx : idx + x.shape[0]].assign(x)
|
||||
Y[idx : idx + x.shape[0]].assign(y)
|
||||
idx += x.shape[0]
|
||||
assert idx == X.shape[0] and X.shape[0] == Y.shape[0]
|
||||
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
print("downloading and extracting CIFAR...")
|
||||
fn = fetch("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz")
|
||||
tt = tarfile.open(fn, mode="r:gz")
|
||||
_load_disk_tensor(
|
||||
X_train,
|
||||
Y_train,
|
||||
[
|
||||
pickle.load(
|
||||
tt.extractfile(f"cifar-10-batches-py/data_batch_{i}"),
|
||||
encoding="bytes",
|
||||
)
|
||||
for i in range(1, 6)
|
||||
],
|
||||
)
|
||||
_load_disk_tensor(
|
||||
X_test,
|
||||
Y_test,
|
||||
[
|
||||
pickle.load(
|
||||
tt.extractfile("cifar-10-batches-py/test_batch"), encoding="bytes"
|
||||
)
|
||||
],
|
||||
)
|
||||
open("/tmp/cifar_extracted", "wb").close()
|
||||
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
|
|
@ -8,192 +8,207 @@ from examples.mask_rcnn import Masker
|
|||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
iou = _mask.iou
|
||||
merge = _mask.merge
|
||||
iou = _mask.iou
|
||||
merge = _mask.merge
|
||||
frPyObjects = _mask.frPyObjects
|
||||
|
||||
BASEDIR = pathlib.Path(__file__).parent / "COCO"
|
||||
BASEDIR.mkdir(exist_ok=True)
|
||||
|
||||
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
|
||||
|
||||
def create_dict(key_row, val_row, rows):
|
||||
return {row[key_row]: row[val_row] for row in rows}
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'val2017').is_dir():
|
||||
fn = fetch('http://images.cocodataset.org/zips/val2017.zip')
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
if not pathlib.Path(BASEDIR / "val2017").is_dir():
|
||||
fn = fetch("http://images.cocodataset.org/zips/val2017.zip")
|
||||
with zipfile.ZipFile(fn, "r") as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'annotations').is_dir():
|
||||
fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip')
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
if not pathlib.Path(BASEDIR / "annotations").is_dir():
|
||||
fn = fetch("http://images.cocodataset.org/annotations/annotations_trainval2017.zip")
|
||||
with zipfile.ZipFile(fn, "r") as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
|
||||
annotations_raw = json.loads(f.read())
|
||||
images = annotations_raw['images']
|
||||
categories = annotations_raw['categories']
|
||||
annotations = annotations_raw['annotations']
|
||||
file_name_to_id = create_dict('file_name', 'id', images)
|
||||
id_to_width = create_dict('id', 'width', images)
|
||||
id_to_height = create_dict('id', 'height', images)
|
||||
json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
|
||||
contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
|
||||
with open(BASEDIR / "annotations/instances_val2017.json", "r") as f:
|
||||
annotations_raw = json.loads(f.read())
|
||||
images = annotations_raw["images"]
|
||||
categories = annotations_raw["categories"]
|
||||
annotations = annotations_raw["annotations"]
|
||||
file_name_to_id = create_dict("file_name", "id", images)
|
||||
id_to_width = create_dict("id", "width", images)
|
||||
id_to_height = create_dict("id", "height", images)
|
||||
json_category_id_to_contiguous_id = {v["id"]: i + 1 for i, v in enumerate(categories)}
|
||||
contiguous_category_id_to_json_id = {
|
||||
v: k for k, v in json_category_id_to_contiguous_id.items()
|
||||
}
|
||||
|
||||
|
||||
def encode(bimask):
|
||||
if len(bimask.shape) == 3:
|
||||
return _mask.encode(bimask)
|
||||
elif len(bimask.shape) == 2:
|
||||
h, w = bimask.shape
|
||||
return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
|
||||
if len(bimask.shape) == 3:
|
||||
return _mask.encode(bimask)
|
||||
elif len(bimask.shape) == 2:
|
||||
h, w = bimask.shape
|
||||
return _mask.encode(bimask.reshape((h, w, 1), order="F"))[0]
|
||||
|
||||
|
||||
def decode(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.decode(rleObjs)
|
||||
else:
|
||||
return _mask.decode([rleObjs])[:,:,0]
|
||||
if type(rleObjs) == list:
|
||||
return _mask.decode(rleObjs)
|
||||
else:
|
||||
return _mask.decode([rleObjs])[:, :, 0]
|
||||
|
||||
|
||||
def area(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.area(rleObjs)
|
||||
else:
|
||||
return _mask.area([rleObjs])[0]
|
||||
if type(rleObjs) == list:
|
||||
return _mask.area(rleObjs)
|
||||
else:
|
||||
return _mask.area([rleObjs])[0]
|
||||
|
||||
|
||||
def toBbox(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.toBbox(rleObjs)
|
||||
else:
|
||||
return _mask.toBbox([rleObjs])[0]
|
||||
if type(rleObjs) == list:
|
||||
return _mask.toBbox(rleObjs)
|
||||
else:
|
||||
return _mask.toBbox([rleObjs])[0]
|
||||
|
||||
|
||||
def convert_prediction_to_coco_bbox(file_name, prediction):
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
prediction = prediction.convert("xywh")
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
prediction = prediction.convert("xywh")
|
||||
|
||||
boxes = prediction.bbox.numpy().tolist()
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
boxes = prediction.bbox.numpy().tolist()
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"bbox": box,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, box in enumerate(boxes)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"bbox": box,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, box in enumerate(boxes)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
|
||||
def convert_prediction_to_coco_mask(file_name, prediction):
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
masks = prediction.get_field("mask")
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
masks = prediction.get_field("mask")
|
||||
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
|
||||
masks = masker([masks], [prediction])[0].numpy()
|
||||
masks = masker([masks], [prediction])[0].numpy()
|
||||
|
||||
rles = [
|
||||
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
|
||||
for mask in masks
|
||||
]
|
||||
for rle in rles:
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
rles = [
|
||||
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] for mask in masks
|
||||
]
|
||||
for rle in rles:
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"segmentation": rle,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, rle in enumerate(rles)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"segmentation": rle,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, rle in enumerate(rles)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
|
||||
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
|
||||
path = pathlib.Path(json_result_file)
|
||||
if rm and path.exists(): path.unlink()
|
||||
with open(path, "a") as f:
|
||||
for s in coco_results:
|
||||
f.write(json.dumps(s))
|
||||
f.write('\n')
|
||||
path = pathlib.Path(json_result_file)
|
||||
if rm and path.exists():
|
||||
path.unlink()
|
||||
with open(path, "a") as f:
|
||||
for s in coco_results:
|
||||
f.write(json.dumps(s))
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def remove_dup(l):
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in l if not (x in seen or seen_add(x))]
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in l if not (x in seen or seen_add(x))]
|
||||
|
||||
|
||||
class NpEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return super(NpEncoder, self).default(obj)
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return super(NpEncoder, self).default(obj)
|
||||
|
||||
|
||||
def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
|
||||
coco_results = []
|
||||
with open(json_result_file, "r") as f:
|
||||
for line in f:
|
||||
coco_results.append(json.loads(line))
|
||||
coco_results = []
|
||||
with open(json_result_file, "r") as f:
|
||||
for line in f:
|
||||
coco_results.append(json.loads(line))
|
||||
|
||||
coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
|
||||
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
|
||||
unique_list = [json.loads(s) for s in set_of_json]
|
||||
coco_gt = COCO(str(BASEDIR / "annotations/instances_val2017.json"))
|
||||
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
|
||||
unique_list = [json.loads(s) for s in set_of_json]
|
||||
|
||||
with open(f'{json_result_file}.flattend', "w") as f:
|
||||
json.dump(unique_list, f)
|
||||
with open(f"{json_result_file}.flattend", "w") as f:
|
||||
json.dump(unique_list, f)
|
||||
|
||||
coco_dt = coco_gt.loadRes(str(f"{json_result_file}.flattend"))
|
||||
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
return coco_eval
|
||||
|
||||
coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
|
||||
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
return coco_eval
|
||||
|
||||
def iterate(files, bs=1):
|
||||
batch = []
|
||||
for file in files:
|
||||
batch.append(file)
|
||||
if len(batch) >= bs: yield batch; batch = []
|
||||
if len(batch) > 0: yield batch; batch = []
|
||||
batch = []
|
||||
for file in files:
|
||||
batch.append(file)
|
||||
if len(batch) >= bs:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
batch = []
|
||||
|
|
|
@ -7,47 +7,56 @@ import functools, pathlib
|
|||
|
||||
BASEDIR = pathlib.Path(__file__).parent / "imagenet"
|
||||
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
||||
cir = {v[0]: int(k) for k,v in ci.items()}
|
||||
cir = {v[0]: int(k) for k, v in ci.items()}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_train_files():
|
||||
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
|
||||
return [(BASEDIR / "train" / x) for x in train_files]
|
||||
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
|
||||
return [(BASEDIR / "train" / x) for x in train_files]
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_val_files():
|
||||
val_files = glob.glob(str(BASEDIR / "val/*/*"))
|
||||
return val_files
|
||||
val_files = glob.glob(str(BASEDIR / "val/*/*"))
|
||||
return val_files
|
||||
|
||||
#rrc = transforms.RandomResizedCrop(224)
|
||||
|
||||
# rrc = transforms.RandomResizedCrop(224)
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
def image_load(fn):
|
||||
img = Image.open(fn).convert('RGB')
|
||||
img = F.resize(img, 256, Image.BILINEAR)
|
||||
img = F.center_crop(img, 224)
|
||||
ret = np.array(img)
|
||||
return ret
|
||||
img = Image.open(fn).convert("RGB")
|
||||
img = F.resize(img, 256, Image.BILINEAR)
|
||||
img = F.center_crop(img, 224)
|
||||
ret = np.array(img)
|
||||
return ret
|
||||
|
||||
|
||||
def iterate(bs=32, val=True, shuffle=True):
|
||||
files = get_val_files() if val else get_train_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle: random.shuffle(order)
|
||||
from multiprocessing import Pool
|
||||
p = Pool(16)
|
||||
for i in range(0, len(files), bs):
|
||||
X = p.map(image_load, [files[i] for i in order[i:i+bs]])
|
||||
Y = [cir[files[i].split("/")[-2]] for i in order[i:i+bs]]
|
||||
yield (np.array(X), np.array(Y))
|
||||
files = get_val_files() if val else get_train_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle:
|
||||
random.shuffle(order)
|
||||
from multiprocessing import Pool
|
||||
|
||||
p = Pool(16)
|
||||
for i in range(0, len(files), bs):
|
||||
X = p.map(image_load, [files[i] for i in order[i : i + bs]])
|
||||
Y = [cir[files[i].split("/")[-2]] for i in order[i : i + bs]]
|
||||
yield (np.array(X), np.array(Y))
|
||||
|
||||
|
||||
def fetch_batch(bs, val=False):
|
||||
files = get_val_files() if val else get_train_files()
|
||||
samp = np.random.randint(0, len(files), size=(bs))
|
||||
files = [files[i] for i in samp]
|
||||
X = [image_load(x) for x in files]
|
||||
Y = [cir[x.split("/")[0]] for x in files]
|
||||
return np.array(X), np.array(Y)
|
||||
files = get_val_files() if val else get_train_files()
|
||||
samp = np.random.randint(0, len(files), size=(bs))
|
||||
files = [files[i] for i in samp]
|
||||
X = [image_load(x) for x in files]
|
||||
Y = [cir[x.split("/")[0]] for x in files]
|
||||
return np.array(X), np.array(Y)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
X,Y = fetch_batch(64)
|
||||
print(X.shape, Y)
|
||||
|
||||
X, Y = fetch_batch(64)
|
||||
print(X.shape, Y)
|
||||
|
|
|
@ -4,48 +4,92 @@ from pathlib import Path
|
|||
from tqdm import tqdm
|
||||
import tarfile, os
|
||||
|
||||
|
||||
def imagenet_extract(file, path, small=False):
|
||||
with tarfile.open(name=file) as tar:
|
||||
if small: # Show progressbar only for big files
|
||||
for member in tar.getmembers(): tar.extract(path=path, member=member)
|
||||
else:
|
||||
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member)
|
||||
tar.close()
|
||||
with tarfile.open(name=file) as tar:
|
||||
if small: # Show progressbar only for big files
|
||||
for member in tar.getmembers():
|
||||
tar.extract(path=path, member=member)
|
||||
else:
|
||||
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())):
|
||||
tar.extract(path=path, member=member)
|
||||
tar.close()
|
||||
|
||||
|
||||
def imagenet_prepare_val():
|
||||
# Read in the labels file
|
||||
with open(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt", 'r') as f:
|
||||
labels = f.read().splitlines()
|
||||
f.close()
|
||||
# Get a list of images
|
||||
images = os.listdir(Path(__file__).parent / "imagenet" / "val")
|
||||
images.sort()
|
||||
# Create folders and move files into those
|
||||
for co,dir in enumerate(labels):
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True)
|
||||
os.replace(Path(__file__).parent / "imagenet" / "val" / images[co], Path(__file__).parent / "imagenet" / "val" / dir / images[co])
|
||||
os.remove(Path(__file__).parent / "imagenet" / "imagenet_2012_validation_synset_labels.txt")
|
||||
# Read in the labels file
|
||||
with open(
|
||||
Path(__file__).parent
|
||||
/ "imagenet"
|
||||
/ "imagenet_2012_validation_synset_labels.txt",
|
||||
"r",
|
||||
) as f:
|
||||
labels = f.read().splitlines()
|
||||
f.close()
|
||||
# Get a list of images
|
||||
images = os.listdir(Path(__file__).parent / "imagenet" / "val")
|
||||
images.sort()
|
||||
# Create folders and move files into those
|
||||
for co, dir in enumerate(labels):
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "val" / dir, exist_ok=True)
|
||||
os.replace(
|
||||
Path(__file__).parent / "imagenet" / "val" / images[co],
|
||||
Path(__file__).parent / "imagenet" / "val" / dir / images[co],
|
||||
)
|
||||
os.remove(
|
||||
Path(__file__).parent
|
||||
/ "imagenet"
|
||||
/ "imagenet_2012_validation_synset_labels.txt"
|
||||
)
|
||||
|
||||
|
||||
def imagenet_prepare_train():
|
||||
images = os.listdir(Path(__file__).parent / "imagenet" / "train")
|
||||
for co,tarf in enumerate(images):
|
||||
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
|
||||
if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file():
|
||||
images[co] = tarf[:-4] # remove .tar from extracted tar files
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True)
|
||||
imagenet_extract(Path(__file__).parent / "imagenet" / "train" / tarf, Path(__file__).parent/ "imagenet" / "train" / images[co], small=True)
|
||||
os.remove(Path(__file__).parent / "imagenet" / "train" / tarf)
|
||||
images = os.listdir(Path(__file__).parent / "imagenet" / "train")
|
||||
for co, tarf in enumerate(images):
|
||||
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
|
||||
if Path(Path(__file__).parent / "imagenet" / "train" / images[co]).is_file():
|
||||
images[co] = tarf[:-4] # remove .tar from extracted tar files
|
||||
os.makedirs(
|
||||
Path(__file__).parent / "imagenet" / "train" / images[co], exist_ok=True
|
||||
)
|
||||
imagenet_extract(
|
||||
Path(__file__).parent / "imagenet" / "train" / tarf,
|
||||
Path(__file__).parent / "imagenet" / "train" / images[co],
|
||||
small=True,
|
||||
)
|
||||
os.remove(Path(__file__).parent / "imagenet" / "train" / tarf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True)
|
||||
fetch("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent / "imagenet" / "imagenet_class_index.json")
|
||||
fetch("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent / "imagenet"/ "imagenet_2012_validation_synset_labels.txt")
|
||||
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar") # 7GB
|
||||
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "val")
|
||||
imagenet_prepare_val()
|
||||
if os.getenv('IMGNET_TRAIN', None) is not None:
|
||||
fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar") #138GB!
|
||||
imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "train")
|
||||
imagenet_prepare_train()
|
||||
os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True)
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json",
|
||||
Path(__file__).parent / "imagenet" / "imagenet_class_index.json",
|
||||
)
|
||||
fetch(
|
||||
"https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt",
|
||||
Path(__file__).parent
|
||||
/ "imagenet"
|
||||
/ "imagenet_2012_validation_synset_labels.txt",
|
||||
)
|
||||
fetch(
|
||||
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar",
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
|
||||
) # 7GB
|
||||
imagenet_extract(
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar",
|
||||
Path(__file__).parent / "imagenet" / "val",
|
||||
)
|
||||
imagenet_prepare_val()
|
||||
if os.getenv("IMGNET_TRAIN", None) is not None:
|
||||
fetch(
|
||||
"https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar",
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
|
||||
) # 138GB!
|
||||
imagenet_extract(
|
||||
Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar",
|
||||
Path(__file__).parent / "imagenet" / "train",
|
||||
)
|
||||
imagenet_prepare_train()
|
||||
|
|
|
@ -23,109 +23,199 @@ mv kits extra/datasets
|
|||
```
|
||||
"""
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_val_files():
|
||||
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
|
||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
|
||||
data = fetch(
|
||||
"https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt"
|
||||
).read_text()
|
||||
return sorted(
|
||||
[x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")]
|
||||
)
|
||||
|
||||
|
||||
def load_pair(file_path):
|
||||
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
|
||||
image_spacings = image.header["pixdim"][1:4].tolist()
|
||||
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
|
||||
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
|
||||
return image, label, image_spacings
|
||||
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(
|
||||
file_path / "segmentation.nii.gz"
|
||||
)
|
||||
image_spacings = image.header["pixdim"][1:4].tolist()
|
||||
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(
|
||||
np.uint8
|
||||
)
|
||||
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
|
||||
return image, label, image_spacings
|
||||
|
||||
|
||||
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
|
||||
if image_spacings != target_spacing:
|
||||
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
|
||||
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
|
||||
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
|
||||
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
|
||||
image = np.squeeze(image.numpy(), axis=0)
|
||||
label = np.squeeze(label.numpy(), axis=0)
|
||||
return image, label
|
||||
if image_spacings != target_spacing:
|
||||
spc_arr, targ_arr, shp_arr = (
|
||||
np.array(image_spacings),
|
||||
np.array(target_spacing),
|
||||
np.array(image.shape[1:]),
|
||||
)
|
||||
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
|
||||
image = F.interpolate(
|
||||
torch.from_numpy(np.expand_dims(image, axis=0)),
|
||||
size=new_shape,
|
||||
mode="trilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
label = F.interpolate(
|
||||
torch.from_numpy(np.expand_dims(label, axis=0)),
|
||||
size=new_shape,
|
||||
mode="nearest",
|
||||
)
|
||||
image = np.squeeze(image.numpy(), axis=0)
|
||||
label = np.squeeze(label.numpy(), axis=0)
|
||||
return image, label
|
||||
|
||||
|
||||
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
|
||||
image = np.clip(image, min_clip, max_clip)
|
||||
image = (image - mean) / std
|
||||
return image
|
||||
image = np.clip(image, min_clip, max_clip)
|
||||
image = (image - mean) / std
|
||||
return image
|
||||
|
||||
|
||||
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
|
||||
current_shape = image.shape[1:]
|
||||
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
|
||||
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
|
||||
image = np.pad(image, paddings, mode="edge")
|
||||
label = np.pad(label, paddings, mode="edge")
|
||||
return image, label
|
||||
current_shape = image.shape[1:]
|
||||
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
|
||||
paddings = [(0, 0)] + [
|
||||
(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)
|
||||
]
|
||||
image = np.pad(image, paddings, mode="edge")
|
||||
label = np.pad(label, paddings, mode="edge")
|
||||
return image, label
|
||||
|
||||
|
||||
def preprocess(file_path):
|
||||
image, label, image_spacings = load_pair(file_path)
|
||||
image, label = resample3d(image, label, image_spacings)
|
||||
image = normal_intensity(image.copy())
|
||||
image, label = pad_to_min_shape(image, label)
|
||||
return image, label
|
||||
image, label, image_spacings = load_pair(file_path)
|
||||
image, label = resample3d(image, label, image_spacings)
|
||||
image = normal_intensity(image.copy())
|
||||
image, label = pad_to_min_shape(image, label)
|
||||
return image, label
|
||||
|
||||
|
||||
def iterate(val=True, shuffle=False):
|
||||
if not val: raise NotImplementedError
|
||||
files = get_val_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle: random.shuffle(order)
|
||||
for file in files:
|
||||
X, Y = preprocess(file)
|
||||
X = np.expand_dims(X, axis=0)
|
||||
yield (X, Y)
|
||||
if not val:
|
||||
raise NotImplementedError
|
||||
files = get_val_files()
|
||||
order = list(range(0, len(files)))
|
||||
if shuffle:
|
||||
random.shuffle(order)
|
||||
for file in files:
|
||||
X, Y = preprocess(file)
|
||||
X = np.expand_dims(X, axis=0)
|
||||
yield (X, Y)
|
||||
|
||||
|
||||
def gaussian_kernel(n, std):
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
|
||||
gaussian_3d = gaussian_3d.reshape(n, n, n)
|
||||
gaussian_3d = np.cbrt(gaussian_3d)
|
||||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
gaussian_1d = signal.gaussian(n, std)
|
||||
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
||||
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
|
||||
gaussian_3d = gaussian_3d.reshape(n, n, n)
|
||||
gaussian_3d = np.cbrt(gaussian_3d)
|
||||
gaussian_3d /= gaussian_3d.max()
|
||||
return gaussian_3d
|
||||
|
||||
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
|
||||
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
|
||||
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
|
||||
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
|
||||
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
|
||||
|
||||
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
|
||||
from tinygrad.jit import TinyJit
|
||||
mdl_run = TinyJit(lambda x: model(x).realize())
|
||||
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
||||
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
|
||||
bounds = [image_shape[i] % strides[i] for i in range(dim)]
|
||||
bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)]
|
||||
inputs = inputs[
|
||||
...,
|
||||
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
|
||||
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
|
||||
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
|
||||
]
|
||||
labels = labels[
|
||||
...,
|
||||
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
|
||||
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
|
||||
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
|
||||
]
|
||||
inputs, paddings = pad_input(inputs, roi_shape, strides)
|
||||
padded_shape = inputs.shape[2:]
|
||||
size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
|
||||
result = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
||||
norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0])
|
||||
norm_patch = np.expand_dims(norm_patch, axis=0)
|
||||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
|
||||
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
|
||||
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
|
||||
result /= norm_map
|
||||
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
|
||||
return result, labels
|
||||
def pad_input(
|
||||
volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3
|
||||
):
|
||||
bounds = [
|
||||
(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)
|
||||
]
|
||||
bounds = [
|
||||
bounds[i]
|
||||
if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i]
|
||||
else bounds[i] + strides[i]
|
||||
for i in range(dim)
|
||||
]
|
||||
paddings = [
|
||||
bounds[2] // 2,
|
||||
bounds[2] - bounds[2] // 2,
|
||||
bounds[1] // 2,
|
||||
bounds[1] - bounds[1] // 2,
|
||||
bounds[0] // 2,
|
||||
bounds[0] - bounds[0] // 2,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
]
|
||||
return (
|
||||
F.pad(
|
||||
torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val
|
||||
).numpy(),
|
||||
paddings,
|
||||
)
|
||||
|
||||
|
||||
def sliding_window_inference(
|
||||
model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5
|
||||
):
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
mdl_run = TinyJit(lambda x: model(x).realize())
|
||||
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
||||
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
|
||||
bounds = [image_shape[i] % strides[i] for i in range(dim)]
|
||||
bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)]
|
||||
inputs = inputs[
|
||||
...,
|
||||
bounds[0] // 2 : image_shape[0] - (bounds[0] - bounds[0] // 2),
|
||||
bounds[1] // 2 : image_shape[1] - (bounds[1] - bounds[1] // 2),
|
||||
bounds[2] // 2 : image_shape[2] - (bounds[2] - bounds[2] // 2),
|
||||
]
|
||||
labels = labels[
|
||||
...,
|
||||
bounds[0] // 2 : image_shape[0] - (bounds[0] - bounds[0] // 2),
|
||||
bounds[1] // 2 : image_shape[1] - (bounds[1] - bounds[1] // 2),
|
||||
bounds[2] // 2 : image_shape[2] - (bounds[2] - bounds[2] // 2),
|
||||
]
|
||||
inputs, paddings = pad_input(inputs, roi_shape, strides)
|
||||
padded_shape = inputs.shape[2:]
|
||||
size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
|
||||
result = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
||||
norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
||||
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0])
|
||||
norm_patch = np.expand_dims(norm_patch, axis=0)
|
||||
for i in range(0, strides[0] * size[0], strides[0]):
|
||||
for j in range(0, strides[1] * size[1], strides[1]):
|
||||
for k in range(0, strides[2] * size[2], strides[2]):
|
||||
out = mdl_run(
|
||||
Tensor(
|
||||
inputs[
|
||||
...,
|
||||
i : roi_shape[0] + i,
|
||||
j : roi_shape[1] + j,
|
||||
k : roi_shape[2] + k,
|
||||
]
|
||||
)
|
||||
).numpy()
|
||||
result[
|
||||
...,
|
||||
i : roi_shape[0] + i,
|
||||
j : roi_shape[1] + j,
|
||||
k : roi_shape[2] + k,
|
||||
] += (
|
||||
out * norm_patch
|
||||
)
|
||||
norm_map[
|
||||
...,
|
||||
i : roi_shape[0] + i,
|
||||
j : roi_shape[1] + j,
|
||||
k : roi_shape[2] + k,
|
||||
] += norm_patch
|
||||
result /= norm_map
|
||||
result = result[
|
||||
...,
|
||||
paddings[4] : image_shape[0] + paddings[4],
|
||||
paddings[2] : image_shape[1] + paddings[2],
|
||||
paddings[0] : image_shape[2] + paddings[0],
|
||||
]
|
||||
return result, labels
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for X, Y in iterate():
|
||||
print(X.shape, Y.shape)
|
||||
for X, Y in iterate():
|
||||
print(X.shape, Y.shape)
|
||||
|
|
|
@ -17,66 +17,88 @@ Then this [file](https://github.com/mlcommons/inference/blob/master/speech_recog
|
|||
"""
|
||||
BASEDIR = pathlib.Path(__file__).parent / "librispeech"
|
||||
with open(BASEDIR / "dev-clean-wav.json") as f:
|
||||
ci = json.load(f)
|
||||
ci = json.load(f)
|
||||
|
||||
FILTER_BANK = np.expand_dims(librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0)
|
||||
FILTER_BANK = np.expand_dims(
|
||||
librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0
|
||||
)
|
||||
WINDOW = librosa.filters.get_window("hann", 320)
|
||||
|
||||
|
||||
def feature_extract(x, x_lens):
|
||||
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
|
||||
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
|
||||
|
||||
# pre-emphasis
|
||||
x = np.concatenate((np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1)
|
||||
# pre-emphasis
|
||||
x = np.concatenate(
|
||||
(np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1
|
||||
)
|
||||
|
||||
# stft
|
||||
x = librosa.stft(x, n_fft=512, window=WINDOW, hop_length=160, win_length=320, center=True, pad_mode="reflect")
|
||||
x = np.stack((x.real, x.imag), axis=-1)
|
||||
# stft
|
||||
x = librosa.stft(
|
||||
x,
|
||||
n_fft=512,
|
||||
window=WINDOW,
|
||||
hop_length=160,
|
||||
win_length=320,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
x = np.stack((x.real, x.imag), axis=-1)
|
||||
|
||||
# power spectrum
|
||||
x = (x**2).sum(-1)
|
||||
# power spectrum
|
||||
x = (x**2).sum(-1)
|
||||
|
||||
# mel filter bank
|
||||
x = np.matmul(FILTER_BANK, x)
|
||||
# mel filter bank
|
||||
x = np.matmul(FILTER_BANK, x)
|
||||
|
||||
# log
|
||||
x = np.log(x + 1e-20)
|
||||
# log
|
||||
x = np.log(x + 1e-20)
|
||||
|
||||
# feature splice
|
||||
seq = [x]
|
||||
for i in range(1, 3):
|
||||
tmp = np.zeros_like(x)
|
||||
tmp[:, :, :-i] = x[:, :, i:]
|
||||
seq.append(tmp)
|
||||
features = np.concatenate(seq, axis=1)[:, :, ::3]
|
||||
# feature splice
|
||||
seq = [x]
|
||||
for i in range(1, 3):
|
||||
tmp = np.zeros_like(x)
|
||||
tmp[:, :, :-i] = x[:, :, i:]
|
||||
seq.append(tmp)
|
||||
features = np.concatenate(seq, axis=1)[:, :, ::3]
|
||||
|
||||
# normalize
|
||||
features_mean = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
|
||||
features_std = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
|
||||
for i in range(features.shape[0]):
|
||||
features_mean[i, :] = features[i, :, :x_lens[i]].mean(axis=1)
|
||||
features_std[i, :] = features[i, :, :x_lens[i]].std(axis=1, ddof=1)
|
||||
features_std += 1e-5
|
||||
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(features_std, 2)
|
||||
# normalize
|
||||
features_mean = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
|
||||
features_std = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
|
||||
for i in range(features.shape[0]):
|
||||
features_mean[i, :] = features[i, :, : x_lens[i]].mean(axis=1)
|
||||
features_std[i, :] = features[i, :, : x_lens[i]].std(axis=1, ddof=1)
|
||||
features_std += 1e-5
|
||||
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(
|
||||
features_std, 2
|
||||
)
|
||||
|
||||
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
|
||||
|
||||
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
|
||||
|
||||
def load_wav(file):
|
||||
sample = soundfile.read(file)[0].astype(np.float32)
|
||||
return sample, sample.shape[0]
|
||||
sample = soundfile.read(file)[0].astype(np.float32)
|
||||
return sample, sample.shape[0]
|
||||
|
||||
|
||||
def iterate(bs=1, start=0):
|
||||
print(f"there are {len(ci)} samples in the dataset")
|
||||
for i in range(start, len(ci), bs):
|
||||
samples, sample_lens = zip(*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]])
|
||||
samples = list(samples)
|
||||
# pad to same length
|
||||
max_len = max(sample_lens)
|
||||
for j in range(len(samples)):
|
||||
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
|
||||
samples, sample_lens = np.array(samples), np.array(sample_lens)
|
||||
print(f"there are {len(ci)} samples in the dataset")
|
||||
for i in range(start, len(ci), bs):
|
||||
samples, sample_lens = zip(
|
||||
*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]]
|
||||
)
|
||||
samples = list(samples)
|
||||
# pad to same length
|
||||
max_len = max(sample_lens)
|
||||
for j in range(len(samples)):
|
||||
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
|
||||
samples, sample_lens = np.array(samples), np.array(sample_lens)
|
||||
|
||||
yield feature_extract(samples, sample_lens), np.array(
|
||||
[v["transcript"] for v in ci[i : i + bs]]
|
||||
)
|
||||
|
||||
yield feature_extract(samples, sample_lens), np.array([v["transcript"] for v in ci[i : i + bs]])
|
||||
|
||||
if __name__ == "__main__":
|
||||
X, Y = next(iterate())
|
||||
print(X[0].shape, Y.shape)
|
||||
X, Y = next(iterate())
|
||||
print(X[0].shape, Y.shape)
|
||||
|
|
|
@ -12,153 +12,467 @@ import concurrent.futures
|
|||
|
||||
BASEDIR = pathlib.Path(__file__).parent / "open-images-v6-mlperf"
|
||||
BUCKET_NAME = "open-images-dataset"
|
||||
BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
||||
MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
||||
MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana',
|
||||
'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle',
|
||||
'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot',
|
||||
'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread',
|
||||
'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry',
|
||||
'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart',
|
||||
'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken',
|
||||
'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin',
|
||||
'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store',
|
||||
'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard',
|
||||
'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly',
|
||||
'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant',
|
||||
'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork',
|
||||
'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses',
|
||||
'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar',
|
||||
'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels',
|
||||
'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard',
|
||||
'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair',
|
||||
'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream',
|
||||
'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite',
|
||||
'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse',
|
||||
'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror',
|
||||
'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule',
|
||||
'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building',
|
||||
'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen',
|
||||
'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow',
|
||||
'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle',
|
||||
'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion',
|
||||
'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard',
|
||||
'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon',
|
||||
'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light',
|
||||
'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan',
|
||||
'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television',
|
||||
'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower',
|
||||
'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase',
|
||||
'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch',
|
||||
'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman',
|
||||
'Zebra', 'Zucchini',
|
||||
BBOX_ANNOTATIONS_URL = (
|
||||
"https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
||||
)
|
||||
MAP_CLASSES_URL = (
|
||||
"https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
||||
)
|
||||
MLPERF_CLASSES = [
|
||||
"Airplane",
|
||||
"Antelope",
|
||||
"Apple",
|
||||
"Backpack",
|
||||
"Balloon",
|
||||
"Banana",
|
||||
"Barrel",
|
||||
"Baseball bat",
|
||||
"Baseball glove",
|
||||
"Bee",
|
||||
"Beer",
|
||||
"Bench",
|
||||
"Bicycle",
|
||||
"Bicycle helmet",
|
||||
"Bicycle wheel",
|
||||
"Billboard",
|
||||
"Book",
|
||||
"Bookcase",
|
||||
"Boot",
|
||||
"Bottle",
|
||||
"Bowl",
|
||||
"Bowling equipment",
|
||||
"Box",
|
||||
"Boy",
|
||||
"Brassiere",
|
||||
"Bread",
|
||||
"Broccoli",
|
||||
"Bronze sculpture",
|
||||
"Bull",
|
||||
"Bus",
|
||||
"Bust",
|
||||
"Butterfly",
|
||||
"Cabinetry",
|
||||
"Cake",
|
||||
"Camel",
|
||||
"Camera",
|
||||
"Candle",
|
||||
"Candy",
|
||||
"Cannon",
|
||||
"Canoe",
|
||||
"Carrot",
|
||||
"Cart",
|
||||
"Castle",
|
||||
"Cat",
|
||||
"Cattle",
|
||||
"Cello",
|
||||
"Chair",
|
||||
"Cheese",
|
||||
"Chest of drawers",
|
||||
"Chicken",
|
||||
"Christmas tree",
|
||||
"Coat",
|
||||
"Cocktail",
|
||||
"Coffee",
|
||||
"Coffee cup",
|
||||
"Coffee table",
|
||||
"Coin",
|
||||
"Common sunflower",
|
||||
"Computer keyboard",
|
||||
"Computer monitor",
|
||||
"Convenience store",
|
||||
"Cookie",
|
||||
"Countertop",
|
||||
"Cowboy hat",
|
||||
"Crab",
|
||||
"Crocodile",
|
||||
"Cucumber",
|
||||
"Cupboard",
|
||||
"Curtain",
|
||||
"Deer",
|
||||
"Desk",
|
||||
"Dinosaur",
|
||||
"Dog",
|
||||
"Doll",
|
||||
"Dolphin",
|
||||
"Door",
|
||||
"Dragonfly",
|
||||
"Drawer",
|
||||
"Dress",
|
||||
"Drum",
|
||||
"Duck",
|
||||
"Eagle",
|
||||
"Earrings",
|
||||
"Egg (Food)",
|
||||
"Elephant",
|
||||
"Falcon",
|
||||
"Fedora",
|
||||
"Flag",
|
||||
"Flowerpot",
|
||||
"Football",
|
||||
"Football helmet",
|
||||
"Fork",
|
||||
"Fountain",
|
||||
"French fries",
|
||||
"French horn",
|
||||
"Frog",
|
||||
"Giraffe",
|
||||
"Girl",
|
||||
"Glasses",
|
||||
"Goat",
|
||||
"Goggles",
|
||||
"Goldfish",
|
||||
"Gondola",
|
||||
"Goose",
|
||||
"Grape",
|
||||
"Grapefruit",
|
||||
"Guitar",
|
||||
"Hamburger",
|
||||
"Handbag",
|
||||
"Harbor seal",
|
||||
"Headphones",
|
||||
"Helicopter",
|
||||
"High heels",
|
||||
"Hiking equipment",
|
||||
"Horse",
|
||||
"House",
|
||||
"Houseplant",
|
||||
"Human arm",
|
||||
"Human beard",
|
||||
"Human body",
|
||||
"Human ear",
|
||||
"Human eye",
|
||||
"Human face",
|
||||
"Human foot",
|
||||
"Human hair",
|
||||
"Human hand",
|
||||
"Human head",
|
||||
"Human leg",
|
||||
"Human mouth",
|
||||
"Human nose",
|
||||
"Ice cream",
|
||||
"Jacket",
|
||||
"Jeans",
|
||||
"Jellyfish",
|
||||
"Juice",
|
||||
"Kitchen & dining room table",
|
||||
"Kite",
|
||||
"Lamp",
|
||||
"Lantern",
|
||||
"Laptop",
|
||||
"Lavender (Plant)",
|
||||
"Lemon",
|
||||
"Light bulb",
|
||||
"Lighthouse",
|
||||
"Lily",
|
||||
"Lion",
|
||||
"Lipstick",
|
||||
"Lizard",
|
||||
"Man",
|
||||
"Maple",
|
||||
"Microphone",
|
||||
"Mirror",
|
||||
"Mixing bowl",
|
||||
"Mobile phone",
|
||||
"Monkey",
|
||||
"Motorcycle",
|
||||
"Muffin",
|
||||
"Mug",
|
||||
"Mule",
|
||||
"Mushroom",
|
||||
"Musical keyboard",
|
||||
"Necklace",
|
||||
"Nightstand",
|
||||
"Office building",
|
||||
"Orange",
|
||||
"Owl",
|
||||
"Oyster",
|
||||
"Paddle",
|
||||
"Palm tree",
|
||||
"Parachute",
|
||||
"Parrot",
|
||||
"Pen",
|
||||
"Penguin",
|
||||
"Personal flotation device",
|
||||
"Piano",
|
||||
"Picture frame",
|
||||
"Pig",
|
||||
"Pillow",
|
||||
"Pizza",
|
||||
"Plate",
|
||||
"Platter",
|
||||
"Porch",
|
||||
"Poster",
|
||||
"Pumpkin",
|
||||
"Rabbit",
|
||||
"Rifle",
|
||||
"Roller skates",
|
||||
"Rose",
|
||||
"Salad",
|
||||
"Sandal",
|
||||
"Saucer",
|
||||
"Saxophone",
|
||||
"Scarf",
|
||||
"Sea lion",
|
||||
"Sea turtle",
|
||||
"Sheep",
|
||||
"Shelf",
|
||||
"Shirt",
|
||||
"Shorts",
|
||||
"Shrimp",
|
||||
"Sink",
|
||||
"Skateboard",
|
||||
"Ski",
|
||||
"Skull",
|
||||
"Skyscraper",
|
||||
"Snake",
|
||||
"Sock",
|
||||
"Sofa bed",
|
||||
"Sparrow",
|
||||
"Spider",
|
||||
"Spoon",
|
||||
"Sports uniform",
|
||||
"Squirrel",
|
||||
"Stairs",
|
||||
"Stool",
|
||||
"Strawberry",
|
||||
"Street light",
|
||||
"Studio couch",
|
||||
"Suit",
|
||||
"Sun hat",
|
||||
"Sunglasses",
|
||||
"Surfboard",
|
||||
"Sushi",
|
||||
"Swan",
|
||||
"Swimming pool",
|
||||
"Swimwear",
|
||||
"Tank",
|
||||
"Tap",
|
||||
"Taxi",
|
||||
"Tea",
|
||||
"Teddy bear",
|
||||
"Television",
|
||||
"Tent",
|
||||
"Tie",
|
||||
"Tiger",
|
||||
"Tin can",
|
||||
"Tire",
|
||||
"Toilet",
|
||||
"Tomato",
|
||||
"Tortoise",
|
||||
"Tower",
|
||||
"Traffic light",
|
||||
"Train",
|
||||
"Tripod",
|
||||
"Truck",
|
||||
"Trumpet",
|
||||
"Umbrella",
|
||||
"Van",
|
||||
"Vase",
|
||||
"Vehicle registration plate",
|
||||
"Violin",
|
||||
"Wall clock",
|
||||
"Waste container",
|
||||
"Watch",
|
||||
"Whale",
|
||||
"Wheel",
|
||||
"Wheelchair",
|
||||
"Whiteboard",
|
||||
"Window",
|
||||
"Wine",
|
||||
"Wine glass",
|
||||
"Woman",
|
||||
"Zebra",
|
||||
"Zucchini",
|
||||
]
|
||||
|
||||
|
||||
def openimages():
|
||||
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
|
||||
if not ann_file.is_file():
|
||||
fetch_openimages(ann_file)
|
||||
return ann_file
|
||||
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
|
||||
if not ann_file.is_file():
|
||||
fetch_openimages(ann_file)
|
||||
return ann_file
|
||||
|
||||
|
||||
# this slows down the conversion a lot!
|
||||
# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py
|
||||
def extract_dims(path): return Image.open(path).size[::-1]
|
||||
def extract_dims(path):
|
||||
return Image.open(path).size[::-1]
|
||||
|
||||
def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
|
||||
categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"])
|
||||
class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner")
|
||||
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
|
||||
annotations = annotations.merge(class_map, on="LabelName", how="inner")
|
||||
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
|
||||
annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand")
|
||||
|
||||
# Images
|
||||
imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None}
|
||||
for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows())
|
||||
]
|
||||
def export_to_coco(
|
||||
class_map,
|
||||
annotations,
|
||||
image_list,
|
||||
dataset_path,
|
||||
output_path,
|
||||
classes=MLPERF_CLASSES,
|
||||
):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
|
||||
categories_map = pd.DataFrame(
|
||||
[(i, c) for i, c in enumerate(classes)],
|
||||
columns=["category_id", "category_name"],
|
||||
)
|
||||
class_map = class_map.merge(
|
||||
categories_map, left_on="DisplayName", right_on="category_name", how="inner"
|
||||
)
|
||||
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
|
||||
annotations = annotations.merge(class_map, on="LabelName", how="inner")
|
||||
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
|
||||
annotations[["height", "width"]] = annotations.apply(
|
||||
lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"),
|
||||
axis=1,
|
||||
result_type="expand",
|
||||
)
|
||||
|
||||
# Annotations
|
||||
annots = []
|
||||
for i, row in annotations.iterrows():
|
||||
xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]]
|
||||
x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h
|
||||
coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h}
|
||||
coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]})
|
||||
coco_annot["iscrowd"] = int(row["IsGroupOf"])
|
||||
annots.append(coco_annot)
|
||||
# Images
|
||||
imgs = [
|
||||
{
|
||||
"id": int(id + 1),
|
||||
"file_name": f"{image_id}.jpg",
|
||||
"height": row["height"],
|
||||
"width": row["width"],
|
||||
"license": None,
|
||||
"coco_url": None,
|
||||
}
|
||||
for (id, image_id), row in (
|
||||
annotations.groupby(["image_id", "ImageID"]).first().iterrows()
|
||||
)
|
||||
]
|
||||
|
||||
# Annotations
|
||||
annots = []
|
||||
for i, row in annotations.iterrows():
|
||||
xmin, ymin, xmax, ymax, img_w, img_h = [
|
||||
row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]
|
||||
]
|
||||
x, y, w, h = (
|
||||
xmin * img_w,
|
||||
ymin * img_h,
|
||||
(xmax - xmin) * img_w,
|
||||
(ymax - ymin) * img_h,
|
||||
)
|
||||
coco_annot = {
|
||||
"id": int(i) + 1,
|
||||
"image_id": int(row["image_id"] + 1),
|
||||
"category_id": int(row["category_id"]),
|
||||
"bbox": [x, y, w, h],
|
||||
"area": w * h,
|
||||
}
|
||||
coco_annot.update(
|
||||
{
|
||||
k: row[k]
|
||||
for k in [
|
||||
"IsOccluded",
|
||||
"IsInside",
|
||||
"IsDepiction",
|
||||
"IsTruncated",
|
||||
"IsGroupOf",
|
||||
]
|
||||
}
|
||||
)
|
||||
coco_annot["iscrowd"] = int(row["IsGroupOf"])
|
||||
annots.append(coco_annot)
|
||||
|
||||
info = {"dataset": "openimages_mlperf", "version": "v6"}
|
||||
coco_annotations = {
|
||||
"info": info,
|
||||
"licenses": [],
|
||||
"categories": cats,
|
||||
"images": imgs,
|
||||
"annotations": annots,
|
||||
}
|
||||
with open(output_path, "w") as fp:
|
||||
json.dump(coco_annotations, fp)
|
||||
|
||||
info = {"dataset": "openimages_mlperf", "version": "v6"}
|
||||
coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots}
|
||||
with open(output_path, "w") as fp:
|
||||
json.dump(coco_annotations, fp)
|
||||
|
||||
def get_image_list(class_map, annotations, classes=MLPERF_CLASSES):
|
||||
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
|
||||
image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique()
|
||||
return image_ids
|
||||
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
|
||||
image_ids = annotations[np.isin(annotations["LabelName"], labels)][
|
||||
"ImageID"
|
||||
].unique()
|
||||
return image_ids
|
||||
|
||||
|
||||
def download_image(bucket, image_id, data_dir):
|
||||
try:
|
||||
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}")
|
||||
try:
|
||||
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
sys.exit(
|
||||
f"ERROR when downloading image `validation/{image_id}`: {str(exception)}"
|
||||
)
|
||||
|
||||
|
||||
def fetch_openimages(output_fn):
|
||||
bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
|
||||
bucket = boto3.resource(
|
||||
"s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)
|
||||
).Bucket(BUCKET_NAME)
|
||||
|
||||
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
|
||||
annotations_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
|
||||
annotations_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1]
|
||||
fetch(BBOX_ANNOTATIONS_URL, annotations_fn)
|
||||
annotations = pd.read_csv(annotations_fn)
|
||||
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split("/")[-1]
|
||||
fetch(BBOX_ANNOTATIONS_URL, annotations_fn)
|
||||
annotations = pd.read_csv(annotations_fn)
|
||||
|
||||
classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1]
|
||||
fetch(MAP_CLASSES_URL, classmap_fn)
|
||||
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
|
||||
classmap_fn = annotations_dir / MAP_CLASSES_URL.split("/")[-1]
|
||||
fetch(MAP_CLASSES_URL, classmap_fn)
|
||||
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
|
||||
|
||||
image_list = get_image_list(class_map, annotations)
|
||||
image_list = get_image_list(class_map, annotations)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list]
|
||||
for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))):
|
||||
t.set_description(f"Downloading images")
|
||||
future.result()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(download_image, bucket, image_id, data_dir)
|
||||
for image_id in image_list
|
||||
]
|
||||
for future in (
|
||||
t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))
|
||||
):
|
||||
t.set_description(f"Downloading images")
|
||||
future.result()
|
||||
|
||||
print("Converting annotations to COCO format...")
|
||||
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
|
||||
|
||||
print("Converting annotations to COCO format...")
|
||||
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
|
||||
|
||||
def image_load(fn):
|
||||
img_folder = BASEDIR / "validation/data"
|
||||
img = Image.open(img_folder / fn).convert('RGB')
|
||||
import torchvision.transforms.functional as F
|
||||
ret = F.resize(img, size=(800, 800))
|
||||
ret = np.array(ret)
|
||||
return ret, img.size[::-1]
|
||||
img_folder = BASEDIR / "validation/data"
|
||||
img = Image.open(img_folder / fn).convert("RGB")
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
ret = F.resize(img, size=(800, 800))
|
||||
ret = np.array(ret)
|
||||
return ret, img.size[::-1]
|
||||
|
||||
|
||||
def prepare_target(annotations, img_id, img_size):
|
||||
boxes = [annot["bbox"] for annot in annotations]
|
||||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
boxes[:, 0::2] = boxes[:, 0::2].clip(0, img_size[1])
|
||||
boxes[:, 1::2] = boxes[:, 1::2].clip(0, img_size[0])
|
||||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||
boxes = boxes[keep]
|
||||
classes = [annot["category_id"] for annot in annotations]
|
||||
classes = np.array(classes, dtype=np.int64)
|
||||
classes = classes[keep]
|
||||
return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size}
|
||||
boxes = [annot["bbox"] for annot in annotations]
|
||||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
boxes[:, 0::2] = boxes[:, 0::2].clip(0, img_size[1])
|
||||
boxes[:, 1::2] = boxes[:, 1::2].clip(0, img_size[0])
|
||||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||
boxes = boxes[keep]
|
||||
classes = [annot["category_id"] for annot in annotations]
|
||||
classes = np.array(classes, dtype=np.int64)
|
||||
classes = classes[keep]
|
||||
return {
|
||||
"boxes": boxes,
|
||||
"labels": classes,
|
||||
"image_id": img_id,
|
||||
"image_size": img_size,
|
||||
}
|
||||
|
||||
|
||||
def iterate(coco, bs=8):
|
||||
image_ids = sorted(coco.imgs.keys())
|
||||
for i in range(0, len(image_ids), bs):
|
||||
X, targets = [], []
|
||||
for img_id in image_ids[i:i+bs]:
|
||||
x, original_size = image_load(coco.loadImgs(img_id)[0]["file_name"])
|
||||
X.append(x)
|
||||
annotations = coco.loadAnns(coco.getAnnIds(img_id))
|
||||
targets.append(prepare_target(annotations, img_id, original_size))
|
||||
yield np.array(X), targets
|
||||
image_ids = sorted(coco.imgs.keys())
|
||||
for i in range(0, len(image_ids), bs):
|
||||
X, targets = [], []
|
||||
for img_id in image_ids[i : i + bs]:
|
||||
x, original_size = image_load(coco.loadImgs(img_id)[0]["file_name"])
|
||||
X.append(x)
|
||||
annotations = coco.loadAnns(coco.getAnnIds(img_id))
|
||||
targets.append(prepare_target(annotations, img_id, original_size))
|
||||
yield np.array(X), targets
|
||||
|
|
|
@ -3,20 +3,25 @@ from tinygrad.tensor import Tensor
|
|||
from extra.datasets.imagenet import iterate, get_val_files
|
||||
|
||||
if __name__ == "__main__":
|
||||
#sz = len(get_val_files())
|
||||
sz = 32*100
|
||||
X,Y = None, None
|
||||
# sz = len(get_val_files())
|
||||
sz = 32 * 100
|
||||
X, Y = None, None
|
||||
|
||||
idx = 0
|
||||
for x,y in iterate(shuffle=False):
|
||||
print(x.shape, y.shape, x.dtype, y.dtype)
|
||||
assert x.shape[0] == y.shape[0]
|
||||
bs = x.shape[0]
|
||||
if X is None:
|
||||
X = Tensor.empty(sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8)
|
||||
Y = Tensor.empty(sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64)
|
||||
print(X.shape, Y.shape)
|
||||
X[idx:idx+bs].assign(x)
|
||||
Y[idx:idx+bs].assign(y)
|
||||
idx += bs
|
||||
if idx >= sz: break
|
||||
idx = 0
|
||||
for x, y in iterate(shuffle=False):
|
||||
print(x.shape, y.shape, x.dtype, y.dtype)
|
||||
assert x.shape[0] == y.shape[0]
|
||||
bs = x.shape[0]
|
||||
if X is None:
|
||||
X = Tensor.empty(
|
||||
sz, *x.shape[1:], device="disk:/tmp/imagenet_x", dtype=dtypes.uint8
|
||||
)
|
||||
Y = Tensor.empty(
|
||||
sz, *y.shape[1:], device="disk:/tmp/imagenet_y", dtype=dtypes.int64
|
||||
)
|
||||
print(X.shape, Y.shape)
|
||||
X[idx : idx + bs].assign(x)
|
||||
Y[idx : idx + bs].assign(y)
|
||||
idx += bs
|
||||
if idx >= sz:
|
||||
break
|
||||
|
|
|
@ -6,143 +6,164 @@ import numpy as np
|
|||
from tinygrad.helpers import fetch
|
||||
|
||||
BASEDIR = Path(__file__).parent / "squad"
|
||||
|
||||
|
||||
def init_dataset():
|
||||
os.makedirs(BASEDIR, exist_ok=True)
|
||||
fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
|
||||
with open(BASEDIR / "dev-v1.1.json") as f:
|
||||
data = json.load(f)["data"]
|
||||
os.makedirs(BASEDIR, exist_ok=True)
|
||||
fetch(
|
||||
"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
|
||||
BASEDIR / "dev-v1.1.json",
|
||||
)
|
||||
with open(BASEDIR / "dev-v1.1.json") as f:
|
||||
data = json.load(f)["data"]
|
||||
|
||||
examples = []
|
||||
for article in data:
|
||||
for paragraph in article["paragraphs"]:
|
||||
text = paragraph["context"]
|
||||
doc_tokens = []
|
||||
prev_is_whitespace = True
|
||||
for c in text:
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
prev_is_whitespace = True
|
||||
else:
|
||||
if prev_is_whitespace:
|
||||
doc_tokens.append(c)
|
||||
else:
|
||||
doc_tokens[-1] += c
|
||||
prev_is_whitespace = False
|
||||
examples = []
|
||||
for article in data:
|
||||
for paragraph in article["paragraphs"]:
|
||||
text = paragraph["context"]
|
||||
doc_tokens = []
|
||||
prev_is_whitespace = True
|
||||
for c in text:
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
prev_is_whitespace = True
|
||||
else:
|
||||
if prev_is_whitespace:
|
||||
doc_tokens.append(c)
|
||||
else:
|
||||
doc_tokens[-1] += c
|
||||
prev_is_whitespace = False
|
||||
|
||||
for qa in paragraph["qas"]:
|
||||
qa_id = qa["id"]
|
||||
q_text = qa["question"]
|
||||
for qa in paragraph["qas"]:
|
||||
qa_id = qa["id"]
|
||||
q_text = qa["question"]
|
||||
|
||||
examples.append(
|
||||
{
|
||||
"id": qa_id,
|
||||
"question": q_text,
|
||||
"context": doc_tokens,
|
||||
"answers": list(map(lambda x: x["text"], qa["answers"])),
|
||||
}
|
||||
)
|
||||
return examples
|
||||
|
||||
examples.append({
|
||||
"id": qa_id,
|
||||
"question": q_text,
|
||||
"context": doc_tokens,
|
||||
"answers": list(map(lambda x: x["text"], qa["answers"]))
|
||||
})
|
||||
return examples
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
best_score, best_span_index = None, None
|
||||
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
||||
end = doc_start + doc_length - 1
|
||||
if position < doc_start:
|
||||
continue
|
||||
if position > end:
|
||||
continue
|
||||
num_left_context = position - doc_start
|
||||
num_right_context = end - position
|
||||
score = min(num_left_context, num_right_context) + 0.01 * doc_length
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_span_index = di
|
||||
return cur_span_index == best_span_index
|
||||
best_score, best_span_index = None, None
|
||||
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
||||
end = doc_start + doc_length - 1
|
||||
if position < doc_start:
|
||||
continue
|
||||
if position > end:
|
||||
continue
|
||||
num_left_context = position - doc_start
|
||||
num_right_context = end - position
|
||||
score = min(num_left_context, num_right_context) + 0.01 * doc_length
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_span_index = di
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def convert_example_to_features(example, tokenizer):
|
||||
query_tokens = tokenizer.tokenize(example["question"])
|
||||
query_tokens = tokenizer.tokenize(example["question"])
|
||||
|
||||
if len(query_tokens) > 64:
|
||||
query_tokens = query_tokens[:64]
|
||||
if len(query_tokens) > 64:
|
||||
query_tokens = query_tokens[:64]
|
||||
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
for i, token in enumerate(example["context"]):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
for i, token in enumerate(example["context"]):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
|
||||
max_tokens_for_doc = 384 - len(query_tokens) - 3
|
||||
max_tokens_for_doc = 384 - len(query_tokens) - 3
|
||||
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
length = len(all_doc_tokens) - start_offset
|
||||
length = min(length, max_tokens_for_doc)
|
||||
doc_spans.append((start_offset, length))
|
||||
if start_offset + length == len(all_doc_tokens):
|
||||
break
|
||||
start_offset += min(length, 128)
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(all_doc_tokens):
|
||||
length = len(all_doc_tokens) - start_offset
|
||||
length = min(length, max_tokens_for_doc)
|
||||
doc_spans.append((start_offset, length))
|
||||
if start_offset + length == len(all_doc_tokens):
|
||||
break
|
||||
start_offset += min(length, 128)
|
||||
|
||||
outputs = []
|
||||
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
||||
tokens = []
|
||||
token_to_orig_map = {}
|
||||
token_is_max_context = {}
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
outputs = []
|
||||
for di, (doc_start, doc_length) in enumerate(doc_spans):
|
||||
tokens = []
|
||||
token_to_orig_map = {}
|
||||
token_is_max_context = {}
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
|
||||
for i in range(doc_length):
|
||||
split_token_index = doc_start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
for i in range(doc_length):
|
||||
split_token_index = doc_start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
token_is_max_context[len(tokens)] = _check_is_max_context(
|
||||
doc_spans, di, split_token_index
|
||||
)
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
|
||||
while len(input_ids) < 384:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
while len(input_ids) < 384:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
|
||||
assert len(input_ids) == 384
|
||||
assert len(input_mask) == 384
|
||||
assert len(segment_ids) == 384
|
||||
assert len(input_ids) == 384
|
||||
assert len(input_mask) == 384
|
||||
assert len(segment_ids) == 384
|
||||
|
||||
outputs.append({
|
||||
"input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
|
||||
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
|
||||
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
|
||||
"token_to_orig_map": token_to_orig_map,
|
||||
"token_is_max_context": token_is_max_context,
|
||||
"tokens": tokens,
|
||||
})
|
||||
outputs.append(
|
||||
{
|
||||
"input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
|
||||
"input_mask": np.expand_dims(np.array(input_mask), 0).astype(
|
||||
np.float32
|
||||
),
|
||||
"segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(
|
||||
np.float32
|
||||
),
|
||||
"token_to_orig_map": token_to_orig_map,
|
||||
"token_is_max_context": token_is_max_context,
|
||||
"tokens": tokens,
|
||||
}
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
return outputs
|
||||
|
||||
def iterate(tokenizer, start=0):
|
||||
examples = init_dataset()
|
||||
print(f"there are {len(examples)} pairs in the dataset")
|
||||
examples = init_dataset()
|
||||
print(f"there are {len(examples)} pairs in the dataset")
|
||||
|
||||
for i in range(start, len(examples)):
|
||||
example = examples[i]
|
||||
features = convert_example_to_features(example, tokenizer)
|
||||
# we need to yield all features here as the f1 score is the maximum over all features
|
||||
yield features, example
|
||||
|
||||
for i in range(start, len(examples)):
|
||||
example = examples[i]
|
||||
features = convert_example_to_features(example, tokenizer)
|
||||
# we need to yield all features here as the f1 score is the maximum over all features
|
||||
yield features, example
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))
|
||||
tokenizer = BertTokenizer(
|
||||
str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt")
|
||||
)
|
||||
|
||||
X, Y = next(iterate(tokenizer))
|
||||
print(" ".join(X[0]["tokens"]))
|
||||
print(X[0]["input_ids"].shape, Y)
|
||||
X, Y = next(iterate(tokenizer))
|
||||
print(" ".join(X[0]["tokens"]))
|
||||
print(X[0]["input_ids"].shape, Y)
|
||||
|
|
|
@ -5,56 +5,70 @@ from tinygrad.helpers import DEBUG, getenv
|
|||
import multiprocessing as mp
|
||||
import os
|
||||
|
||||
|
||||
# this needs to be called before everything else if you are using distributed
|
||||
def preinit():
|
||||
os.environ["DELAYED_RUNTIME_INIT"] = "1"
|
||||
mp.set_start_method("spawn")
|
||||
os.environ["DELAYED_RUNTIME_INIT"] = "1"
|
||||
mp.set_start_method("spawn")
|
||||
|
||||
|
||||
# out-of-band communication/synchronization
|
||||
class _OOB:
|
||||
def __init__(self, pipes:List[Tuple[Connection, Connection]]):
|
||||
self.pipes = pipes
|
||||
def __init__(self, pipes: List[Tuple[Connection, Connection]]):
|
||||
self.pipes = pipes
|
||||
|
||||
# send some data to a target rank, blocks until data is received
|
||||
def send(self, data: Any, target_rank: int):
|
||||
self.pipes[getenv("RANK") * getenv("WORLD_SIZE") + target_rank][1].send(data)
|
||||
|
||||
# receive some data from a target rank, blocks until data is received
|
||||
def recv(self, target_rank: int) -> Any:
|
||||
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
|
||||
|
||||
# send some data to a target rank, blocks until data is received
|
||||
def send(self, data:Any, target_rank:int):
|
||||
self.pipes[getenv("RANK") * getenv("WORLD_SIZE") + target_rank][1].send(data)
|
||||
|
||||
# receive some data from a target rank, blocks until data is received
|
||||
def recv(self, target_rank:int) -> Any:
|
||||
return self.pipes[target_rank * getenv("WORLD_SIZE") + getenv("RANK")][0].recv()
|
||||
OOB: Optional[_OOB] = None
|
||||
|
||||
def init_oob(world_size:int):
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
global OOB
|
||||
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
|
||||
def init_oob(world_size: int):
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
global OOB
|
||||
OOB = _OOB([mp.Pipe(False) for _ in range(world_size * world_size)])
|
||||
|
||||
|
||||
# this runs in the spawned process so we can do all the delayed runtime initialization
|
||||
def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()):
|
||||
# setup the rank
|
||||
os.environ["RANK"] = str(rank)
|
||||
def _process_wrap(rank: int, device: str, oob: _OOB, fn: Callable, args=()):
|
||||
# setup the rank
|
||||
os.environ["RANK"] = str(rank)
|
||||
|
||||
# setup out of band communication
|
||||
global OOB
|
||||
OOB = oob
|
||||
# setup out of band communication
|
||||
global OOB
|
||||
OOB = oob
|
||||
|
||||
# do specific runtime initialization for distributed
|
||||
from tinygrad import Device
|
||||
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1])
|
||||
if "GPU" in device:
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
CL.post_init(device_num)
|
||||
elif "HIP" in device:
|
||||
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(device_num)
|
||||
if DEBUG >= 1: print(f"distributed process {rank} initialized runtime for device {device}")
|
||||
# do specific runtime initialization for distributed
|
||||
from tinygrad import Device
|
||||
|
||||
# convert device to be process specific
|
||||
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
|
||||
device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(
|
||||
device.split(":")[-1]
|
||||
)
|
||||
if "GPU" in device:
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
|
||||
CL.post_init(device_num)
|
||||
elif "HIP" in device:
|
||||
os.environ["HIP_DEFAULT_DEVICE"] = os.environ["HIP_VISIBLE_DEVICES"] = str(
|
||||
device_num
|
||||
)
|
||||
if DEBUG >= 1:
|
||||
print(f"distributed process {rank} initialized runtime for device {device}")
|
||||
|
||||
# convert device to be process specific
|
||||
Device.DEFAULT = device.split(":")[0] if "GPU" in device else device
|
||||
|
||||
fn(*args)
|
||||
|
||||
fn(*args)
|
||||
|
||||
# wrapper around mp.Process that initializes the runtime
|
||||
def spawn(rank:int, device:str, fn:Callable, args=()) -> mp.Process:
|
||||
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()
|
||||
return p
|
||||
def spawn(rank: int, device: str, fn: Callable, args=()) -> mp.Process:
|
||||
(p := mp.Process(target=_process_wrap, args=(rank, device, OOB, fn, args))).start()
|
||||
return p
|
||||
|
|
|
@ -3,38 +3,41 @@ from tinygrad.helpers import getenv
|
|||
|
||||
from extra.dist import world
|
||||
|
||||
def allreduce(t:Tensor) -> Tensor:
|
||||
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
||||
|
||||
# flatten
|
||||
flattened = t.flatten()
|
||||
def allreduce(t: Tensor) -> Tensor:
|
||||
RANK, WORLD_SIZE = getenv("RANK"), getenv("WORLD_SIZE")
|
||||
|
||||
# pad to evenly divide
|
||||
if flattened.shape[0] % WORLD_SIZE != 0:
|
||||
flattened = Tensor.cat(flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE)))
|
||||
# flatten
|
||||
flattened = t.flatten()
|
||||
|
||||
# chunk
|
||||
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
||||
# pad to evenly divide
|
||||
if flattened.shape[0] % WORLD_SIZE != 0:
|
||||
flattened = Tensor.cat(
|
||||
flattened, Tensor.empty(WORLD_SIZE - (flattened.shape[0] % WORLD_SIZE))
|
||||
)
|
||||
|
||||
next_rank = (RANK + 1) % WORLD_SIZE
|
||||
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
# chunk
|
||||
chunks = flattened.chunk(WORLD_SIZE, dim=0)
|
||||
|
||||
# scatter reduce
|
||||
current_chunk_index = RANK
|
||||
for _ in range(WORLD_SIZE - 1):
|
||||
world.send(chunks[current_chunk_index], next_rank)
|
||||
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
||||
world.recv(recv_buf, prev_rank)
|
||||
chunks[current_chunk_index] += recv_buf
|
||||
next_rank = (RANK + 1) % WORLD_SIZE
|
||||
prev_rank = ((RANK - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
|
||||
# gather
|
||||
current_chunk_index = (RANK + 1) % WORLD_SIZE
|
||||
for _ in range(WORLD_SIZE - 1):
|
||||
world.send(chunks[current_chunk_index], next_rank)
|
||||
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
||||
world.recv(recv_buf, prev_rank)
|
||||
chunks[current_chunk_index].assign(recv_buf)
|
||||
# scatter reduce
|
||||
current_chunk_index = RANK
|
||||
for _ in range(WORLD_SIZE - 1):
|
||||
world.send(chunks[current_chunk_index], next_rank)
|
||||
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
||||
world.recv(recv_buf, prev_rank)
|
||||
chunks[current_chunk_index] += recv_buf
|
||||
|
||||
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)
|
||||
# gather
|
||||
current_chunk_index = (RANK + 1) % WORLD_SIZE
|
||||
for _ in range(WORLD_SIZE - 1):
|
||||
world.send(chunks[current_chunk_index], next_rank)
|
||||
current_chunk_index = ((current_chunk_index - 1) + WORLD_SIZE) % WORLD_SIZE
|
||||
recv_buf = Tensor.empty(*chunks[current_chunk_index].shape)
|
||||
world.recv(recv_buf, prev_rank)
|
||||
chunks[current_chunk_index].assign(recv_buf)
|
||||
|
||||
return Tensor.cat(*chunks, dim=0).shrink(((0, t.numel()),)).reshape(*t.shape)
|
||||
|
|
|
@ -4,111 +4,154 @@ from multiprocessing import shared_memory
|
|||
from tinygrad.helpers import DEBUG, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut
|
||||
|
||||
try:
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
|
||||
except: RawHIPBuffer = None
|
||||
import gpuctypes.hip as hip
|
||||
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
|
||||
except:
|
||||
RawHIPBuffer = None
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.tensor import Tensor, Function
|
||||
import numpy as np
|
||||
|
||||
|
||||
# match the function signature of JITRunner so we can put it in the cache
|
||||
def __send_rb(args, variables=None, wait=False, jit=False):
|
||||
x, target_rank, y = args[:3]
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
else:
|
||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else: y.fromCPU(x.toCPU())
|
||||
dist.OOB.send(None, target_rank)
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
x, target_rank, y = args[:3]
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
else:
|
||||
if isinstance(x, RawBufferCopyInOut):
|
||||
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else:
|
||||
y.fromCPU(x.toCPU())
|
||||
dist.OOB.send(None, target_rank)
|
||||
if DEBUG >= 2:
|
||||
print(
|
||||
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} sent {x} to rank {target_rank}"
|
||||
)
|
||||
|
||||
|
||||
def __recv_rb(args, variables=None, wait=False, jit=False):
|
||||
x, target_rank, y = args[:3]
|
||||
dist.OOB.recv(target_rank)
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
x._transfer(y)
|
||||
elif isinstance(x, RawBuffer): x._copyin(y.toCPU())
|
||||
else: x.fromCPU(y.toCPU())
|
||||
if DEBUG >= 2: print(f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}")
|
||||
x, target_rank, y = args[:3]
|
||||
dist.OOB.recv(target_rank)
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
x._transfer(y)
|
||||
elif isinstance(x, RawBuffer):
|
||||
x._copyin(y.toCPU())
|
||||
else:
|
||||
x.fromCPU(y.toCPU())
|
||||
if DEBUG >= 2:
|
||||
print(
|
||||
f"{colored('****', 'magenta' if jit else None)} rank {getenv('RANK')} recv {x} from rank {target_rank}"
|
||||
)
|
||||
|
||||
|
||||
# send a rawbuffer from out rank to the target rank
|
||||
def _send_rb(x:RawBuffer, target_rank:int):
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
# send ipc handle
|
||||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
check(hip.hipIpcGetMemHandle(ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf))
|
||||
dist.OOB.send((handle, x._device), target_rank)
|
||||
def _send_rb(x: RawBuffer, target_rank: int):
|
||||
if RawHIPBuffer and x.__class__ is RawHIPBuffer:
|
||||
# send ipc handle
|
||||
check(hip.hipSetDevice(x._device))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
check(
|
||||
hip.hipIpcGetMemHandle(
|
||||
ctypes.byval(handle := hip.hipIpcMemHandle_t()), x._buf
|
||||
)
|
||||
)
|
||||
dist.OOB.send((handle, x._device), target_rank)
|
||||
|
||||
# jit support
|
||||
x._allocator = None # need to disconnect allocator for sent buffers
|
||||
CacheCollector.add(__send_rb, [x, target_rank, None], {})
|
||||
else:
|
||||
# create shared memory
|
||||
shm_name = (s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)).name
|
||||
s.close()
|
||||
# jit support
|
||||
x._allocator = None # need to disconnect allocator for sent buffers
|
||||
CacheCollector.add(__send_rb, [x, target_rank, None], {})
|
||||
else:
|
||||
# create shared memory
|
||||
shm_name = (
|
||||
s := shared_memory.SharedMemory(create=True, size=x.size * x.dtype.itemsize)
|
||||
).name
|
||||
s.close()
|
||||
|
||||
# copy the buffer into shared memory
|
||||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:"+shm_name)
|
||||
# fast path when we can directly copyout
|
||||
if isinstance(x, RawBufferCopyInOut): x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else: y.fromCPU(x.toCPU())
|
||||
# copy the buffer into shared memory
|
||||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
||||
# fast path when we can directly copyout
|
||||
if isinstance(x, RawBufferCopyInOut):
|
||||
x._copyout(np.frombuffer(y._buffer(), dtype=x.dtype.np))
|
||||
else:
|
||||
y.fromCPU(x.toCPU())
|
||||
|
||||
dist.OOB.send(shm_name, target_rank)
|
||||
dist.OOB.send(shm_name, target_rank)
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__send_rb, [x, target_rank, y], {})
|
||||
if DEBUG >= 2:
|
||||
print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__send_rb, [x, target_rank, y], {})
|
||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} sent {x} to rank {target_rank}")
|
||||
|
||||
# receive a rawbuffer from the target rank
|
||||
def _recv_rb(x:RawBuffer, target_rank:int):
|
||||
if RawHIPBuffer and isinstance(x, RawHIPBuffer):
|
||||
# open ipc handle
|
||||
handle, y_device = dist.OOB.recv(target_rank)
|
||||
check(hip.hipSetDevice(y_device))
|
||||
check(hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0))
|
||||
def _recv_rb(x: RawBuffer, target_rank: int):
|
||||
if RawHIPBuffer and isinstance(x, RawHIPBuffer):
|
||||
# open ipc handle
|
||||
handle, y_device = dist.OOB.recv(target_rank)
|
||||
check(hip.hipSetDevice(y_device))
|
||||
check(
|
||||
hip.hipIpcOpenMemHandle(ctypes.byval(ptr := ctypes.c_void_p()), handle, 0)
|
||||
)
|
||||
|
||||
# build a new buffer
|
||||
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
|
||||
x._transfer(y)
|
||||
# build a new buffer
|
||||
y = RawHIPBuffer(x.size, x.dtype, device=str(y_device), buf=ptr, allocator=None)
|
||||
x._transfer(y)
|
||||
|
||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||
else:
|
||||
shm_name = dist.OOB.recv(target_rank)
|
||||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:"+shm_name)
|
||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||
else:
|
||||
shm_name = dist.OOB.recv(target_rank)
|
||||
y = RawDiskBuffer(x.size, x.dtype, device="disk:shm:" + shm_name)
|
||||
|
||||
# fast path when we can directly copyin
|
||||
if isinstance(x, RawBuffer): x._copyin(y.toCPU())
|
||||
else: x.fromCPU(y.toCPU())
|
||||
# fast path when we can directly copyin
|
||||
if isinstance(x, RawBuffer):
|
||||
x._copyin(y.toCPU())
|
||||
else:
|
||||
x.fromCPU(y.toCPU())
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||
if DEBUG >= 2:
|
||||
print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
||||
|
||||
# jit support
|
||||
CacheCollector.add(__recv_rb, [x, target_rank, y], {})
|
||||
if DEBUG >= 2: print(f"**** rank {getenv('RANK')} got {x} from rank {target_rank}")
|
||||
|
||||
# sends a lazybuffer from our rank to the target rank
|
||||
def _send_lb(x:LazyBuffer, target_rank:int) -> None:
|
||||
assert x.st.contiguous and x.realized, "sending buffer must be contiguous and realized"
|
||||
_send_rb(x.realized, target_rank)
|
||||
def _send_lb(x: LazyBuffer, target_rank: int) -> None:
|
||||
assert (
|
||||
x.st.contiguous and x.realized
|
||||
), "sending buffer must be contiguous and realized"
|
||||
_send_rb(x.realized, target_rank)
|
||||
|
||||
|
||||
# receive a lazybuffer from the target rank
|
||||
def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer:
|
||||
assert x.st.contiguous and x.realized, "receiving buffer must be contiguous and realized"
|
||||
_recv_rb(x.realized, target_rank)
|
||||
return x
|
||||
|
||||
class Send(Function):
|
||||
def forward(self, x:LazyBuffer, target_rank:int) -> LazyBuffer:
|
||||
self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype
|
||||
_send_lb(x, target_rank)
|
||||
def _recv_lb(x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||
assert (
|
||||
x.st.contiguous and x.realized
|
||||
), "receiving buffer must be contiguous and realized"
|
||||
_recv_rb(x.realized, target_rank)
|
||||
return x
|
||||
|
||||
class Recv(Function):
|
||||
def forward(self, x:LazyBuffer, target_rank:int) -> LazyBuffer:
|
||||
self.target_rank = target_rank
|
||||
return _recv_lb(x, target_rank)
|
||||
|
||||
def send(x:Tensor, target_rank:int) -> Tensor: return Send.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
def recv(x:Tensor, target_rank:int) -> Tensor: return Recv.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
class Send(Function):
|
||||
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||
self.target_rank, self.shape, self.dtype = target_rank, x.shape, x.dtype
|
||||
_send_lb(x, target_rank)
|
||||
return x
|
||||
|
||||
|
||||
class Recv(Function):
|
||||
def forward(self, x: LazyBuffer, target_rank: int) -> LazyBuffer:
|
||||
self.target_rank = target_rank
|
||||
return _recv_lb(x, target_rank)
|
||||
|
||||
|
||||
def send(x: Tensor, target_rank: int) -> Tensor:
|
||||
return Send.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
|
||||
|
||||
def recv(x: Tensor, target_rank: int) -> Tensor:
|
||||
return Recv.apply(x.contiguous().realize(), target_rank=target_rank)
|
||||
|
|
|
@ -2,20 +2,25 @@ import sys, sqlite3, pickle
|
|||
from tinygrad.helpers import CACHEDB
|
||||
|
||||
if __name__ == "__main__":
|
||||
fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB
|
||||
conn = sqlite3.connect(fn)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
for f in cur.fetchall():
|
||||
table = f[0]
|
||||
cur2 = conn.cursor()
|
||||
cur2.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
cnt = cur2.fetchone()[0]
|
||||
print(f"{table:20s} : {cnt}")
|
||||
fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB
|
||||
conn = sqlite3.connect(fn)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
for f in cur.fetchall():
|
||||
table = f[0]
|
||||
cur2 = conn.cursor()
|
||||
cur2.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
cnt = cur2.fetchone()[0]
|
||||
print(f"{table:20s} : {cnt}")
|
||||
|
||||
cur3 = conn.cursor()
|
||||
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
||||
for f in cur3.fetchall():
|
||||
v = pickle.loads(f[-1])
|
||||
print(" ", len(f[0]) if isinstance(f[0], str) else f[0], f[1:-1], str(v)[0:50])
|
||||
#print(f"{len(k):10d}, {sk} -> {v}")
|
||||
cur3 = conn.cursor()
|
||||
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
||||
for f in cur3.fetchall():
|
||||
v = pickle.loads(f[-1])
|
||||
print(
|
||||
" ",
|
||||
len(f[0]) if isinstance(f[0], str) else f[0],
|
||||
f[1:-1],
|
||||
str(v)[0:50],
|
||||
)
|
||||
# print(f"{len(k):10d}, {sk} -> {v}")
|
||||
|
|
|
@ -7,77 +7,190 @@ import json
|
|||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for ji in run.jit_cache:
|
||||
fxn = ji.prg
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(ji.rawbufs):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
bufs[key] = (special_names[key], arg.size*arg.dtype.itemsize, arg.dtype, key)
|
||||
else:
|
||||
bufs[key] = (f"buf_{bufnum}", arg.size*arg.dtype.itemsize, arg.dtype, key)
|
||||
bufnum += 1
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
|
||||
def compile_net(
|
||||
run: TinyJit, special_names: Dict[int, str]
|
||||
) -> Tuple[
|
||||
Dict[str, str],
|
||||
List[Tuple[str, List[str], List[int]]],
|
||||
Dict[str, Tuple[int, DType, int]],
|
||||
Dict[str, Tensor],
|
||||
]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for ji in run.jit_cache:
|
||||
fxn = ji.prg
|
||||
functions[
|
||||
fxn.name
|
||||
] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i, arg in enumerate(ji.rawbufs):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
bufs[key] = (
|
||||
special_names[key],
|
||||
arg.size * arg.dtype.itemsize,
|
||||
arg.dtype,
|
||||
key,
|
||||
)
|
||||
else:
|
||||
bufs[key] = (
|
||||
f"buf_{bufnum}",
|
||||
arg.size * arg.dtype.itemsize,
|
||||
arg.dtype,
|
||||
key,
|
||||
)
|
||||
bufnum += 1
|
||||
if i > 0:
|
||||
bufs_to_save[
|
||||
bufs[key][0]
|
||||
] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
statements.append((fxn.name, cargs, fxn.global_size, fxn.local_size))
|
||||
|
||||
def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
assert hasattr(model, "forward") or callable(model), "model needs a forward function"
|
||||
@TinyJit
|
||||
def run(*x):
|
||||
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
|
||||
assert isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||
out = [out] if isinstance(out, Tensor) else out
|
||||
return [o.realize() for o in out]
|
||||
return (
|
||||
functions,
|
||||
statements,
|
||||
{name: (size, dtype, key) for (name, size, dtype, key) in bufs.values()},
|
||||
bufs_to_save,
|
||||
)
|
||||
|
||||
# twice to run the JIT
|
||||
for _ in range(2): the_output = run(*args)
|
||||
special_names = {}
|
||||
|
||||
# hack to put the inputs back
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
realized_input = args[idx].lazydata.realized
|
||||
run.jit_cache[j].rawbufs[i] = realized_input
|
||||
special_names[id(realized_input)] = f'input{idx}'
|
||||
def jit_model(model, *args) -> Tuple[TinyJit, Dict[int, str]]:
|
||||
assert hasattr(model, "forward") or callable(
|
||||
model
|
||||
), "model needs a forward function"
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
for i, output in enumerate(the_output):
|
||||
special_names[id(output.lazydata.realized)] = f'output{i}'
|
||||
return run, special_names
|
||||
@TinyJit
|
||||
def run(*x):
|
||||
out = model.forward(*x) if hasattr(model, "forward") else model(*x)
|
||||
assert (
|
||||
isinstance(out, tuple) or isinstance(out, list) or isinstance(out, Tensor)
|
||||
), "model output must be a Tensor, tuple, or a list of Tensors for export"
|
||||
out = [out] if isinstance(out, Tensor) else out
|
||||
return [o.realize() for o in out]
|
||||
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
|
||||
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
||||
cprog = [CLANG_PROGRAM_HEADER]
|
||||
# twice to run the JIT
|
||||
for _ in range(2):
|
||||
the_output = run(*args)
|
||||
special_names = {}
|
||||
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
# hack to put the inputs back
|
||||
for (j, i), idx in run.input_replace.items():
|
||||
realized_input = args[idx].lazydata.realized
|
||||
run.jit_cache[j].rawbufs[i] = realized_input
|
||||
special_names[id(realized_input)] = f"input{idx}"
|
||||
|
||||
inputs = ", ".join([f'float* {input}' for input in input_names])
|
||||
outputs = ", ".join([f'float* {output}' for output in output_names])
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
|
||||
cprog += list(functions.values())
|
||||
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
for i, output in enumerate(the_output):
|
||||
special_names[id(output.lazydata.realized)] = f"output{i}"
|
||||
return run, special_names
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
|
||||
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
|
||||
def export_model_clang(
|
||||
functions: Dict[str, str],
|
||||
statements: Dict[str, Tuple[str, int, int]],
|
||||
bufs: Dict[str, Tuple[str, int, int]],
|
||||
bufs_to_save: Dict[str, Tensor],
|
||||
input_names: List[str],
|
||||
output_names: List[str],
|
||||
) -> str:
|
||||
from tinygrad.runtime.ops_clang import CLANG_PROGRAM_HEADER
|
||||
|
||||
cprog = [CLANG_PROGRAM_HEADER]
|
||||
|
||||
for name, cl in bufs_to_save.items():
|
||||
weight = "".join(["\\x%02X" % x for x in bytes(cl._buf)])
|
||||
cprog.append(f'unsigned char {name}_data[] = "{weight}";')
|
||||
|
||||
inputs = ", ".join([f"float* {input}" for input in input_names])
|
||||
outputs = ", ".join([f"float* {output}" for output in output_names])
|
||||
cprog += [
|
||||
f"float {name}[{len}];"
|
||||
if name not in bufs_to_save
|
||||
else f"float *{name} = (float *){name}_data;"
|
||||
for name, (len, dtype, _key) in bufs.items()
|
||||
if name not in ["input", "outputs"]
|
||||
]
|
||||
cprog += list(functions.values())
|
||||
cprog += (
|
||||
[f"void net({inputs}, {outputs}) {{"]
|
||||
+ [
|
||||
f"{name}({', '.join(args)});"
|
||||
for (name, args, _global_size, _local_size) in statements
|
||||
]
|
||||
+ ["}"]
|
||||
)
|
||||
return "\n".join(cprog)
|
||||
|
||||
|
||||
def export_model_webgpu(
|
||||
functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names
|
||||
) -> Tuple[str, int, int]:
|
||||
kernel_code = "\n\n".join(
|
||||
[
|
||||
f"const {key} = `{code.replace(key, 'main')}`;"
|
||||
for key, code in functions.items()
|
||||
]
|
||||
)
|
||||
kernel_names = ", ".join(
|
||||
[name for (name, _args, _global_size, _local_size) in statements]
|
||||
)
|
||||
kernel_calls = "\n ".join(
|
||||
[
|
||||
f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});"
|
||||
for i, (_name, args, global_size, _local_size) in enumerate(statements)
|
||||
]
|
||||
)
|
||||
_bufs = "\n ".join(
|
||||
[
|
||||
f"const {name} = "
|
||||
+ (
|
||||
f"createEmptyBuf(device, {size});"
|
||||
if _key not in weight_names
|
||||
else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))"
|
||||
)
|
||||
+ ";"
|
||||
for name, (size, dtype, _key) in bufs.items()
|
||||
]
|
||||
)
|
||||
gpu_write_bufs = "\n ".join(
|
||||
[
|
||||
f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});"
|
||||
for i, input_name in enumerate(input_names)
|
||||
]
|
||||
)
|
||||
input_writers = "\n ".join(
|
||||
[
|
||||
f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set("
|
||||
+ f"_{inp_name});"
|
||||
+ f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);"
|
||||
for i, inp_name in enumerate(input_names)
|
||||
]
|
||||
)
|
||||
gpu_read_bufs = "\n ".join(
|
||||
[
|
||||
f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});"
|
||||
for i, output_name in enumerate(output_names)
|
||||
]
|
||||
)
|
||||
outbuf_copies = "\n ".join(
|
||||
[
|
||||
f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);"
|
||||
for i, output_name in enumerate(output_names)
|
||||
]
|
||||
)
|
||||
output_readers = "\n ".join(
|
||||
[
|
||||
f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();"
|
||||
for i in range(len(output_names))
|
||||
]
|
||||
)
|
||||
output_return = "[{}]".format(
|
||||
",".join([f"resultBuffer{i}" for i in range(len(output_names))])
|
||||
)
|
||||
return (
|
||||
f"""
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
|
@ -134,46 +247,73 @@ const setupNet = async (device, safetensor) => {{
|
|||
return {output_return};
|
||||
}}
|
||||
}}
|
||||
""" + f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
"""
|
||||
+ f"\n\nconst loadNet = async (device) => {{ return await fetch('net.safetensors').then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}"
|
||||
)
|
||||
|
||||
def export_model(model, target:str, *inputs):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
output_names = [name for _,name in special_names.items() if "output" in name]
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names)
|
||||
else:
|
||||
prg = json.dumps({
|
||||
"backend": Device.DEFAULT,
|
||||
"inputs": [{
|
||||
"size": bufs[name][0],
|
||||
"dtype": bufs[name][1].name
|
||||
} for name in input_names],
|
||||
"outputs": [{
|
||||
"size": bufs[name][0],
|
||||
"dtype": bufs[name][1].name
|
||||
} for name in output_names],
|
||||
"functions": functions,
|
||||
"statements": [{
|
||||
"kernel": kernel,
|
||||
"args": args,
|
||||
"global_size": global_size,
|
||||
"local_size": local_size
|
||||
} for (kernel, args, global_size, local_size) in statements],
|
||||
"buffers": {
|
||||
name: {
|
||||
"size": size,
|
||||
"dtype": dtype.name,
|
||||
"id": weight_names[_key] if _key in weight_names else ""
|
||||
} for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"]
|
||||
}
|
||||
})
|
||||
|
||||
return prg, {input:bufs[input][0] for input in input_names}, {output:bufs[output][0] for output in output_names}, state
|
||||
def export_model(model, target: str, *inputs):
|
||||
assert (
|
||||
Device.DEFAULT in EXPORT_SUPPORTED_DEVICE
|
||||
), "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
run, special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
weight_names = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
input_names = [name for _, name in special_names.items() if "input" in name]
|
||||
output_names = [name for _, name in special_names.items() if "output" in name]
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
prg = export_model_clang(
|
||||
functions, statements, bufs, bufs_to_save, input_names, output_names
|
||||
)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(
|
||||
functions,
|
||||
statements,
|
||||
bufs,
|
||||
bufs_to_save,
|
||||
weight_names,
|
||||
input_names,
|
||||
output_names,
|
||||
)
|
||||
else:
|
||||
prg = json.dumps(
|
||||
{
|
||||
"backend": Device.DEFAULT,
|
||||
"inputs": [
|
||||
{"size": bufs[name][0], "dtype": bufs[name][1].name}
|
||||
for name in input_names
|
||||
],
|
||||
"outputs": [
|
||||
{"size": bufs[name][0], "dtype": bufs[name][1].name}
|
||||
for name in output_names
|
||||
],
|
||||
"functions": functions,
|
||||
"statements": [
|
||||
{
|
||||
"kernel": kernel,
|
||||
"args": args,
|
||||
"global_size": global_size,
|
||||
"local_size": local_size,
|
||||
}
|
||||
for (kernel, args, global_size, local_size) in statements
|
||||
],
|
||||
"buffers": {
|
||||
name: {
|
||||
"size": size,
|
||||
"dtype": dtype.name,
|
||||
"id": weight_names[_key] if _key in weight_names else "",
|
||||
}
|
||||
for name, (size, dtype, _key) in bufs.items()
|
||||
if name not in ["input", "outputs"]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
prg,
|
||||
{input: bufs[input][0] for input in input_names},
|
||||
{output: bufs[output][0] for output in output_names},
|
||||
state,
|
||||
)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
|
||||
np.set_printoptions(linewidth=160)
|
||||
np.set_printoptions(linewidth=1000, threshold=10000000000, suppress=False)
|
||||
from tinygrad.runtime.ops_llvm import LLVM, LLVMBuffer, int_const
|
||||
|
@ -11,28 +12,71 @@ from llvmlite import ir # type: ignore
|
|||
# https://github.com/corsix/amx/blob/main/Instructions.md
|
||||
# 12 lines for AMX support
|
||||
from functools import partialmethod
|
||||
|
||||
|
||||
class AMX:
|
||||
@staticmethod
|
||||
def nop_op_imm5(op, imm5, builder): builder.asm(ir.FunctionType(ir.VoidType(), []), f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}", "", tuple(), True)
|
||||
@staticmethod
|
||||
def op_gpr(op, builder, gpr): builder.asm(ir.FunctionType(ir.VoidType(), [ir.IntType(64)]), f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0", "r", (gpr,), True)
|
||||
set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
|
||||
ldx, ldy, stx, sty = partialmethod(op_gpr, 0), partialmethod(op_gpr, 1), partialmethod(op_gpr, 2), partialmethod(op_gpr, 3)
|
||||
ldz, stz, ldzi, stzi = partialmethod(op_gpr, 4), partialmethod(op_gpr, 5), partialmethod(op_gpr, 6), partialmethod(op_gpr, 7)
|
||||
extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
|
||||
fma64, fms64, fma32, fms32 = partialmethod(op_gpr, 10), partialmethod(op_gpr, 11), partialmethod(op_gpr, 12), partialmethod(op_gpr, 13)
|
||||
mac16, fma16, fms16 = partialmethod(op_gpr, 14), partialmethod(op_gpr, 15), partialmethod(op_gpr, 16)
|
||||
vecint, vecfp, matint, matfp, genlut = partialmethod(op_gpr, 18), partialmethod(op_gpr, 19), partialmethod(op_gpr, 20), partialmethod(op_gpr, 21), partialmethod(op_gpr, 22)
|
||||
@staticmethod
|
||||
def nop_op_imm5(op, imm5, builder):
|
||||
builder.asm(
|
||||
ir.FunctionType(ir.VoidType(), []),
|
||||
f".word (0x201000 + ({op} << 5) + {imm5}); amx op {op} imm {imm5}",
|
||||
"",
|
||||
tuple(),
|
||||
True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def op_gpr(op, builder, gpr):
|
||||
builder.asm(
|
||||
ir.FunctionType(ir.VoidType(), [ir.IntType(64)]),
|
||||
f".word (0x201000 + ({op} << 5) + 0$0 - ((0$0 >> 4) * 6)); amx op {op} reg $0",
|
||||
"r",
|
||||
(gpr,),
|
||||
True,
|
||||
)
|
||||
|
||||
set, clr = partialmethod(nop_op_imm5, 17, 0), partialmethod(nop_op_imm5, 17, 1)
|
||||
ldx, ldy, stx, sty = (
|
||||
partialmethod(op_gpr, 0),
|
||||
partialmethod(op_gpr, 1),
|
||||
partialmethod(op_gpr, 2),
|
||||
partialmethod(op_gpr, 3),
|
||||
)
|
||||
ldz, stz, ldzi, stzi = (
|
||||
partialmethod(op_gpr, 4),
|
||||
partialmethod(op_gpr, 5),
|
||||
partialmethod(op_gpr, 6),
|
||||
partialmethod(op_gpr, 7),
|
||||
)
|
||||
extrx, extry = partialmethod(op_gpr, 8), partialmethod(op_gpr, 9)
|
||||
fma64, fms64, fma32, fms32 = (
|
||||
partialmethod(op_gpr, 10),
|
||||
partialmethod(op_gpr, 11),
|
||||
partialmethod(op_gpr, 12),
|
||||
partialmethod(op_gpr, 13),
|
||||
)
|
||||
mac16, fma16, fms16 = (
|
||||
partialmethod(op_gpr, 14),
|
||||
partialmethod(op_gpr, 15),
|
||||
partialmethod(op_gpr, 16),
|
||||
)
|
||||
vecint, vecfp, matint, matfp, genlut = (
|
||||
partialmethod(op_gpr, 18),
|
||||
partialmethod(op_gpr, 19),
|
||||
partialmethod(op_gpr, 20),
|
||||
partialmethod(op_gpr, 21),
|
||||
partialmethod(op_gpr, 22),
|
||||
)
|
||||
|
||||
|
||||
N = 4096
|
||||
#N = 1024
|
||||
#N = 64
|
||||
# N = 1024
|
||||
# N = 64
|
||||
|
||||
#an = np.arange(N*N).reshape(N, N) - 43*64
|
||||
#bn = np.arange(N*N).reshape(N, N)
|
||||
#an = np.ones((N, N)).astype(np.float32)
|
||||
#bn = np.ones((N, N)).astype(np.float32)
|
||||
# an = np.arange(N*N).reshape(N, N) - 43*64
|
||||
# bn = np.arange(N*N).reshape(N, N)
|
||||
# an = np.ones((N, N)).astype(np.float32)
|
||||
# bn = np.ones((N, N)).astype(np.float32)
|
||||
|
||||
# matrix is 64M, max load bandwidth is 57 GB/s
|
||||
# cache line looks like 256 bytes (64 floats)
|
||||
|
@ -49,12 +93,16 @@ cn = (an.T @ bn).T
|
|||
|
||||
a = LLVMBuffer.fromCPU(an)
|
||||
b = LLVMBuffer.fromCPU(bn)
|
||||
#c = LLVMBuffer.fromCPU(np.zeros((N, N)))
|
||||
# c = LLVMBuffer.fromCPU(np.zeros((N, N)))
|
||||
c = LLVMBuffer.fromCPU(np.zeros(256))
|
||||
bufs = [c,a,b]
|
||||
bufs = [c, a, b]
|
||||
|
||||
module = ir.Module(name=__file__)
|
||||
func = ir.Function(module, ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()]*3), name='exec')
|
||||
func = ir.Function(
|
||||
module,
|
||||
ir.FunctionType(ir.IntType(64), [ir.FloatType().as_pointer()] * 3),
|
||||
name="exec",
|
||||
)
|
||||
|
||||
# load all
|
||||
entry = ir.IRBuilder(func.append_basic_block(name="entry"))
|
||||
|
@ -66,25 +114,42 @@ exit = ir.IRBuilder(func.append_basic_block(name="exit"))
|
|||
|
||||
y = loop_1.phi(ir.IntType(64), name="y")
|
||||
y.add_incoming(int_const(0), entry._block)
|
||||
yp = loop_1_exit.add(y, int_const(32*2))
|
||||
yp = loop_1_exit.add(y, int_const(32 * 2))
|
||||
y.add_incoming(yp, loop_1_exit._block)
|
||||
|
||||
prefetch_function = ir.Function(module, ir.FunctionType(ir.VoidType(), [ir.PointerType(ir.FloatType()), ir.IntType(32), ir.IntType(32), ir.IntType(32)]), name="llvm.prefetch")
|
||||
prefetch_function = ir.Function(
|
||||
module,
|
||||
ir.FunctionType(
|
||||
ir.VoidType(),
|
||||
[
|
||||
ir.PointerType(ir.FloatType()),
|
||||
ir.IntType(32),
|
||||
ir.IntType(32),
|
||||
ir.IntType(32),
|
||||
],
|
||||
),
|
||||
name="llvm.prefetch",
|
||||
)
|
||||
|
||||
xptr = y
|
||||
addr = loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
||||
|
||||
#prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType()))
|
||||
#loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)])
|
||||
# prefetch_ptr = loop_1_exit.inttoptr(loop_1_exit.add(addr, int_const(128)), ir.PointerType(ir.FloatType()))
|
||||
# loop_1_exit.call(prefetch_function, [prefetch_ptr, ir.IntType(32)(0), ir.IntType(32)(2), ir.IntType(32)(1)])
|
||||
|
||||
AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1<<62), addr))
|
||||
AMX.ldx(loop_1_exit, loop_1_exit.add(int_const(1 << 62), addr))
|
||||
xptr = loop_1_exit.add(xptr, int_const(32))
|
||||
AMX.ldy(loop_1_exit, loop_1_exit.add(int_const(1<<62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))))
|
||||
AMX.ldy(
|
||||
loop_1_exit,
|
||||
loop_1_exit.add(
|
||||
int_const(1 << 62), loop_1_exit.add(xm, loop_1_exit.mul(int_const(4), xptr))
|
||||
),
|
||||
)
|
||||
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28))
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16*4)<<10))
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 28 | 1 << 20 | (16 * 4) << 10))
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29))
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16*4)))
|
||||
AMX.fma32(loop_1_exit, int_const(1 << 63 | 1 << 29 | 1 << 20 | (16 * 4)))
|
||||
|
||||
AMX.set(entry)
|
||||
|
||||
|
@ -93,7 +158,9 @@ AMX.clr(exit)
|
|||
|
||||
entry.branch(loop_1._block)
|
||||
loop_1.branch(loop_1_exit._block)
|
||||
loop_1_exit.cbranch(loop_1_exit.icmp_unsigned("==", yp, int_const(N*N)), exit._block, loop_1._block)
|
||||
loop_1_exit.cbranch(
|
||||
loop_1_exit.icmp_unsigned("==", yp, int_const(N * N)), exit._block, loop_1._block
|
||||
)
|
||||
exit.ret(int_const(0))
|
||||
|
||||
cfunc = LLVM().exec(module, bufs, N**2)
|
||||
|
@ -168,21 +235,20 @@ cfunc = LLVM().exec(module, bufs, N**3 * 2)
|
|||
|
||||
times = []
|
||||
for i in range(50):
|
||||
st = time.monotonic()
|
||||
cfunc(*[x._buf for x in bufs])
|
||||
et = time.monotonic() - st
|
||||
times.append(et)
|
||||
st = time.monotonic()
|
||||
cfunc(*[x._buf for x in bufs])
|
||||
et = time.monotonic() - st
|
||||
times.append(et)
|
||||
|
||||
print(f"{min(times)*1000:.2f} ms min time, {np.median(times)*1000:.2f} ms median time")
|
||||
print("%.2f GB/s" % ((N*N*4*1e-9)/min(times)))
|
||||
print("%.2f GB/s" % ((N * N * 4 * 1e-9) / min(times)))
|
||||
|
||||
print(c.toCPU().astype(np.int64)[:sn.shape[0]])
|
||||
print(c.toCPU().astype(np.int64)[: sn.shape[0]])
|
||||
print(sn.astype(np.int64))
|
||||
|
||||
np.testing.assert_allclose(c.toCPU()[:sn.shape[0]], sn, atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(c.toCPU()[: sn.shape[0]], sn, atol=1e-4, rtol=1e-4)
|
||||
|
||||
"""
|
||||
print(cn.astype(np.int64))
|
||||
np.testing.assert_allclose(c.toCPU(), cn, atol=1e-4, rtol=1e-5)
|
||||
"""
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import numpy as np
|
||||
|
||||
os.environ["CUDA"] = "1"
|
||||
from tinygrad.runtime.ops_cuda import RawCUDABuffer, CUDAProgram, compile_cuda
|
||||
|
||||
|
@ -7,21 +8,24 @@ FLOAT16 = True
|
|||
ACC_FLOAT16 = False
|
||||
N = 4096
|
||||
|
||||
na = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32)
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32)
|
||||
na = np.random.default_rng().standard_normal(size=(N, N), dtype=np.float32)
|
||||
nb = np.random.default_rng().standard_normal(size=(N, N), dtype=np.float32)
|
||||
|
||||
if FLOAT16:
|
||||
na = na.astype(np.float16)
|
||||
nb = nb.astype(np.float16)
|
||||
na = na.astype(np.float16)
|
||||
nb = nb.astype(np.float16)
|
||||
|
||||
a = RawCUDABuffer.fromCPU(na)
|
||||
b = RawCUDABuffer.fromCPU(nb)
|
||||
c = RawCUDABuffer.fromCPU(np.ones((N,N),dtype=np.float32))
|
||||
c = RawCUDABuffer.fromCPU(np.ones((N, N), dtype=np.float32))
|
||||
|
||||
FLOPS = N*N*N*2
|
||||
BW = N*N*3*4
|
||||
FLOPS = N * N * N * 2
|
||||
BW = N * N * 3 * 4
|
||||
|
||||
prog = CUDAProgram("wmma_example", compile_cuda(f"""
|
||||
prog = CUDAProgram(
|
||||
"wmma_example",
|
||||
compile_cuda(
|
||||
f"""
|
||||
#include <mma.h>
|
||||
using namespace nvcuda;
|
||||
|
||||
|
@ -88,10 +92,23 @@ __global__ void wmma_example({'half' if FLOAT16 else 'float'} *a, {'half' if FLO
|
|||
}}
|
||||
}}
|
||||
}}
|
||||
"""))
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
global_size, local_size = [(N//16)//4, (N//16)//4], [32, 1, 1]
|
||||
tm = min([prog(a, b, c, global_size=global_size, local_size=local_size, wait=True) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
global_size, local_size = [(N // 16) // 4, (N // 16) // 4], [32, 1, 1]
|
||||
tm = min(
|
||||
[
|
||||
prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)
|
||||
for _ in range(20)
|
||||
]
|
||||
)
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(na.T.astype(np.float32) @ nb.T.astype(np.float32), c.toCPU().reshape((N,N)).T, atol=1e-2)
|
||||
np.testing.assert_allclose(
|
||||
na.T.astype(np.float32) @ nb.T.astype(np.float32),
|
||||
c.toCPU().reshape((N, N)).T,
|
||||
atol=1e-2,
|
||||
)
|
||||
|
|
|
@ -15,39 +15,50 @@ from tinygrad.helpers import partition, GlobalCounters, Context, getenv, prod, d
|
|||
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
|
||||
def single_kernel():
|
||||
# single kernel
|
||||
sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4)
|
||||
out = CLBuffer(prod(sz1), dtypes.imageh(sz1))
|
||||
x = CLBuffer(prod(sz2), dtypes.imageh(sz2))
|
||||
w = CLBuffer(prod(sz3), dtypes.imageh(sz3))
|
||||
|
||||
old = CLProgram("r_32_16_16_64_4_4_4", open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read())
|
||||
old_tms = [old([1,1,32], [16,16,1], out, x, w, wait=True)*1e6 for _ in range(5)]
|
||||
print(old_tms, 67.107/min(old_tms)*1e3)
|
||||
exit(0)
|
||||
def single_kernel():
|
||||
# single kernel
|
||||
sz1, sz2, sz3 = (32, 1024, 4), (32, 4096, 4), (16, 256, 4)
|
||||
out = CLBuffer(prod(sz1), dtypes.imageh(sz1))
|
||||
x = CLBuffer(prod(sz2), dtypes.imageh(sz2))
|
||||
w = CLBuffer(prod(sz3), dtypes.imageh(sz3))
|
||||
|
||||
old = CLProgram(
|
||||
"r_32_16_16_64_4_4_4",
|
||||
open(pathlib.Path(__file__).parent / "conv1_reorder.cl").read(),
|
||||
)
|
||||
old_tms = [
|
||||
old([1, 1, 32], [16, 16, 1], out, x, w, wait=True) * 1e6 for _ in range(5)
|
||||
]
|
||||
print(old_tms, 67.107 / min(old_tms) * 1e3)
|
||||
exit(0)
|
||||
|
||||
|
||||
# CONV=0 PYTHONPATH="." LATEDEBUG=5 OPT=99 IMAGE=2 FLOAT16=1 NOLOCALS=1 python3 extra/fastvits/fastvits_speed.py
|
||||
if __name__ == "__main__":
|
||||
#single_kernel()
|
||||
# single_kernel()
|
||||
|
||||
# this is stage 1 in fastvits
|
||||
c1 = Conv2d(256, 64, (1,1), bias=False)
|
||||
c2 = Conv2d(64, 64, (3,3), groups=64, padding=1, bias=False)
|
||||
c3 = Conv2d(64, 64, (7,7), groups=64, padding=3, bias=False)
|
||||
c4 = Conv2d(64, 256, (1,1), bias=False)
|
||||
c5 = Conv2d(256, 64, (1,1), bias=False)
|
||||
# this is stage 1 in fastvits
|
||||
c1 = Conv2d(256, 64, (1, 1), bias=False)
|
||||
c2 = Conv2d(64, 64, (3, 3), groups=64, padding=1, bias=False)
|
||||
c3 = Conv2d(64, 64, (7, 7), groups=64, padding=3, bias=False)
|
||||
c4 = Conv2d(64, 256, (1, 1), bias=False)
|
||||
c5 = Conv2d(256, 64, (1, 1), bias=False)
|
||||
|
||||
# TODO: the elementwise ops shouldn't rerun with normal realize
|
||||
x = Tensor.randn(1, 256, 32, 64)
|
||||
out = x.sequential([c1,c2,c3,c4,c5])
|
||||
schedule = out.lazydata.schedule()
|
||||
# TODO: the elementwise ops shouldn't rerun with normal realize
|
||||
x = Tensor.randn(1, 256, 32, 64)
|
||||
out = x.sequential([c1, c2, c3, c4, c5])
|
||||
schedule = out.lazydata.schedule()
|
||||
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps and any(y.op in ReduceOps for y in x.ast.get_lazyops()))
|
||||
run_schedule(schedule_input)
|
||||
run_schedule(schedule[:getenv("CONV")])
|
||||
print("*** init done ***")
|
||||
schedule, schedule_input = partition(
|
||||
schedule,
|
||||
lambda x: x.ast.op not in LoadOps
|
||||
and any(y.op in ReduceOps for y in x.ast.get_lazyops()),
|
||||
)
|
||||
run_schedule(schedule_input)
|
||||
run_schedule(schedule[: getenv("CONV")])
|
||||
print("*** init done ***")
|
||||
|
||||
GlobalCounters.reset()
|
||||
with Context(DEBUG=getenv("LATEDEBUG", 2), BEAM=getenv("LATEBEAM")):
|
||||
run_schedule(schedule[getenv("CONV"):getenv("CONV")+1])
|
||||
GlobalCounters.reset()
|
||||
with Context(DEBUG=getenv("LATEDEBUG", 2), BEAM=getenv("LATEBEAM")):
|
||||
run_schedule(schedule[getenv("CONV") : getenv("CONV") + 1])
|
||||
|
|
|
@ -1,28 +1,29 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
#os.environ['OMP_NUM_THREADS'] = '1'
|
||||
|
||||
# os.environ['OMP_NUM_THREADS'] = '1'
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
N = 2048
|
||||
if __name__ == "__main__":
|
||||
# N^2
|
||||
A = np.random.randn(N, N).astype(np.float32)
|
||||
# N^2
|
||||
B = np.random.randn(N, N).astype(np.float32)
|
||||
# N^2
|
||||
A = np.random.randn(N, N).astype(np.float32)
|
||||
# N^2
|
||||
B = np.random.randn(N, N).astype(np.float32)
|
||||
|
||||
# 2N compute in N^2 output cells
|
||||
flop = 2*N*N*N
|
||||
#print(f"{flop / 1e9:.2f} GFLOP")
|
||||
# 2N compute in N^2 output cells
|
||||
flop = 2 * N * N * N
|
||||
# print(f"{flop / 1e9:.2f} GFLOP")
|
||||
|
||||
for i in range(4):
|
||||
st = time.monotonic()
|
||||
C = A @ B.T
|
||||
et = time.monotonic()
|
||||
s = et-st
|
||||
print(f"{flop/s * 1e-9:.2f} GFLOP/S, {s*1e3:.2f} ms")
|
||||
for i in range(4):
|
||||
st = time.monotonic()
|
||||
C = A @ B.T
|
||||
et = time.monotonic()
|
||||
s = et - st
|
||||
print(f"{flop/s * 1e-9:.2f} GFLOP/S, {s*1e3:.2f} ms")
|
||||
|
||||
with open("/tmp/matmul", "wb") as f:
|
||||
f.write(A.data)
|
||||
f.write(B.data)
|
||||
f.write(C.data)
|
||||
with open("/tmp/matmul", "wb") as f:
|
||||
f.write(A.data)
|
||||
f.write(B.data)
|
||||
f.write(C.data)
|
||||
|
|
|
@ -62,21 +62,19 @@ from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
|
|||
from tinygrad.helpers import dtypes, prod
|
||||
|
||||
if __name__ == "__main__":
|
||||
out = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
|
||||
x = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
|
||||
w = CLBuffer(prod((256, 512, 4)), dtypes.imageh((256, 512, 4)))
|
||||
b = CLBuffer(1024, dtypes.float)
|
||||
out = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1, 128, 4)))
|
||||
x = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1, 128, 4)))
|
||||
w = CLBuffer(prod((256, 512, 4)), dtypes.imageh((256, 512, 4)))
|
||||
b = CLBuffer(1024, dtypes.float)
|
||||
|
||||
old = CLProgram("re_S256_16_8", old)
|
||||
new = CLProgram("r_256_16_4_8_4", new)
|
||||
old = CLProgram("re_S256_16_8", old)
|
||||
new = CLProgram("r_256_16_4_8_4", new)
|
||||
|
||||
old_tms = []
|
||||
new_tms = []
|
||||
|
||||
for i in range(5):
|
||||
old_tms.append(old([1,1,256], [4,16,1], out, x, w, b, wait=True))
|
||||
new_tms.append(new([256,1,1], [4,16,1], out, x, w, b, wait=True))
|
||||
|
||||
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")
|
||||
old_tms = []
|
||||
new_tms = []
|
||||
|
||||
for i in range(5):
|
||||
old_tms.append(old([1, 1, 256], [4, 16, 1], out, x, w, b, wait=True))
|
||||
new_tms.append(new([256, 1, 1], [4, 16, 1], out, x, w, b, wait=True))
|
||||
|
||||
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")
|
||||
|
|
|
@ -18,24 +18,33 @@ from tinygrad.runtime.ops_hip import HIPAllocator, HIPProgram, compile_hip
|
|||
N = getenv("N", 2048)
|
||||
KX = getenv("KX", 4)
|
||||
KY = getenv("KY", 4)
|
||||
assert N%(16*KX) == 0, f"N must be multiple of {16*KX}"
|
||||
assert N%(16*KY) == 0, f"N must be multiple of {16*KY}"
|
||||
FLOPS = N*N*N*2
|
||||
BW = N*N*3*4
|
||||
assert N % (16 * KX) == 0, f"N must be multiple of {16*KX}"
|
||||
assert N % (16 * KY) == 0, f"N must be multiple of {16*KY}"
|
||||
FLOPS = N * N * N * 2
|
||||
BW = N * N * 3 * 4
|
||||
|
||||
# Can HIPAllocator initialized as device=0 by default?
|
||||
device = 0
|
||||
hipallocator = HIPAllocator(device)
|
||||
a = hipallocator.alloc(N*N*4)
|
||||
b = hipallocator.alloc(N*N*2)
|
||||
c = hipallocator.alloc(N*N*2)
|
||||
na = np.empty(N*N, np.float32)
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32).astype(np.float16)
|
||||
a = hipallocator.alloc(N * N * 4)
|
||||
b = hipallocator.alloc(N * N * 2)
|
||||
c = hipallocator.alloc(N * N * 2)
|
||||
na = np.empty(N * N, np.float32)
|
||||
nb = (
|
||||
np.random.default_rng()
|
||||
.standard_normal(size=(N, N), dtype=np.float32)
|
||||
.astype(np.float16)
|
||||
)
|
||||
nc = (
|
||||
np.random.default_rng()
|
||||
.standard_normal(size=(N, N), dtype=np.float32)
|
||||
.astype(np.float16)
|
||||
)
|
||||
hipallocator.copyin(b, bytearray(nb))
|
||||
hipallocator.copyin(c, bytearray(nc))
|
||||
|
||||
lib = compile_hip(f"""
|
||||
lib = compile_hip(
|
||||
f"""
|
||||
#define F32
|
||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||
|
@ -92,22 +101,41 @@ extern "C" __global__ void __launch_bounds__ (128, 1) test(float* c, __half* a,
|
|||
}}
|
||||
}}
|
||||
}}
|
||||
}}""")
|
||||
}}"""
|
||||
)
|
||||
|
||||
prog = HIPProgram(device, "test", lib)
|
||||
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
ret = time.perf_counter() - st # NOTE: et doesn't contain the launch overhead
|
||||
#print(f"{ret*1e6:.2f} us")
|
||||
return et
|
||||
|
||||
global_size, local_size = [N//(KX*16*2), N//(KY*16*2), 1], [32, 2, 2]
|
||||
print("global/local size", global_size, local_size, f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}")
|
||||
tm = min([timeit(lambda: prog(a, b, c, global_size=global_size, local_size=local_size, wait=True)) for _ in range(1000)])
|
||||
hipallocator.copyout(flat_mv(na.data),a)
|
||||
na = na.reshape(N,N)
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
ret = time.perf_counter() - st # NOTE: et doesn't contain the launch overhead
|
||||
# print(f"{ret*1e6:.2f} us")
|
||||
return et
|
||||
|
||||
|
||||
global_size, local_size = [N // (KX * 16 * 2), N // (KY * 16 * 2), 1], [32, 2, 2]
|
||||
print(
|
||||
"global/local size",
|
||||
global_size,
|
||||
local_size,
|
||||
f"local_size:{prod(local_size)} total_size:{prod(global_size+local_size)}",
|
||||
)
|
||||
tm = min(
|
||||
[
|
||||
timeit(
|
||||
lambda: prog(
|
||||
a, b, c, global_size=global_size, local_size=local_size, wait=True
|
||||
)
|
||||
)
|
||||
for _ in range(1000)
|
||||
]
|
||||
)
|
||||
hipallocator.copyout(flat_mv(na.data), a)
|
||||
na = na.reshape(N, N)
|
||||
comp = nb.astype(np.float32) @ nc.astype(np.float32)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||
)
|
||||
np.testing.assert_allclose(na, comp, atol=1e-2, rtol=1e-2)
|
||||
|
|
|
@ -13,15 +13,21 @@ B = jnp.zeros((1, 1, N, N), dtype)
|
|||
A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
|
||||
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
||||
|
||||
OPS = DEVICES*BS*N*N*N*2
|
||||
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
|
||||
OPS = DEVICES * BS * N * N * N * 2
|
||||
|
||||
|
||||
def matmul(A, B):
|
||||
return jnp.matmul(A, B, preferred_element_type=jnp.float32)
|
||||
|
||||
|
||||
pmatmul = jax.pmap(matmul)
|
||||
|
||||
MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
||||
MAX_TFLOPS = 123 * DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
||||
for i in range(10):
|
||||
st = time.perf_counter()
|
||||
C = pmatmul(A,B).block_until_ready()
|
||||
et = time.perf_counter()-st
|
||||
tflops = (OPS*1e-12)/et
|
||||
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
|
||||
|
||||
st = time.perf_counter()
|
||||
C = pmatmul(A, B).block_until_ready()
|
||||
et = time.perf_counter() - st
|
||||
tflops = (OPS * 1e-12) / et
|
||||
print(
|
||||
f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}"
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
#os.environ["METAL"] = "1"
|
||||
|
||||
# os.environ["METAL"] = "1"
|
||||
import numpy as np
|
||||
|
||||
BS = 64
|
||||
|
@ -11,39 +12,48 @@ PADDING = 0
|
|||
# TODO: this is doing some trick, since with CIN=256 COUT=256 it's over 10.4 TFLOPS.
|
||||
# are winograd convs less flops? it appears so if they are batched
|
||||
# https://www.cse.ust.hk/~weiwa/papers/yan-ppopp20.pdf
|
||||
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
|
||||
FLOPS = BS * K * K * CIN * HW * HW * COUT * 2
|
||||
|
||||
nb = np.random.default_rng().standard_normal(size=(BS,CIN,HW,HW), dtype=np.float32)
|
||||
nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float32)
|
||||
nb = np.random.default_rng().standard_normal(size=(BS, CIN, HW, HW), dtype=np.float32)
|
||||
nc = np.random.default_rng().standard_normal(size=(COUT, CIN, K, K), dtype=np.float32)
|
||||
|
||||
try:
|
||||
import time, torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
import time, torch, torch.mps
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = torch.nn.functional.conv2d(b, c, padding=PADDING)
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
|
||||
b = torch.from_numpy(nb).to("mps")
|
||||
c = torch.from_numpy(nc).to("mps")
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = torch.nn.functional.conv2d(b, c, padding=PADDING)
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
|
||||
except RuntimeError:
|
||||
print("no torch metal conv")
|
||||
print("no torch metal conv")
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad import Device
|
||||
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
|
||||
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return b.conv2d(c, padding=PADDING).realize()
|
||||
return b.conv2d(c, padding=PADDING).realize()
|
||||
|
||||
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
Device[a.device].synchronize()
|
||||
return time.perf_counter() - st
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
Device[a.device].synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([tiny_prog(b, c) for _ in range(5)])
|
||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
os.environ["METAL"] = "1"
|
||||
import time
|
||||
import numpy as np
|
||||
|
@ -8,17 +9,24 @@ from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_met
|
|||
N = getenv("N", 2048)
|
||||
LID = 2
|
||||
|
||||
a = RawMetalBuffer(N*N, dtypes.float32)
|
||||
a = RawMetalBuffer(N * N, dtypes.float32)
|
||||
|
||||
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nb = np.random.default_rng().standard_normal(
|
||||
size=(N, N), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(
|
||||
size=(N, N), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
b = RawMetalBuffer.fromCPU(nb)
|
||||
c = RawMetalBuffer.fromCPU(nc)
|
||||
|
||||
FLOPS = N*N*N*2
|
||||
BW = N*N*3*4
|
||||
FLOPS = N * N * N * 2
|
||||
BW = N * N * 3 * 4
|
||||
|
||||
prog = MetalProgram("test", compile_metal(f"""
|
||||
prog = MetalProgram(
|
||||
"test",
|
||||
compile_metal(
|
||||
f"""
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
||||
using namespace metal;
|
||||
|
@ -80,46 +88,83 @@ kernel void test(device float *a, device const float *data1, device const float
|
|||
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
|
||||
}}"""))
|
||||
}}"""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
|
||||
na = a.toCPU().reshape(N,N)
|
||||
comp = nb@nc
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min(
|
||||
[
|
||||
timeit(
|
||||
lambda: prog(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
global_size=[N // (8 * 4), N // (8 * 4 * LID), 1],
|
||||
local_size=[32, LID, 1],
|
||||
wait=True,
|
||||
)
|
||||
)
|
||||
for _ in range(20)
|
||||
]
|
||||
)
|
||||
na = a.toCPU().reshape(N, N)
|
||||
comp = nb @ nc
|
||||
if N <= 32:
|
||||
print(na)
|
||||
print(comp)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
print(na)
|
||||
print(comp)
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s"
|
||||
)
|
||||
np.testing.assert_allclose(na, comp, atol=1e-3)
|
||||
|
||||
import torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
b = torch.from_numpy(nb).to("mps")
|
||||
c = torch.from_numpy(nc).to("mps")
|
||||
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = b@c
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
st = time.perf_counter()
|
||||
a = b @ c
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch"
|
||||
)
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.runtime.ops_metal import METAL
|
||||
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
|
||||
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return (b@c).realize()
|
||||
return (b @ c).realize()
|
||||
|
||||
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([tiny_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad"
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
#os.environ["METAL"] = "1"
|
||||
|
||||
# os.environ["METAL"] = "1"
|
||||
import numpy as np
|
||||
import time, torch, torch.mps
|
||||
|
||||
|
@ -10,6 +11,7 @@ from tinygrad import Device
|
|||
from tinygrad.helpers import colored, getenv, CI
|
||||
|
||||
import os
|
||||
|
||||
os.environ["METAL"] = "1"
|
||||
import time
|
||||
import numpy as np
|
||||
|
@ -18,29 +20,40 @@ from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_met
|
|||
|
||||
N = 16384
|
||||
M = 4096
|
||||
FLOPS = N*M*2
|
||||
FLOPS = N * M * 2
|
||||
|
||||
nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32)
|
||||
nb = np.random.default_rng().standard_normal(
|
||||
size=(N), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
nc = np.random.default_rng().standard_normal(
|
||||
size=(N, M), dtype=np.float32
|
||||
) # .astype(np.int32).astype(np.float32)
|
||||
|
||||
import torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
b = torch.from_numpy(nb).to("mps")
|
||||
c = torch.from_numpy(nc).to("mps")
|
||||
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = b@c
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
st = time.perf_counter()
|
||||
a = b @ c
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(200)])
|
||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch")
|
||||
torch_a = (b@c).cpu()
|
||||
print(
|
||||
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch"
|
||||
)
|
||||
torch_a = (b @ c).cpu()
|
||||
|
||||
WORKSIZE_ROW = 16
|
||||
WORKSIZE_COL = 1
|
||||
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
|
||||
GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1]
|
||||
prog = compile_metal(f"""
|
||||
GLOBAL_SIZE = [M // (LOCAL_SIZE[0] * LOCAL_SIZE[1] * 4), 1, 1]
|
||||
prog = compile_metal(
|
||||
f"""
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
|
||||
|
@ -86,41 +99,59 @@ kernel void test(device float* data0, const device float* data1, const device fl
|
|||
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
|
||||
}}
|
||||
}}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
prog = MetalProgram("test", prog)
|
||||
# print(prog_string)
|
||||
na = np.zeros(M, dtype=np.float32)
|
||||
b = RawMetalBuffer.fromCPU(nb)
|
||||
c = RawMetalBuffer.fromCPU(nc)
|
||||
|
||||
|
||||
def metalrun():
|
||||
a = RawMetalBuffer.fromCPU(na)
|
||||
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
|
||||
return a
|
||||
a = RawMetalBuffer.fromCPU(na)
|
||||
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
|
||||
return a
|
||||
|
||||
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([timeit(metalrun) for _ in range(200)])
|
||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal")
|
||||
print(
|
||||
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal"
|
||||
)
|
||||
metal_a = metalrun().toCPU().reshape(M)
|
||||
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.runtime.ops_metal import METAL
|
||||
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
|
||||
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return (b@c).realize()
|
||||
return (b @ c).realize()
|
||||
|
||||
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
|
||||
tm = min([tiny_prog(b, c) for _ in range(200)])
|
||||
print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad")
|
||||
print(
|
||||
f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad"
|
||||
)
|
||||
tiny_a = tiny_jit(b, c).numpy()
|
||||
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)
|
||||
np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)
|
||||
|
|
|
@ -2,14 +2,28 @@ import numpy as np
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||
N = getenv("N", 4096)
|
||||
CNT = getenv("CNT", 10)
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
a, b = (
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
)
|
||||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = (
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
Tensor.rand(N, N, dtype=dtype_in).realize(),
|
||||
)
|
||||
c = (
|
||||
(a.reshape(N, 1, N) * b.permute(1, 0).reshape(1, N, N))
|
||||
.float()
|
||||
.sum(axis=2)
|
||||
.realize()
|
||||
if getenv("ACCUM_FP32")
|
||||
else (a @ b).realize()
|
||||
)
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2)
|
||||
|
|
|
@ -1,33 +1,37 @@
|
|||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
gpus = tf.config.list_physical_devices('GPU')
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
if gpus:
|
||||
try:
|
||||
# Currently, memory growth needs to be the same across GPUs
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logical_gpus = tf.config.list_logical_devices('GPU')
|
||||
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
||||
except RuntimeError as e:
|
||||
# Memory growth must be set before GPUs have been initialized
|
||||
print(e)
|
||||
try:
|
||||
# Currently, memory growth needs to be the same across GPUs
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
|
||||
except RuntimeError as e:
|
||||
# Memory growth must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
for dtype in [tf.float16, tf.float32]:
|
||||
for N in [256, 512, 1024, 2048, 4096, 8192]:
|
||||
FLOPS = N*N*N*2
|
||||
for N in [256, 512, 1024, 2048, 4096, 8192]:
|
||||
FLOPS = N * N * N * 2
|
||||
|
||||
b = tf.random.uniform((N, N), dtype=dtype)
|
||||
c = tf.random.uniform((N, N), dtype=dtype)
|
||||
b = tf.random.uniform((N, N), dtype=dtype)
|
||||
c = tf.random.uniform((N, N), dtype=dtype)
|
||||
|
||||
b = tf.Variable(b)
|
||||
c = tf.Variable(c)
|
||||
b = tf.Variable(b)
|
||||
c = tf.Variable(c)
|
||||
|
||||
def tf_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tf.matmul(b, c)
|
||||
tf.debugging.check_numerics(a, "Nan or Inf in result") # Ensures that the calculation is done.
|
||||
return time.perf_counter() - st
|
||||
def tf_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tf.matmul(b, c)
|
||||
tf.debugging.check_numerics(
|
||||
a, "Nan or Inf in result"
|
||||
) # Ensures that the calculation is done.
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([tf_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
|
||||
tm = min([tf_prog(b, c) for _ in range(20)])
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
|
||||
)
|
||||
|
|
|
@ -2,16 +2,19 @@ import time
|
|||
import torch
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
for N in [256, 512, 1024, 2048, 4096]:
|
||||
FLOPS = N*N*N*2
|
||||
for N in [256, 512, 1024, 2048, 4096]:
|
||||
FLOPS = N * N * N * 2
|
||||
|
||||
b = torch.rand((N,N), dtype=dtype).cuda()
|
||||
c = torch.rand((N,N), dtype=dtype).cuda()
|
||||
b = torch.rand((N, N), dtype=dtype).cuda()
|
||||
c = torch.rand((N, N), dtype=dtype).cuda()
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = b@c
|
||||
torch.cuda.synchronize()
|
||||
return time.perf_counter() - st
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = b @ c
|
||||
torch.cuda.synchronize()
|
||||
return time.perf_counter() - st
|
||||
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(
|
||||
f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}"
|
||||
)
|
||||
|
|
|
@ -3,28 +3,29 @@
|
|||
M, N, K = 1024, 1024, 1024
|
||||
|
||||
try:
|
||||
import tvm
|
||||
from tvm import te
|
||||
#print(tvm.target.Target.list_kinds())
|
||||
import tvm
|
||||
from tvm import te
|
||||
|
||||
# c, opencl
|
||||
target = tvm.target.Target(target="c")
|
||||
# print(tvm.target.Target.list_kinds())
|
||||
|
||||
# TVM Matrix Multiplication using TE
|
||||
k = te.reduce_axis((0, K), "k")
|
||||
A = te.placeholder((M, K), name="A")
|
||||
B = te.placeholder((K, N), name="B")
|
||||
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
|
||||
# c, opencl
|
||||
target = tvm.target.Target(target="c")
|
||||
|
||||
# Default schedule
|
||||
s = te.create_schedule(C.op)
|
||||
#print(tvm.lower(s, [A, B, C], simple_mode=True))
|
||||
# TVM Matrix Multiplication using TE
|
||||
k = te.reduce_axis((0, K), "k")
|
||||
A = te.placeholder((M, K), name="A")
|
||||
B = te.placeholder((K, N), name="B")
|
||||
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
|
||||
|
||||
# Output C code
|
||||
func = tvm.build(s, [A, B, C], target=target, name="mmult")
|
||||
print(func.get_source())
|
||||
# Default schedule
|
||||
s = te.create_schedule(C.op)
|
||||
# print(tvm.lower(s, [A, B, C], simple_mode=True))
|
||||
|
||||
# Output C code
|
||||
func = tvm.build(s, [A, B, C], target=target, name="mmult")
|
||||
print(func.get_source())
|
||||
except ImportError:
|
||||
print("** please install TVM for TVM output")
|
||||
print("** please install TVM for TVM output")
|
||||
|
||||
# tinygrad version
|
||||
|
||||
|
@ -34,14 +35,18 @@ from tinygrad.tensor import Tensor
|
|||
# define the compute
|
||||
A = Tensor.rand(M, K, device="clang")
|
||||
B = Tensor.rand(K, N, device="clang")
|
||||
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||
C = (A.reshape(M, 1, K) * B.permute(1, 0).reshape(1, N, K)).sum(axis=2)
|
||||
|
||||
sched = C.lazydata.schedule()
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False))
|
||||
#lin.hand_coded_optimizations()
|
||||
|
||||
lin = Linearizer(
|
||||
sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False)
|
||||
)
|
||||
# lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
from tinygrad.runtime.ops_clang import renderer
|
||||
|
||||
src = renderer("mmult", lin.uops)
|
||||
print(src)
|
||||
|
|
|
@ -1,50 +1,58 @@
|
|||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def mask_like(like, mask_inx, mask_value = 1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
mask[mask_inx] = mask_value
|
||||
return mask.reshape(like.shape)
|
||||
|
||||
def mask_like(like, mask_inx, mask_value=1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
mask[mask_inx] = mask_value
|
||||
return mask.reshape(like.shape)
|
||||
|
||||
|
||||
def jacobian(func, input):
|
||||
output = func(input)
|
||||
|
||||
ji = input.numpy().reshape(-1).shape[-1]
|
||||
jo = output.numpy().reshape(-1).shape[-1]
|
||||
J = np.zeros((jo,ji), dtype=np.float32)
|
||||
|
||||
for o in range(jo):
|
||||
input.grad = None
|
||||
output = func(input)
|
||||
|
||||
# tinygrad doesn't support slicing, tiny-hack to select
|
||||
# the needed scalar an backpropagate only through it
|
||||
o_scalar = Tensor(mask_like(output.numpy(), o, 1.)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
ji = input.numpy().reshape(-1).shape[-1]
|
||||
jo = output.numpy().reshape(-1).shape[-1]
|
||||
J = np.zeros((jo, ji), dtype=np.float32)
|
||||
|
||||
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
|
||||
J[o,i] = grad
|
||||
return J
|
||||
for o in range(jo):
|
||||
input.grad = None
|
||||
output = func(input)
|
||||
|
||||
def numerical_jacobian(func, input, eps = 1e-3):
|
||||
output = func(input)
|
||||
# tinygrad doesn't support slicing, tiny-hack to select
|
||||
# the needed scalar an backpropagate only through it
|
||||
o_scalar = Tensor(mask_like(output.numpy(), o, 1.0)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
|
||||
ji = input.numpy().reshape(-1).shape[-1]
|
||||
jo = output.numpy().reshape(-1).shape[-1]
|
||||
NJ = np.zeros((jo, ji), dtype=np.float32)
|
||||
for i, grad in enumerate(input.grad.numpy().reshape(-1)):
|
||||
J[o, i] = grad
|
||||
return J
|
||||
|
||||
for i in range(ji):
|
||||
eps_perturb = mask_like(input.numpy(), i, mask_value = eps)
|
||||
|
||||
output_perturb_add = func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
|
||||
output_perturb_sub = func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
|
||||
def numerical_jacobian(func, input, eps=1e-3):
|
||||
output = func(input)
|
||||
|
||||
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)
|
||||
ji = input.numpy().reshape(-1).shape[-1]
|
||||
jo = output.numpy().reshape(-1).shape[-1]
|
||||
NJ = np.zeros((jo, ji), dtype=np.float32)
|
||||
|
||||
NJ[:,i] = grad_approx
|
||||
return NJ
|
||||
for i in range(ji):
|
||||
eps_perturb = mask_like(input.numpy(), i, mask_value=eps)
|
||||
|
||||
def gradcheck(func, input, eps = 1e-3, atol = 1e-3, rtol = 1e-3):
|
||||
NJ = numerical_jacobian(func, input, eps)
|
||||
J = jacobian(func, input)
|
||||
return np.allclose(J, NJ, atol = atol, rtol = rtol)
|
||||
output_perturb_add = (
|
||||
func(Tensor(input.numpy() + eps_perturb)).numpy().reshape(-1)
|
||||
)
|
||||
output_perturb_sub = (
|
||||
func(Tensor(input.numpy() - eps_perturb)).numpy().reshape(-1)
|
||||
)
|
||||
|
||||
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2 * eps)
|
||||
|
||||
NJ[:, i] = grad_approx
|
||||
return NJ
|
||||
|
||||
|
||||
def gradcheck(func, input, eps=1e-3, atol=1e-3, rtol=1e-3):
|
||||
NJ = numerical_jacobian(func, input, eps)
|
||||
J = jacobian(func, input)
|
||||
return np.allclose(J, NJ, atol=atol, rtol=rtol)
|
||||
|
|
|
@ -2,49 +2,71 @@ import multiprocessing, subprocess
|
|||
import cloudpickle
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _early_exec_process(qin, qout):
|
||||
while True:
|
||||
path, inp = qin.get()
|
||||
try:
|
||||
qout.put(subprocess.check_output(path, input=inp))
|
||||
except Exception as e:
|
||||
qout.put(e)
|
||||
while True:
|
||||
path, inp = qin.get()
|
||||
try:
|
||||
qout.put(subprocess.check_output(path, input=inp))
|
||||
except Exception as e:
|
||||
qout.put(e)
|
||||
|
||||
|
||||
def enable_early_exec():
|
||||
qin: multiprocessing.Queue = multiprocessing.Queue()
|
||||
qout: multiprocessing.Queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
def early_exec(x):
|
||||
qin.put(x)
|
||||
ret = qout.get()
|
||||
if isinstance(ret, Exception): raise ret
|
||||
else: return ret
|
||||
return early_exec
|
||||
qin: multiprocessing.Queue = multiprocessing.Queue()
|
||||
qout: multiprocessing.Queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=_early_exec_process, args=(qin, qout))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
|
||||
def early_exec(x):
|
||||
qin.put(x)
|
||||
ret = qout.get()
|
||||
if isinstance(ret, Exception):
|
||||
raise ret
|
||||
else:
|
||||
return ret
|
||||
|
||||
return early_exec
|
||||
|
||||
|
||||
def proc(itermaker, q) -> None:
|
||||
try:
|
||||
for x in itermaker(): q.put(x)
|
||||
except Exception as e:
|
||||
q.put(e)
|
||||
finally:
|
||||
q.put(None)
|
||||
q.close()
|
||||
try:
|
||||
for x in itermaker():
|
||||
q.put(x)
|
||||
except Exception as e:
|
||||
q.put(e)
|
||||
finally:
|
||||
q.put(None)
|
||||
q.close()
|
||||
|
||||
|
||||
class _CloudpickleFunctionWrapper:
|
||||
def __init__(self, fn): self.fn = fn
|
||||
def __getstate__(self): return cloudpickle.dumps(self.fn)
|
||||
def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn)
|
||||
def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs)
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __getstate__(self):
|
||||
return cloudpickle.dumps(self.fn)
|
||||
|
||||
def __setstate__(self, pfn):
|
||||
self.fn = cloudpickle.loads(pfn)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
|
||||
def cross_process(itermaker, maxsize=16):
|
||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
||||
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
||||
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
|
||||
p.start()
|
||||
while True:
|
||||
ret = q.get()
|
||||
if isinstance(ret, Exception): raise ret
|
||||
elif ret is None: break
|
||||
else: yield ret
|
||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
||||
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
||||
p = multiprocessing.Process(
|
||||
target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q)
|
||||
)
|
||||
p.start()
|
||||
while True:
|
||||
ret = q.get()
|
||||
if isinstance(ret, Exception):
|
||||
raise ret
|
||||
elif ret is None:
|
||||
break
|
||||
else:
|
||||
yield ret
|
||||
|
|
|
@ -6,37 +6,45 @@ from tinygrad.lazy import LazyBuffer
|
|||
from tinygrad.runtime.ops_gpu import CLBuffer
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
|
||||
|
||||
def print_objects():
|
||||
#gc.collect()
|
||||
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
|
||||
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
|
||||
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
|
||||
gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)]
|
||||
realized_buffers = [x.realized for x in lazybuffers if x.realized]
|
||||
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]
|
||||
# gc.collect()
|
||||
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
|
||||
tensor_ram_used = sum([prod(x.shape) * 4 for x in tensors])
|
||||
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
|
||||
gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)]
|
||||
realized_buffers = [x.realized for x in lazybuffers if x.realized]
|
||||
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]
|
||||
|
||||
print(f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB")
|
||||
print(f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers")
|
||||
print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned")
|
||||
print(
|
||||
f"{len(tensors)} tensors allocated in {tensor_ram_used/1e9:.2f} GB, GPU using {GlobalCounters.mem_used/1e9:.2f} GB"
|
||||
)
|
||||
print(
|
||||
f"{len(lazybuffers)} lazybuffers {len(realized_buffers)} realized, {len(gpubuffers)} GPU buffers"
|
||||
)
|
||||
print(f"{len(gpubuffers_orphaned)} GPU buffers are orphaned")
|
||||
|
||||
cnt = 0
|
||||
for tb in gpubuffers_orphaned:
|
||||
bb = gc.get_referrers(tb)
|
||||
for b in bb:
|
||||
if b is not gpubuffers and b is not gpubuffers_orphaned:
|
||||
print(tb, "\nreference", type(b), len(b), str(b)[0:150])
|
||||
for x in gc.get_referrers(b):
|
||||
print("double reference", str(x)[0:100])
|
||||
print("\n")
|
||||
if cnt == 10:
|
||||
break
|
||||
cnt += 1
|
||||
cnt = 0
|
||||
for tb in gpubuffers_orphaned:
|
||||
bb = gc.get_referrers(tb)
|
||||
for b in bb:
|
||||
if b is not gpubuffers and b is not gpubuffers_orphaned:
|
||||
print(tb, "\nreference", type(b), len(b), str(b)[0:150])
|
||||
for x in gc.get_referrers(b):
|
||||
print("double reference", str(x)[0:100])
|
||||
print("\n")
|
||||
if cnt == 10:
|
||||
break
|
||||
cnt += 1
|
||||
|
||||
for x in gpubuffers_orphaned:
|
||||
if getattr(x, '_buf', None): del x._buf
|
||||
if getattr(x, '_image', None): del x._image
|
||||
for x in gpubuffers_orphaned:
|
||||
if getattr(x, "_buf", None):
|
||||
del x._buf
|
||||
if getattr(x, "_image", None):
|
||||
del x._image
|
||||
|
||||
return len(gpubuffers_orphaned)
|
||||
|
||||
return len(gpubuffers_orphaned)
|
||||
|
||||
"""
|
||||
import gc
|
||||
|
|
|
@ -7,39 +7,44 @@ from google.protobuf import descriptor as _descriptor
|
|||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
|
||||
b'\n\x19sentencepiece_model.proto\x12\rsentencepiece"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03'
|
||||
)
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sentencepiece_model_pb2", _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
_globals['DESCRIPTOR']._options = None
|
||||
_globals['DESCRIPTOR']._serialized_options = b'H\003'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._options = None
|
||||
_globals['_TRAINERSPEC'].fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
|
||||
_globals['_TRAINERSPEC']._serialized_start=45
|
||||
_globals['_TRAINERSPEC']._serialized_end=1581
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_start=1517
|
||||
_globals['_TRAINERSPEC_MODELTYPE']._serialized_end=1570
|
||||
_globals['_NORMALIZERSPEC']._serialized_start=1584
|
||||
_globals['_NORMALIZERSPEC']._serialized_end=1793
|
||||
_globals['_SELFTESTDATA']._serialized_start=1795
|
||||
_globals['_SELFTESTDATA']._serialized_end=1916
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_start=1864
|
||||
_globals['_SELFTESTDATA_SAMPLE']._serialized_end=1905
|
||||
_globals['_MODELPROTO']._serialized_start=1919
|
||||
_globals['_MODELPROTO']._serialized_end=2429
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_start=2208
|
||||
_globals['_MODELPROTO_SENTENCEPIECE']._serialized_end=2418
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_start=2323
|
||||
_globals['_MODELPROTO_SENTENCEPIECE_TYPE']._serialized_end=2407
|
||||
_globals["DESCRIPTOR"]._options = None
|
||||
_globals["DESCRIPTOR"]._serialized_options = b"H\003"
|
||||
_globals["_TRAINERSPEC"].fields_by_name["mining_sentence_size"]._options = None
|
||||
_globals["_TRAINERSPEC"].fields_by_name[
|
||||
"mining_sentence_size"
|
||||
]._serialized_options = b"\030\001"
|
||||
_globals["_TRAINERSPEC"].fields_by_name["training_sentence_size"]._options = None
|
||||
_globals["_TRAINERSPEC"].fields_by_name[
|
||||
"training_sentence_size"
|
||||
]._serialized_options = b"\030\001"
|
||||
_globals["_TRAINERSPEC"]._serialized_start = 45
|
||||
_globals["_TRAINERSPEC"]._serialized_end = 1581
|
||||
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_start = 1517
|
||||
_globals["_TRAINERSPEC_MODELTYPE"]._serialized_end = 1570
|
||||
_globals["_NORMALIZERSPEC"]._serialized_start = 1584
|
||||
_globals["_NORMALIZERSPEC"]._serialized_end = 1793
|
||||
_globals["_SELFTESTDATA"]._serialized_start = 1795
|
||||
_globals["_SELFTESTDATA"]._serialized_end = 1916
|
||||
_globals["_SELFTESTDATA_SAMPLE"]._serialized_start = 1864
|
||||
_globals["_SELFTESTDATA_SAMPLE"]._serialized_end = 1905
|
||||
_globals["_MODELPROTO"]._serialized_start = 1919
|
||||
_globals["_MODELPROTO"]._serialized_end = 2429
|
||||
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_start = 2208
|
||||
_globals["_MODELPROTO_SENTENCEPIECE"]._serialized_end = 2418
|
||||
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_start = 2323
|
||||
_globals["_MODELPROTO_SENTENCEPIECE_TYPE"]._serialized_end = 2407
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
|
@ -3,84 +3,138 @@ from typing import List
|
|||
from tinygrad.nn.optim import Optimizer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
class LR_Scheduler:
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
self.epoch_counter = Tensor(
|
||||
[0], requires_grad=False, device=self.optimizer.device
|
||||
)
|
||||
|
||||
def get_lr(self): pass
|
||||
def get_lr(self):
|
||||
pass
|
||||
|
||||
def step(self) -> None:
|
||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
self.optimizer.lr.assign(self.get_lr()).realize()
|
||||
|
||||
def step(self) -> None:
|
||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
self.optimizer.lr.assign(self.get_lr()).realize()
|
||||
|
||||
class MultiStepLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
|
||||
super().__init__(optimizer)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
def __init__(self, optimizer: Optimizer, milestones: List[int], gamma=0.1):
|
||||
super().__init__(optimizer)
|
||||
self.milestones = milestones
|
||||
self.gamma = gamma
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
if self.epoch_counter.numpy()[0] not in self.milestones:
|
||||
return self.optimizer.lr
|
||||
return self.optimizer.lr * self.gamma
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
if self.epoch_counter.numpy()[0] not in self.milestones:
|
||||
return self.optimizer.lr
|
||||
return self.optimizer.lr * self.gamma
|
||||
|
||||
class ReduceLROnPlateau(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, mode="min", factor=0.1, patience=10, threshold=1e-4, threshold_mode="rel"):
|
||||
assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"]
|
||||
super().__init__(optimizer)
|
||||
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = mode, factor, patience, threshold, threshold_mode
|
||||
self.best = float('inf') if mode == "min" else float('-inf')
|
||||
self.bad_epoch = 0
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
mode="min",
|
||||
factor=0.1,
|
||||
patience=10,
|
||||
threshold=1e-4,
|
||||
threshold_mode="rel",
|
||||
):
|
||||
assert mode in ["min", "max"] and threshold_mode in ["rel", "abs"]
|
||||
super().__init__(optimizer)
|
||||
self.mode, self.factor, self.patience, self.threshold, self.threshold_mode = (
|
||||
mode,
|
||||
factor,
|
||||
patience,
|
||||
threshold,
|
||||
threshold_mode,
|
||||
)
|
||||
self.best = float("inf") if mode == "min" else float("-inf")
|
||||
self.bad_epoch = 0
|
||||
|
||||
if mode == "min": self.threshold *= -1
|
||||
if mode == "min":
|
||||
self.threshold *= -1
|
||||
|
||||
def is_better(self, current: float) -> bool:
|
||||
dynamic_threshold = self.best*(1+self.threshold) if self.threshold_mode == "rel" else self.best+self.threshold
|
||||
if self.mode == "min":
|
||||
return current < dynamic_threshold
|
||||
return current > dynamic_threshold
|
||||
def is_better(self, current: float) -> bool:
|
||||
dynamic_threshold = (
|
||||
self.best * (1 + self.threshold)
|
||||
if self.threshold_mode == "rel"
|
||||
else self.best + self.threshold
|
||||
)
|
||||
if self.mode == "min":
|
||||
return current < dynamic_threshold
|
||||
return current > dynamic_threshold
|
||||
|
||||
def step(self, current: float) -> None:
|
||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
if self.is_better(current):
|
||||
self.bad_epoch = 0
|
||||
self.best = current
|
||||
else:
|
||||
self.bad_epoch += 1
|
||||
def step(self, current: float) -> None:
|
||||
self.epoch_counter.assign(self.epoch_counter + 1).realize()
|
||||
if self.is_better(current):
|
||||
self.bad_epoch = 0
|
||||
self.best = current
|
||||
else:
|
||||
self.bad_epoch += 1
|
||||
|
||||
if self.bad_epoch > self.patience:
|
||||
self.optimizer.lr *= self.factor
|
||||
self.bad_epoch = 0
|
||||
|
||||
if self.bad_epoch > self.patience:
|
||||
self.optimizer.lr *= self.factor
|
||||
self.bad_epoch = 0
|
||||
|
||||
class CosineAnnealingLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0):
|
||||
super().__init__(optimizer)
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
self.eta_max = optimizer.lr.numpy()[0]
|
||||
def __init__(self, optimizer: Optimizer, T_max: int, eta_min=0):
|
||||
super().__init__(optimizer)
|
||||
self.T_max = T_max
|
||||
self.eta_min = eta_min
|
||||
self.eta_max = optimizer.lr.numpy()[0]
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
return Tensor(
|
||||
[
|
||||
self.eta_min
|
||||
+ 0.5
|
||||
* (self.eta_max - self.eta_min)
|
||||
* (1 + math.cos((self.epoch_counter.numpy()[0] / self.T_max) * math.pi))
|
||||
],
|
||||
device=self.optimizer.device,
|
||||
)
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device)
|
||||
|
||||
class OneCycleLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
|
||||
anneal_strategy: str = 'linear', cycle_momentum: bool = False):
|
||||
self.initial_lr = Tensor([max_lr / div_factor]).contiguous()
|
||||
self.max_lr = Tensor([max_lr]).contiguous()
|
||||
self.min_lr = self.initial_lr/final_div_factor
|
||||
super().__init__(optimizer)
|
||||
self.total_steps = total_steps
|
||||
self.pct_start = pct_start
|
||||
assert anneal_strategy == 'linear', 'only linear annealing supported'
|
||||
assert not cycle_momentum, 'cycle momentum not supported'
|
||||
self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
max_lr: float,
|
||||
div_factor: float,
|
||||
final_div_factor: float,
|
||||
total_steps: int,
|
||||
pct_start: float,
|
||||
anneal_strategy: str = "linear",
|
||||
cycle_momentum: bool = False,
|
||||
):
|
||||
self.initial_lr = Tensor([max_lr / div_factor]).contiguous()
|
||||
self.max_lr = Tensor([max_lr]).contiguous()
|
||||
self.min_lr = self.initial_lr / final_div_factor
|
||||
super().__init__(optimizer)
|
||||
self.total_steps = total_steps
|
||||
self.pct_start = pct_start
|
||||
assert anneal_strategy == "linear", "only linear annealing supported"
|
||||
assert not cycle_momentum, "cycle momentum not supported"
|
||||
self.optimizer.lr.assign(self.get_lr()).realize() # update the initial LR
|
||||
|
||||
@staticmethod
|
||||
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor: return ((end - start) * pct + start)
|
||||
@staticmethod
|
||||
def _annealing_linear(start: Tensor, end: Tensor, pct: Tensor) -> Tensor:
|
||||
return (end - start) * pct + start
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
return (self.epoch_counter < self.total_steps*self.pct_start).where(
|
||||
self._annealing_linear(self.initial_lr, self.max_lr, self.epoch_counter/(self.total_steps*self.pct_start)),
|
||||
self._annealing_linear(self.max_lr, self.min_lr, (self.epoch_counter-(self.total_steps*self.pct_start))/(self.total_steps*(1-self.pct_start)))
|
||||
)
|
||||
def get_lr(self) -> Tensor:
|
||||
return (self.epoch_counter < self.total_steps * self.pct_start).where(
|
||||
self._annealing_linear(
|
||||
self.initial_lr,
|
||||
self.max_lr,
|
||||
self.epoch_counter / (self.total_steps * self.pct_start),
|
||||
),
|
||||
self._annealing_linear(
|
||||
self.max_lr,
|
||||
self.min_lr,
|
||||
(self.epoch_counter - (self.total_steps * self.pct_start))
|
||||
/ (self.total_steps * (1 - self.pct_start)),
|
||||
),
|
||||
)
|
||||
|
|
|
@ -5,167 +5,290 @@ from pathlib import Path
|
|||
|
||||
|
||||
class BertForQuestionAnswering:
|
||||
def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1):
|
||||
self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
self.qa_outputs = Linear(hidden_size, 2)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
intermediate_size=4096,
|
||||
max_position_embeddings=512,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=24,
|
||||
type_vocab_size=2,
|
||||
vocab_size=30522,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
hidden_dropout_prob=0.1,
|
||||
):
|
||||
self.bert = Bert(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
max_position_embeddings,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
self.qa_outputs = Linear(hidden_size, 2)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt"
|
||||
fetch("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn)
|
||||
fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt"
|
||||
fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt"
|
||||
fetch("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn)
|
||||
fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt"
|
||||
fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)
|
||||
|
||||
import torch
|
||||
with open(fn, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
import torch
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if "dropout" in k: continue # skip dropout
|
||||
if "pooler" in k: continue # skip pooler
|
||||
get_child(self, k).assign(v.numpy()).realize()
|
||||
with open(fn, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
|
||||
def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor):
|
||||
sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.chunk(2, dim=-1)
|
||||
start_logits = start_logits.reshape(-1, 1)
|
||||
end_logits = end_logits.reshape(-1, 1)
|
||||
for k, v in state_dict.items():
|
||||
if "dropout" in k:
|
||||
continue # skip dropout
|
||||
if "pooler" in k:
|
||||
continue # skip pooler
|
||||
get_child(self, k).assign(v.numpy()).realize()
|
||||
|
||||
def __call__(
|
||||
self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor
|
||||
):
|
||||
sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.chunk(2, dim=-1)
|
||||
start_logits = start_logits.reshape(-1, 1)
|
||||
end_logits = end_logits.reshape(-1, 1)
|
||||
|
||||
return Tensor.stack([start_logits, end_logits])
|
||||
|
||||
return Tensor.stack([start_logits, end_logits])
|
||||
|
||||
class Bert:
|
||||
def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
|
||||
self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
max_position_embeddings,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.embeddings = BertEmbeddings(
|
||||
hidden_size,
|
||||
max_position_embeddings,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
self.encoder = BertEncoder(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids):
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids):
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask)
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
class BertEmbeddings:
|
||||
def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob):
|
||||
self.word_embeddings = Embedding(vocab_size, hidden_size)
|
||||
self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
|
||||
self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
|
||||
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
||||
self.dropout = hidden_dropout_prob
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
max_position_embeddings,
|
||||
type_vocab_size,
|
||||
vocab_size,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.word_embeddings = Embedding(vocab_size, hidden_size)
|
||||
self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
|
||||
self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
|
||||
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
||||
self.dropout = hidden_dropout_prob
|
||||
|
||||
def __call__(self, input_ids, token_type_ids):
|
||||
input_shape = input_ids.shape
|
||||
seq_length = input_shape[1]
|
||||
def __call__(self, input_ids, token_type_ids):
|
||||
input_shape = input_ids.shape
|
||||
seq_length = input_shape[1]
|
||||
|
||||
position_ids = Tensor.arange(seq_length, requires_grad=False).unsqueeze(0).expand(*input_shape)
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
position_ids = (
|
||||
Tensor.arange(seq_length, requires_grad=False)
|
||||
.unsqueeze(0)
|
||||
.expand(*input_shape)
|
||||
)
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = embeddings.dropout(self.dropout)
|
||||
return embeddings
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = embeddings.dropout(self.dropout)
|
||||
return embeddings
|
||||
|
||||
class BertEncoder:
|
||||
def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)]
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
num_hidden_layers,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.layer = [
|
||||
BertLayer(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
for _ in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states, attention_mask)
|
||||
return hidden_states
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states, attention_mask)
|
||||
return hidden_states
|
||||
|
||||
class BertLayer:
|
||||
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
|
||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size)
|
||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.attention = BertAttention(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
)
|
||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size)
|
||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
class BertOutput:
|
||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
|
||||
self.dense = Linear(intermediate_size, hidden_size)
|
||||
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
||||
self.dropout = hidden_dropout_prob
|
||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
|
||||
self.dense = Linear(intermediate_size, hidden_size)
|
||||
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
||||
self.dropout = hidden_dropout_prob
|
||||
|
||||
def __call__(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = hidden_states.dropout(self.dropout)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
def __call__(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = hidden_states.dropout(self.dropout)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
# approximation of the error function
|
||||
def erf(x):
|
||||
t = (1 + 0.3275911 * x.abs()).reciprocal()
|
||||
return x.sign() * (1 - ((((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736) * t + 0.254829592) * t * (-(x.square())).exp())
|
||||
t = (1 + 0.3275911 * x.abs()).reciprocal()
|
||||
return x.sign() * (
|
||||
1
|
||||
- (
|
||||
(((1.061405429 * t + -1.453152027) * t + 1.421413741) * t + -0.284496736)
|
||||
* t
|
||||
+ 0.254829592
|
||||
)
|
||||
* t
|
||||
* (-(x.square())).exp()
|
||||
)
|
||||
|
||||
|
||||
class BertIntermediate:
|
||||
def __init__(self, hidden_size, intermediate_size):
|
||||
self.dense = Linear(hidden_size, intermediate_size)
|
||||
def __init__(self, hidden_size, intermediate_size):
|
||||
self.dense = Linear(hidden_size, intermediate_size)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
x = self.dense(hidden_states)
|
||||
# tinygrad gelu is openai gelu but we need the original bert gelu
|
||||
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
x = self.dense(hidden_states)
|
||||
# tinygrad gelu is openai gelu but we need the original bert gelu
|
||||
return x * 0.5 * (1.0 + erf(x / 1.41421))
|
||||
|
||||
class BertAttention:
|
||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
|
||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
|
||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_probs_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
):
|
||||
self.self = BertSelfAttention(
|
||||
hidden_size, num_attention_heads, attention_probs_dropout_prob
|
||||
)
|
||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
self_output = self.self(hidden_states, attention_mask)
|
||||
attention_output = self.output(self_output, hidden_states)
|
||||
return attention_output
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
self_output = self.self(hidden_states, attention_mask)
|
||||
attention_output = self.output(self_output, hidden_states)
|
||||
return attention_output
|
||||
|
||||
class BertSelfAttention:
|
||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_size = int(hidden_size / num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_size = int(hidden_size / num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = Linear(hidden_size, self.all_head_size)
|
||||
self.key = Linear(hidden_size, self.all_head_size)
|
||||
self.value = Linear(hidden_size, self.all_head_size)
|
||||
self.query = Linear(hidden_size, self.all_head_size)
|
||||
self.key = Linear(hidden_size, self.all_head_size)
|
||||
self.value = Linear(hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = attention_probs_dropout_prob
|
||||
self.dropout = attention_probs_dropout_prob
|
||||
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
def __call__(self, hidden_states, attention_mask):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
context_layer = Tensor.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, self.dropout)
|
||||
context_layer = Tensor.scaled_dot_product_attention(
|
||||
query_layer, key_layer, value_layer, attention_mask, self.dropout
|
||||
)
|
||||
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
context_layer = context_layer.reshape(
|
||||
context_layer.shape[0], context_layer.shape[1], self.all_head_size
|
||||
)
|
||||
|
||||
return context_layer
|
||||
return context_layer
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
x = x.reshape(
|
||||
x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size
|
||||
)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size)
|
||||
return x.transpose(1, 2)
|
||||
|
||||
class BertSelfOutput:
|
||||
def __init__(self, hidden_size, hidden_dropout_prob):
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
||||
self.dropout = hidden_dropout_prob
|
||||
def __init__(self, hidden_size, hidden_dropout_prob):
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
||||
self.dropout = hidden_dropout_prob
|
||||
|
||||
def __call__(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = hidden_states.dropout(self.dropout)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
def __call__(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = hidden_states.dropout(self.dropout)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
|
|
@ -2,64 +2,99 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
||||
from tinygrad.helpers import fetch, get_child
|
||||
|
||||
class Block:
|
||||
def __init__(self, dim):
|
||||
self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = Linear(dim, 4 * dim)
|
||||
self.pwconv2 = Linear(4 * dim, dim)
|
||||
self.gamma = Tensor.ones(dim)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x + x.sequential([
|
||||
self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm,
|
||||
self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2)
|
||||
])
|
||||
class Block:
|
||||
def __init__(self, dim):
|
||||
self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
||||
self.norm = LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = Linear(dim, 4 * dim)
|
||||
self.pwconv2 = Linear(4 * dim, dim)
|
||||
self.gamma = Tensor.ones(dim)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
return x + x.sequential(
|
||||
[
|
||||
self.dwconv,
|
||||
lambda x: x.permute(0, 2, 3, 1),
|
||||
self.norm,
|
||||
self.pwconv1,
|
||||
Tensor.gelu,
|
||||
self.pwconv2,
|
||||
lambda x: (self.gamma * x).permute(0, 3, 1, 2),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConvNeXt:
|
||||
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
|
||||
self.downsample_layers = [
|
||||
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
|
||||
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
||||
]
|
||||
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
|
||||
self.norm = LayerNorm(dims[-1])
|
||||
self.head = Linear(dims[-1], num_classes)
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
depths=[3, 3, 9, 3],
|
||||
dims=[96, 192, 384, 768],
|
||||
):
|
||||
self.downsample_layers = [
|
||||
[
|
||||
Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
||||
LayerNorm2d(dims[0], eps=1e-6),
|
||||
],
|
||||
*[
|
||||
[
|
||||
LayerNorm2d(dims[i], eps=1e-6),
|
||||
Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
||||
]
|
||||
for i in range(len(dims) - 1)
|
||||
],
|
||||
]
|
||||
self.stages = [
|
||||
[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))
|
||||
]
|
||||
self.norm = LayerNorm(dims[-1])
|
||||
self.head = Linear(dims[-1], num_classes)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
for downsample, stage in zip(self.downsample_layers, self.stages):
|
||||
x = x.sequential(downsample).sequential(stage)
|
||||
return x.mean([-2, -1]).sequential([self.norm, self.head])
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
for downsample, stage in zip(self.downsample_layers, self.stages):
|
||||
x = x.sequential(downsample).sequential(stage)
|
||||
return x.mean([-2, -1]).sequential([self.norm, self.head])
|
||||
|
||||
# *** model definition is done ***
|
||||
|
||||
versions = {
|
||||
"tiny": {"depths": [3, 3, 9, 3], "dims": [96, 192, 384, 768]},
|
||||
"small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
|
||||
"base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
|
||||
"large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
|
||||
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]}
|
||||
"tiny": {"depths": [3, 3, 9, 3], "dims": [96, 192, 384, 768]},
|
||||
"small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
|
||||
"base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
|
||||
"large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
|
||||
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]},
|
||||
}
|
||||
|
||||
|
||||
def get_model(version, load_weights=False):
|
||||
model = ConvNeXt(**versions[version])
|
||||
if load_weights:
|
||||
from tinygrad.nn.state import torch_load
|
||||
weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
|
||||
for k,v in weights.items():
|
||||
mv = get_child(model, k)
|
||||
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
|
||||
return model
|
||||
model = ConvNeXt(**versions[version])
|
||||
if load_weights:
|
||||
from tinygrad.nn.state import torch_load
|
||||
|
||||
weights = torch_load(
|
||||
fetch(
|
||||
f"https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth"
|
||||
)
|
||||
)["model"]
|
||||
for k, v in weights.items():
|
||||
mv = get_child(model, k)
|
||||
mv.assign(v.reshape(mv.shape).to(mv.device)).realize()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = get_model("tiny", True)
|
||||
model = get_model("tiny", True)
|
||||
|
||||
# load image
|
||||
from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
|
||||
img = Tensor(preprocess(chicken_img))
|
||||
# load image
|
||||
from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
|
||||
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
img = Tensor(preprocess(chicken_img))
|
||||
|
||||
out = model(img).numpy()
|
||||
print(_LABELS[out.argmax()])
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
out = model(img).numpy()
|
||||
print(_LABELS[out.argmax()])
|
||||
|
|
|
@ -4,161 +4,218 @@ from tinygrad.nn import BatchNorm2d
|
|||
from tinygrad.helpers import get_child, fetch
|
||||
from tinygrad.nn.state import torch_load
|
||||
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
|
||||
oup = expand_ratio * input_filters
|
||||
if expand_ratio != 1:
|
||||
self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
|
||||
self._bn0 = BatchNorm2d(oup, track_running_stats=track_running_stats)
|
||||
else:
|
||||
self._expand_conv = None
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size,
|
||||
strides,
|
||||
expand_ratio,
|
||||
input_filters,
|
||||
output_filters,
|
||||
se_ratio,
|
||||
has_se,
|
||||
track_running_stats=True,
|
||||
):
|
||||
oup = expand_ratio * input_filters
|
||||
if expand_ratio != 1:
|
||||
self._expand_conv = Tensor.glorot_uniform(oup, input_filters, 1, 1)
|
||||
self._bn0 = BatchNorm2d(oup, track_running_stats=track_running_stats)
|
||||
else:
|
||||
self._expand_conv = None
|
||||
|
||||
self.strides = strides
|
||||
if strides == (2,2):
|
||||
self.pad = [(kernel_size-1)//2-1, (kernel_size-1)//2]*2
|
||||
else:
|
||||
self.pad = [(kernel_size-1)//2]*4
|
||||
self.strides = strides
|
||||
if strides == (2, 2):
|
||||
self.pad = [(kernel_size - 1) // 2 - 1, (kernel_size - 1) // 2] * 2
|
||||
else:
|
||||
self.pad = [(kernel_size - 1) // 2] * 4
|
||||
|
||||
self._depthwise_conv = Tensor.glorot_uniform(oup, 1, kernel_size, kernel_size)
|
||||
self._bn1 = BatchNorm2d(oup, track_running_stats=track_running_stats)
|
||||
self._depthwise_conv = Tensor.glorot_uniform(oup, 1, kernel_size, kernel_size)
|
||||
self._bn1 = BatchNorm2d(oup, track_running_stats=track_running_stats)
|
||||
|
||||
self.has_se = has_se
|
||||
if self.has_se:
|
||||
num_squeezed_channels = max(1, int(input_filters * se_ratio))
|
||||
self._se_reduce = Tensor.glorot_uniform(num_squeezed_channels, oup, 1, 1)
|
||||
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
|
||||
self._se_expand = Tensor.glorot_uniform(oup, num_squeezed_channels, 1, 1)
|
||||
self._se_expand_bias = Tensor.zeros(oup)
|
||||
self.has_se = has_se
|
||||
if self.has_se:
|
||||
num_squeezed_channels = max(1, int(input_filters * se_ratio))
|
||||
self._se_reduce = Tensor.glorot_uniform(num_squeezed_channels, oup, 1, 1)
|
||||
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
|
||||
self._se_expand = Tensor.glorot_uniform(oup, num_squeezed_channels, 1, 1)
|
||||
self._se_expand_bias = Tensor.zeros(oup)
|
||||
|
||||
self._project_conv = Tensor.glorot_uniform(output_filters, oup, 1, 1)
|
||||
self._bn2 = BatchNorm2d(output_filters, track_running_stats=track_running_stats)
|
||||
self._project_conv = Tensor.glorot_uniform(output_filters, oup, 1, 1)
|
||||
self._bn2 = BatchNorm2d(output_filters, track_running_stats=track_running_stats)
|
||||
|
||||
def __call__(self, inputs):
|
||||
x = inputs
|
||||
if self._expand_conv:
|
||||
x = self._bn0(x.conv2d(self._expand_conv)).swish()
|
||||
x = x.conv2d(self._depthwise_conv, padding=self.pad, stride=self.strides, groups=self._depthwise_conv.shape[0])
|
||||
x = self._bn1(x).swish()
|
||||
def __call__(self, inputs):
|
||||
x = inputs
|
||||
if self._expand_conv:
|
||||
x = self._bn0(x.conv2d(self._expand_conv)).swish()
|
||||
x = x.conv2d(
|
||||
self._depthwise_conv,
|
||||
padding=self.pad,
|
||||
stride=self.strides,
|
||||
groups=self._depthwise_conv.shape[0],
|
||||
)
|
||||
x = self._bn1(x).swish()
|
||||
|
||||
if self.has_se:
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
|
||||
x = x.mul(x_squeezed.sigmoid())
|
||||
if self.has_se:
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(
|
||||
self._se_reduce, self._se_reduce_bias
|
||||
).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
|
||||
x = x.mul(x_squeezed.sigmoid())
|
||||
|
||||
x = self._bn2(x.conv2d(self._project_conv))
|
||||
if x.shape == inputs.shape:
|
||||
x = x.add(inputs)
|
||||
return x
|
||||
|
||||
x = self._bn2(x.conv2d(self._project_conv))
|
||||
if x.shape == inputs.shape:
|
||||
x = x.add(inputs)
|
||||
return x
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self, number=0, classes=1000, has_se=True, track_running_stats=True, input_channels=3, has_fc_output=True):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
(1.0, 1.0), # b0
|
||||
(1.0, 1.1), # b1
|
||||
(1.1, 1.2), # b2
|
||||
(1.2, 1.4), # b3
|
||||
(1.4, 1.8), # b4
|
||||
(1.6, 2.2), # b5
|
||||
(1.8, 2.6), # b6
|
||||
(2.0, 3.1), # b7
|
||||
(2.2, 3.6), # b8
|
||||
(4.3, 5.3), # l2
|
||||
][max(number,0)]
|
||||
def __init__(
|
||||
self,
|
||||
number=0,
|
||||
classes=1000,
|
||||
has_se=True,
|
||||
track_running_stats=True,
|
||||
input_channels=3,
|
||||
has_fc_output=True,
|
||||
):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
(1.0, 1.0), # b0
|
||||
(1.0, 1.1), # b1
|
||||
(1.1, 1.2), # b2
|
||||
(1.2, 1.4), # b3
|
||||
(1.4, 1.8), # b4
|
||||
(1.6, 2.2), # b5
|
||||
(1.8, 2.6), # b6
|
||||
(2.0, 3.1), # b7
|
||||
(2.2, 3.6), # b8
|
||||
(4.3, 5.3), # l2
|
||||
][max(number, 0)]
|
||||
|
||||
def round_filters(filters):
|
||||
multiplier = global_params[0]
|
||||
divisor = 8
|
||||
filters *= multiplier
|
||||
new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
|
||||
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
def round_filters(filters):
|
||||
multiplier = global_params[0]
|
||||
divisor = 8
|
||||
filters *= multiplier
|
||||
new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
|
||||
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
def round_repeats(repeats):
|
||||
return int(math.ceil(global_params[1] * repeats))
|
||||
def round_repeats(repeats):
|
||||
return int(math.ceil(global_params[1] * repeats))
|
||||
|
||||
out_channels = round_filters(32)
|
||||
self._conv_stem = Tensor.glorot_uniform(out_channels, input_channels, 3, 3)
|
||||
self._bn0 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
|
||||
blocks_args = [
|
||||
[1, 3, (1,1), 1, 32, 16, 0.25],
|
||||
[2, 3, (2,2), 6, 16, 24, 0.25],
|
||||
[2, 5, (2,2), 6, 24, 40, 0.25],
|
||||
[3, 3, (2,2), 6, 40, 80, 0.25],
|
||||
[3, 5, (1,1), 6, 80, 112, 0.25],
|
||||
[4, 5, (2,2), 6, 112, 192, 0.25],
|
||||
[1, 3, (1,1), 6, 192, 320, 0.25],
|
||||
]
|
||||
out_channels = round_filters(32)
|
||||
self._conv_stem = Tensor.glorot_uniform(out_channels, input_channels, 3, 3)
|
||||
self._bn0 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
|
||||
blocks_args = [
|
||||
[1, 3, (1, 1), 1, 32, 16, 0.25],
|
||||
[2, 3, (2, 2), 6, 16, 24, 0.25],
|
||||
[2, 5, (2, 2), 6, 24, 40, 0.25],
|
||||
[3, 3, (2, 2), 6, 40, 80, 0.25],
|
||||
[3, 5, (1, 1), 6, 80, 112, 0.25],
|
||||
[4, 5, (2, 2), 6, 112, 192, 0.25],
|
||||
[1, 3, (1, 1), 6, 192, 320, 0.25],
|
||||
]
|
||||
|
||||
if self.number == -1:
|
||||
blocks_args = [
|
||||
[1, 3, (2,2), 1, 32, 40, 0.25],
|
||||
[1, 3, (2,2), 1, 40, 80, 0.25],
|
||||
[1, 3, (2,2), 1, 80, 192, 0.25],
|
||||
[1, 3, (2,2), 1, 192, 320, 0.25],
|
||||
]
|
||||
elif self.number == -2:
|
||||
blocks_args = [
|
||||
[1, 9, (8,8), 1, 32, 320, 0.25],
|
||||
]
|
||||
if self.number == -1:
|
||||
blocks_args = [
|
||||
[1, 3, (2, 2), 1, 32, 40, 0.25],
|
||||
[1, 3, (2, 2), 1, 40, 80, 0.25],
|
||||
[1, 3, (2, 2), 1, 80, 192, 0.25],
|
||||
[1, 3, (2, 2), 1, 192, 320, 0.25],
|
||||
]
|
||||
elif self.number == -2:
|
||||
blocks_args = [
|
||||
[1, 9, (8, 8), 1, 32, 320, 0.25],
|
||||
]
|
||||
|
||||
self._blocks = []
|
||||
for num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio in blocks_args:
|
||||
input_filters, output_filters = round_filters(input_filters), round_filters(output_filters)
|
||||
for n in range(round_repeats(num_repeats)):
|
||||
self._blocks.append(MBConvBlock(kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se=has_se, track_running_stats=track_running_stats))
|
||||
input_filters = output_filters
|
||||
strides = (1,1)
|
||||
self._blocks = []
|
||||
for (
|
||||
num_repeats,
|
||||
kernel_size,
|
||||
strides,
|
||||
expand_ratio,
|
||||
input_filters,
|
||||
output_filters,
|
||||
se_ratio,
|
||||
) in blocks_args:
|
||||
input_filters, output_filters = round_filters(input_filters), round_filters(
|
||||
output_filters
|
||||
)
|
||||
for n in range(round_repeats(num_repeats)):
|
||||
self._blocks.append(
|
||||
MBConvBlock(
|
||||
kernel_size,
|
||||
strides,
|
||||
expand_ratio,
|
||||
input_filters,
|
||||
output_filters,
|
||||
se_ratio,
|
||||
has_se=has_se,
|
||||
track_running_stats=track_running_stats,
|
||||
)
|
||||
)
|
||||
input_filters = output_filters
|
||||
strides = (1, 1)
|
||||
|
||||
in_channels = round_filters(320)
|
||||
out_channels = round_filters(1280)
|
||||
self._conv_head = Tensor.glorot_uniform(out_channels, in_channels, 1, 1)
|
||||
self._bn1 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
|
||||
if has_fc_output:
|
||||
self._fc = Tensor.glorot_uniform(out_channels, classes)
|
||||
self._fc_bias = Tensor.zeros(classes)
|
||||
else:
|
||||
self._fc = None
|
||||
in_channels = round_filters(320)
|
||||
out_channels = round_filters(1280)
|
||||
self._conv_head = Tensor.glorot_uniform(out_channels, in_channels, 1, 1)
|
||||
self._bn1 = BatchNorm2d(out_channels, track_running_stats=track_running_stats)
|
||||
if has_fc_output:
|
||||
self._fc = Tensor.glorot_uniform(out_channels, classes)
|
||||
self._fc_bias = Tensor.zeros(classes)
|
||||
else:
|
||||
self._fc = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self._bn0(x.conv2d(self._conv_stem, padding=(0,1,0,1), stride=2)).swish()
|
||||
x = x.sequential(self._blocks)
|
||||
x = self._bn1(x.conv2d(self._conv_head)).swish()
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x = x.reshape(shape=(-1, x.shape[1]))
|
||||
return x.linear(self._fc, self._fc_bias) if self._fc is not None else x
|
||||
def forward(self, x):
|
||||
x = self._bn0(x.conv2d(self._conv_stem, padding=(0, 1, 0, 1), stride=2)).swish()
|
||||
x = x.sequential(self._blocks)
|
||||
x = self._bn1(x.conv2d(self._conv_head)).swish()
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x = x.reshape(shape=(-1, x.shape[1]))
|
||||
return x.linear(self._fc, self._fc_bias) if self._fc is not None else x
|
||||
|
||||
def load_from_pretrained(self):
|
||||
model_urls = {
|
||||
0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
||||
1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
||||
2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
||||
3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
||||
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
||||
}
|
||||
def load_from_pretrained(self):
|
||||
model_urls = {
|
||||
0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
||||
1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
||||
2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
||||
3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
||||
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth",
|
||||
}
|
||||
|
||||
b0 = torch_load(fetch(model_urls[self.number]))
|
||||
for k,v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||
if cat in k:
|
||||
k = k.replace('.bias', '_bias')
|
||||
k = k.replace('.weight', '')
|
||||
b0 = torch_load(fetch(model_urls[self.number]))
|
||||
for k, v in b0.items():
|
||||
if k.endswith("num_batches_tracked"):
|
||||
continue
|
||||
for cat in [
|
||||
"_conv_head",
|
||||
"_conv_stem",
|
||||
"_depthwise_conv",
|
||||
"_expand_conv",
|
||||
"_fc",
|
||||
"_project_conv",
|
||||
"_se_reduce",
|
||||
"_se_expand",
|
||||
]:
|
||||
if cat in k:
|
||||
k = k.replace(".bias", "_bias")
|
||||
k = k.replace(".weight", "")
|
||||
|
||||
#print(k, v.shape)
|
||||
mv = get_child(self, k)
|
||||
vnp = v #.astype(np.float32)
|
||||
vnp = vnp if k != '_fc' else vnp.cpu().T
|
||||
#vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
|
||||
if mv.shape == vnp.shape:
|
||||
mv.assign(vnp.to(mv.device))
|
||||
else:
|
||||
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
||||
# print(k, v.shape)
|
||||
mv = get_child(self, k)
|
||||
vnp = v # .astype(np.float32)
|
||||
vnp = vnp if k != "_fc" else vnp.cpu().T
|
||||
# vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
|
||||
if mv.shape == vnp.shape:
|
||||
mv.assign(vnp.to(mv.device))
|
||||
else:
|
||||
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
||||
|
|
|
@ -2,151 +2,275 @@ from typing import Tuple, Union, Optional, Dict
|
|||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(1, end, 1, dim//2, 2)
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[: (dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack([Tensor.cos(freqs), Tensor.sin(freqs)], dim=-1).reshape(
|
||||
1, end, 1, dim // 2, 2
|
||||
)
|
||||
|
||||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a,b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2]
|
||||
ro = a*c - b*d
|
||||
co = a*d + b*c
|
||||
return ro.cat(co, dim=-1)
|
||||
a, b = A[:, :, :, :, 0:1], A[:, :, :, :, 1:2]
|
||||
ro = a * c - b * d
|
||||
co = a * d + b * c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, freqs_cis) -> Tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[:, :xq.shape[1], :, :, 0:1], freqs_cis[:, :xq.shape[1], :, :, 1:2]
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
assert (
|
||||
freqs_cis.shape[1] == xq.shape[1] and freqs_cis.shape[1] == xk.shape[1]
|
||||
), f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == 5 and len(xk.shape) == 5 and len(freqs_cis.shape) == 5
|
||||
c, d = (
|
||||
freqs_cis[:, : xq.shape[1], :, :, 0:1],
|
||||
freqs_cis[:, : xq.shape[1], :, :, 1:2],
|
||||
)
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x.reshape(bs, seqlen, n_kv_heads, 1, head_dim)
|
||||
.expand(bs, seqlen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
def repeat_kv(x:Tensor, n_rep:int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
return x.reshape(bs, seqlen, n_kv_heads, 1, head_dim).expand(bs, seqlen, n_kv_heads, n_rep, head_dim).reshape(bs, seqlen, n_kv_heads * n_rep, head_dim)
|
||||
|
||||
class RMSNorm:
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim)
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.ones(dim)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
# TODO: convert to float?
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
# TODO: convert to float?
|
||||
return (x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()) * self.weight
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = (
|
||||
n_kv_heads if n_kv_heads is not None else n_heads
|
||||
) # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]) -> Tensor:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
start_pos: Union[Variable, int],
|
||||
freqs_cis: Tensor,
|
||||
mask: Optional[Tensor],
|
||||
) -> Tensor:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
bsz, seqlen, n_heads, head_dim = xq.shape
|
||||
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
||||
# create kv cache
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k, self.cache_v = Tensor.zeros(
|
||||
bsz, self.max_context, self.n_kv_heads, self.head_dim
|
||||
), Tensor.zeros(bsz, self.max_context, self.n_kv_heads, self.head_dim)
|
||||
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1)
|
||||
values = self.cache_v.shrink((None, (0, start_pos), None, None)).cat(xv, dim=1)
|
||||
|
||||
# update the cache
|
||||
self.cache_k.assign(keys.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
self.cache_v.assign(values.pad((None,(0,self.max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
# update the cache
|
||||
self.cache_k.assign(
|
||||
keys.pad(
|
||||
(None, (0, self.max_context - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
self.cache_v.assign(
|
||||
values.pad(
|
||||
(None, (0, self.max_context - start_pos - seqlen), None, None)
|
||||
).contiguous()
|
||||
).realize()
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
|
||||
xq, keys, values = (
|
||||
xq.transpose(1, 2),
|
||||
keys.transpose(1, 2),
|
||||
values.transpose(1, 2),
|
||||
)
|
||||
attn = (
|
||||
xq.scaled_dot_product_attention(keys, values, mask)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, seqlen, -1)
|
||||
)
|
||||
return self.wo(attn)
|
||||
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2).reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn)
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim, hidden_dim, linear=nn.Linear):
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||
def __init__(self, dim, hidden_dim, linear=nn.Linear):
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.w2(
|
||||
self.w1(x).silu() * self.w3(x)
|
||||
) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||
self.feed_forward = FeedForward(dim, hidden_dim, linear)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
max_context: int,
|
||||
linear=nn.Linear,
|
||||
):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||
self.feed_forward = FeedForward(dim, hidden_dim, linear)
|
||||
self.attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
start_pos: Union[Variable, int],
|
||||
freqs_cis: Tensor,
|
||||
mask: Optional[Tensor],
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).realize()
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size, linear=nn.Linear, n_kv_heads=None, rope_theta=10000, max_context=1024, jit=True):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear) for _ in range(n_layers)]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context * 2, rope_theta)
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
norm_eps: float,
|
||||
vocab_size,
|
||||
linear=nn.Linear,
|
||||
n_kv_heads=None,
|
||||
rope_theta=10000,
|
||||
max_context=1024,
|
||||
jit=True,
|
||||
):
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear
|
||||
)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
self.norm = RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||
self.output = linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
dim // n_heads, self.max_context * 2, rope_theta
|
||||
)
|
||||
self.forward_jit = TinyJit(self.forward) if jit else None
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:Union[Variable,int], temperature:float=0.0):
|
||||
_bsz, seqlen = tokens.shape
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos+seqlen),None,None,None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
def forward(
|
||||
self, tokens: Tensor, start_pos: Union[Variable, int], temperature: float = 0.0
|
||||
):
|
||||
_bsz, seqlen = tokens.shape
|
||||
freqs_cis = self.freqs_cis.shrink(
|
||||
(None, (start_pos, start_pos + seqlen), None, None, None)
|
||||
)
|
||||
mask = (
|
||||
Tensor.full(
|
||||
(1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32
|
||||
)
|
||||
.triu(start_pos + 1)
|
||||
.realize()
|
||||
if seqlen > 1
|
||||
else None
|
||||
)
|
||||
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
h = self.tok_embeddings(tokens)
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
logits = self.output(self.norm(h))
|
||||
return (logits[:, -1, :] / (temperature + 1e-10)).softmax().flatten().realize()
|
||||
|
||||
def __call__(self, tokens: Tensor, start_pos: Variable, temperature: float = 0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1, 1) and self.forward_jit and getenv("JIT", 1):
|
||||
assert start_pos > 0
|
||||
return self.forward_jit(
|
||||
tokens,
|
||||
Variable("start_pos", 1, self.max_context).bind(start_pos),
|
||||
temperature,
|
||||
)
|
||||
return self.forward(tokens, start_pos, temperature)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
if tokens.shape[0:2] == (1,1) and self.forward_jit and getenv("JIT", 1):
|
||||
assert start_pos > 0
|
||||
return self.forward_jit(tokens, Variable("start_pos", 1, self.max_context).bind(start_pos), temperature)
|
||||
return self.forward(tokens, start_pos, temperature)
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
def convert_from_huggingface(weights:Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight" for x in ["q", "k", "v", "o"] for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight" for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight" for x, y in {"gate": "1", "down": "2", "up": "3"}.items() for l in range(len(model.layers))},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k: continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if "q_proj" in k:
|
||||
v = permute(v, n_heads)
|
||||
elif "k_proj" in k:
|
||||
v = permute(v, n_kv_heads)
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
def convert_from_huggingface(
|
||||
weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int
|
||||
):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return (
|
||||
v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1])
|
||||
.transpose(1, 2)
|
||||
.reshape(*v.shape[:2])
|
||||
)
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{
|
||||
f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
**{
|
||||
f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
|
||||
for x in ["q", "k", "v", "o"]
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
**{
|
||||
f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
**{
|
||||
f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
|
||||
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
|
||||
for l in range(len(model.layers))
|
||||
},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k:
|
||||
continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if "q_proj" in k:
|
||||
v = permute(v, n_heads)
|
||||
elif "k_proj" in k:
|
||||
v = permute(v, n_kv_heads)
|
||||
sd[keymap[k]] = v
|
||||
return sd
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,150 +3,229 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn.state import torch_load
|
||||
from tinygrad.helpers import fetch, get_child
|
||||
|
||||
|
||||
class BasicBlock:
|
||||
expansion = 1
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
||||
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
]
|
||||
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
||||
assert (
|
||||
groups == 1 and base_width == 64
|
||||
), "BasicBlock only supports groups=1 and base_width=64"
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, padding=1, stride=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.downsample = [
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + x.sequential(self.downsample)
|
||||
out = out.relu()
|
||||
return out
|
||||
def __call__(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + x.sequential(self.downsample)
|
||||
out = out.relu()
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck:
|
||||
# NOTE: stride_in_1x1=False, this is the v1.5 variant
|
||||
expansion = 4
|
||||
# NOTE: stride_in_1x1=False, this is the v1.5 variant
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
|
||||
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion*planes)
|
||||
]
|
||||
def __init__(
|
||||
self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64
|
||||
):
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
width,
|
||||
kernel_size=1,
|
||||
stride=stride if stride_in_1x1 else 1,
|
||||
bias=False,
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(
|
||||
width,
|
||||
width,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1 if stride_in_1x1 else stride,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(
|
||||
width, self.expansion * planes, kernel_size=1, bias=False
|
||||
)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.downsample = [
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.bn2(self.conv2(out)).relu()
|
||||
out = self.bn3(self.conv3(out))
|
||||
out = out + x.sequential(self.downsample)
|
||||
out = out.relu()
|
||||
return out
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.bn2(self.conv2(out)).relu()
|
||||
out = self.bn3(self.conv3(out))
|
||||
out = out + x.sequential(self.downsample)
|
||||
out = out.relu()
|
||||
return out
|
||||
|
||||
class ResNet:
|
||||
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
|
||||
self.num = num
|
||||
self.block = {
|
||||
18: BasicBlock,
|
||||
34: BasicBlock,
|
||||
50: Bottleneck,
|
||||
101: Bottleneck,
|
||||
152: Bottleneck
|
||||
}[num]
|
||||
def __init__(
|
||||
self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False
|
||||
):
|
||||
self.num = num
|
||||
self.block = {
|
||||
18: BasicBlock,
|
||||
34: BasicBlock,
|
||||
50: Bottleneck,
|
||||
101: Bottleneck,
|
||||
152: Bottleneck,
|
||||
}[num]
|
||||
|
||||
self.num_blocks = {
|
||||
18: [2,2,2,2],
|
||||
34: [3,4,6,3],
|
||||
50: [3,4,6,3],
|
||||
101: [3,4,23,3],
|
||||
152: [3,8,36,3]
|
||||
}[num]
|
||||
self.num_blocks = {
|
||||
18: [2, 2, 2, 2],
|
||||
34: [3, 4, 6, 3],
|
||||
50: [3, 4, 6, 3],
|
||||
101: [3, 4, 23, 3],
|
||||
152: [3, 8, 36, 3],
|
||||
}[num]
|
||||
|
||||
self.in_planes = 64
|
||||
self.in_planes = 64
|
||||
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
|
||||
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(
|
||||
self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1
|
||||
)
|
||||
self.fc = (
|
||||
nn.Linear(512 * self.block.expansion, num_classes)
|
||||
if num_classes is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
|
||||
strides = [stride] + [1] * (num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
if block == Bottleneck:
|
||||
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
|
||||
else:
|
||||
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
|
||||
self.in_planes = planes * block.expansion
|
||||
return layers
|
||||
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
if block == Bottleneck:
|
||||
layers.append(
|
||||
block(
|
||||
self.in_planes,
|
||||
planes,
|
||||
stride,
|
||||
stride_in_1x1,
|
||||
self.groups,
|
||||
self.base_width,
|
||||
)
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
block(self.in_planes, planes, stride, self.groups, self.base_width)
|
||||
)
|
||||
self.in_planes = planes * block.expansion
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
is_feature_only = self.fc is None
|
||||
if is_feature_only: features = []
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
|
||||
out = out.sequential(self.layer1)
|
||||
if is_feature_only: features.append(out)
|
||||
out = out.sequential(self.layer2)
|
||||
if is_feature_only: features.append(out)
|
||||
out = out.sequential(self.layer3)
|
||||
if is_feature_only: features.append(out)
|
||||
out = out.sequential(self.layer4)
|
||||
if is_feature_only: features.append(out)
|
||||
if not is_feature_only:
|
||||
out = out.mean([2,3])
|
||||
out = self.fc(out).log_softmax()
|
||||
return out
|
||||
return features
|
||||
def forward(self, x):
|
||||
is_feature_only = self.fc is None
|
||||
if is_feature_only:
|
||||
features = []
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2)
|
||||
out = out.sequential(self.layer1)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
out = out.sequential(self.layer2)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
out = out.sequential(self.layer3)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
out = out.sequential(self.layer4)
|
||||
if is_feature_only:
|
||||
features.append(out)
|
||||
if not is_feature_only:
|
||||
out = out.mean([2, 3])
|
||||
out = self.fc(out).log_softmax()
|
||||
return out
|
||||
return features
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return self.forward(x)
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.forward(x)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
# TODO replace with fake torch load
|
||||
def load_from_pretrained(self):
|
||||
# TODO replace with fake torch load
|
||||
|
||||
model_urls = {
|
||||
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
(50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
(101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
model_urls = {
|
||||
(18, 1, 64): "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
||||
(34, 1, 64): "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
||||
(50, 1, 64): "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
||||
(
|
||||
50,
|
||||
32,
|
||||
4,
|
||||
): "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||
(101, 1, 64): "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
||||
(152, 1, 64): "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
||||
}
|
||||
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
for k, v in torch_load(fetch(self.url)).items():
|
||||
obj: Tensor = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
for k, v in torch_load(fetch(self.url)).items():
|
||||
obj: Tensor = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
|
||||
if 'fc.' in k and obj.shape != dat.shape:
|
||||
print("skipping fully connected layer")
|
||||
continue # Skip FC if transfer learning
|
||||
if "fc." in k and obj.shape != dat.shape:
|
||||
print("skipping fully connected layer")
|
||||
continue # Skip FC if transfer learning
|
||||
|
||||
# TODO: remove or when #777 is merged
|
||||
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (
|
||||
k,
|
||||
obj.shape,
|
||||
dat.shape,
|
||||
)
|
||||
obj.assign(dat)
|
||||
|
||||
# TODO: remove or when #777 is merged
|
||||
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape)
|
||||
obj.assign(dat)
|
||||
|
||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
||||
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
||||
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(
|
||||
50, num_classes=num_classes, groups=32, width_per_group=4
|
||||
)
|
||||
|
|
|
@ -4,233 +4,379 @@ import tinygrad.nn as nn
|
|||
from extra.models.resnet import ResNet
|
||||
import numpy as np
|
||||
|
||||
|
||||
def nms(boxes, scores, thresh=0.5):
|
||||
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
to_process, keep = scores.argsort()[::-1], []
|
||||
while to_process.size > 0:
|
||||
cur, to_process = to_process[0], to_process[1:]
|
||||
keep.append(cur)
|
||||
inter_x1 = np.maximum(x1[cur], x1[to_process])
|
||||
inter_y1 = np.maximum(y1[cur], y1[to_process])
|
||||
inter_x2 = np.minimum(x2[cur], x2[to_process])
|
||||
inter_y2 = np.minimum(y2[cur], y2[to_process])
|
||||
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1)
|
||||
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
|
||||
to_process = to_process[np.where(iou <= thresh)[0]]
|
||||
return keep
|
||||
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
to_process, keep = scores.argsort()[::-1], []
|
||||
while to_process.size > 0:
|
||||
cur, to_process = to_process[0], to_process[1:]
|
||||
keep.append(cur)
|
||||
inter_x1 = np.maximum(x1[cur], x1[to_process])
|
||||
inter_y1 = np.maximum(y1[cur], y1[to_process])
|
||||
inter_x2 = np.minimum(x2[cur], x2[to_process])
|
||||
inter_y2 = np.minimum(y2[cur], y2[to_process])
|
||||
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(
|
||||
0, inter_y2 - inter_y1 + 1
|
||||
)
|
||||
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
|
||||
to_process = to_process[np.where(iou <= thresh)[0]]
|
||||
return keep
|
||||
|
||||
|
||||
def decode_bbox(offsets, anchors):
|
||||
dx, dy, dw, dh = np.rollaxis(offsets, 1)
|
||||
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
|
||||
cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights
|
||||
pred_cx, pred_cy = dx * widths + cx, dy * heights + cy
|
||||
pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights
|
||||
pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h
|
||||
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
|
||||
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
|
||||
dx, dy, dw, dh = np.rollaxis(offsets, 1)
|
||||
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
|
||||
cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights
|
||||
pred_cx, pred_cy = dx * widths + cx, dy * heights + cy
|
||||
pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights
|
||||
pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h
|
||||
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
|
||||
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
|
||||
|
||||
|
||||
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
|
||||
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
|
||||
anchors = []
|
||||
for s, ar, gs in zip(scales, aspect_ratios, grid_sizes):
|
||||
s, ar = np.array(s), np.array(ar)
|
||||
h_ratios = np.sqrt(ar)
|
||||
w_ratios = 1 / h_ratios
|
||||
ws = (w_ratios[:, None] * s[None, :]).reshape(-1)
|
||||
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
|
||||
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
|
||||
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
|
||||
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
|
||||
shifts_x = shifts_x.reshape(-1)
|
||||
shifts_y = shifts_y.reshape(-1)
|
||||
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
|
||||
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
|
||||
return anchors
|
||||
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
|
||||
anchors = []
|
||||
for s, ar, gs in zip(scales, aspect_ratios, grid_sizes):
|
||||
s, ar = np.array(s), np.array(ar)
|
||||
h_ratios = np.sqrt(ar)
|
||||
w_ratios = 1 / h_ratios
|
||||
ws = (w_ratios[:, None] * s[None, :]).reshape(-1)
|
||||
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
|
||||
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
|
||||
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
|
||||
shifts_x, shifts_y = np.meshgrid(
|
||||
np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h
|
||||
)
|
||||
shifts_x = shifts_x.reshape(-1)
|
||||
shifts_y = shifts_y.reshape(-1)
|
||||
shifts = np.stack(
|
||||
[shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32
|
||||
)
|
||||
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
|
||||
return anchors
|
||||
|
||||
|
||||
class RetinaNet:
|
||||
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
|
||||
assert isinstance(backbone, ResNet)
|
||||
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
|
||||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
||||
self.num_anchors, self.num_classes = num_anchors, num_classes
|
||||
assert len(scales) == len(aspect_ratios) and all(self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios))
|
||||
def __init__(
|
||||
self,
|
||||
backbone: ResNet,
|
||||
num_classes=264,
|
||||
num_anchors=9,
|
||||
scales=None,
|
||||
aspect_ratios=None,
|
||||
):
|
||||
assert isinstance(backbone, ResNet)
|
||||
scales = (
|
||||
tuple(
|
||||
(i, int(i * 2 ** (1 / 3)), int(i * 2 ** (2 / 3)))
|
||||
for i in 2 ** np.arange(5, 10)
|
||||
)
|
||||
if scales is None
|
||||
else scales
|
||||
)
|
||||
aspect_ratios = (
|
||||
((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
||||
)
|
||||
self.num_anchors, self.num_classes = num_anchors, num_classes
|
||||
assert len(scales) == len(aspect_ratios) and all(
|
||||
self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)
|
||||
)
|
||||
|
||||
self.backbone = ResNetFPN(backbone)
|
||||
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
|
||||
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
|
||||
self.backbone = ResNetFPN(backbone)
|
||||
self.head = RetinaHead(
|
||||
self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes
|
||||
)
|
||||
self.anchor_gen = lambda input_size: generate_anchors(
|
||||
input_size,
|
||||
self.backbone.compute_grid_sizes(input_size),
|
||||
scales,
|
||||
aspect_ratios,
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.forward(x)
|
||||
def forward(self, x):
|
||||
return self.head(self.backbone(x))
|
||||
def __call__(self, x):
|
||||
return self.forward(x)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
model_urls = {
|
||||
(50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
||||
(50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
|
||||
}
|
||||
self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)]
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu')
|
||||
state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict
|
||||
for k, v in state_dict.items():
|
||||
obj = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat)
|
||||
def forward(self, x):
|
||||
return self.head(self.backbone(x))
|
||||
|
||||
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
|
||||
def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
|
||||
anchors = self.anchor_gen(input_size)
|
||||
grid_sizes = self.backbone.compute_grid_sizes(input_size)
|
||||
split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
|
||||
detections = []
|
||||
for i, predictions_per_image in enumerate(predictions):
|
||||
h, w = input_size if image_sizes is None else image_sizes[i]
|
||||
def load_from_pretrained(self):
|
||||
model_urls = {
|
||||
(
|
||||
50,
|
||||
1,
|
||||
64,
|
||||
): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
||||
(
|
||||
50,
|
||||
32,
|
||||
4,
|
||||
): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
|
||||
}
|
||||
self.url = model_urls[
|
||||
(
|
||||
self.backbone.body.num,
|
||||
self.backbone.body.groups,
|
||||
self.backbone.body.base_width,
|
||||
)
|
||||
]
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
predictions_per_image = np.split(predictions_per_image, split_idx)
|
||||
offsets_per_image = [br[:, :4] for br in predictions_per_image]
|
||||
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
|
||||
state_dict = load_state_dict_from_url(
|
||||
self.url, progress=True, map_location="cpu"
|
||||
)
|
||||
state_dict = state_dict["model"] if "model" in state_dict.keys() else state_dict
|
||||
for k, v in state_dict.items():
|
||||
obj = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat)
|
||||
|
||||
image_boxes, image_scores, image_labels = [], [], []
|
||||
for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors):
|
||||
# remove low scoring boxes
|
||||
scores_per_level = scores_per_level.flatten()
|
||||
keep_idxs = scores_per_level > score_thresh
|
||||
scores_per_level = scores_per_level[keep_idxs]
|
||||
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
|
||||
def postprocess_detections(
|
||||
self,
|
||||
predictions,
|
||||
input_size=(800, 800),
|
||||
image_sizes=None,
|
||||
orig_image_sizes=None,
|
||||
score_thresh=0.05,
|
||||
topk_candidates=1000,
|
||||
nms_thresh=0.5,
|
||||
):
|
||||
anchors = self.anchor_gen(input_size)
|
||||
grid_sizes = self.backbone.compute_grid_sizes(input_size)
|
||||
split_idx = np.cumsum(
|
||||
[int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]]
|
||||
)
|
||||
detections = []
|
||||
for i, predictions_per_image in enumerate(predictions):
|
||||
h, w = input_size if image_sizes is None else image_sizes[i]
|
||||
|
||||
# keep topk
|
||||
topk_idxs = np.where(keep_idxs)[0]
|
||||
num_topk = min(len(topk_idxs), topk_candidates)
|
||||
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
|
||||
topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs]
|
||||
predictions_per_image = np.split(predictions_per_image, split_idx)
|
||||
offsets_per_image = [br[:, :4] for br in predictions_per_image]
|
||||
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
|
||||
|
||||
# bbox coords from offsets
|
||||
anchor_idxs = topk_idxs // self.num_classes
|
||||
labels_per_level = topk_idxs % self.num_classes
|
||||
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
|
||||
# clip to image size
|
||||
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
|
||||
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
|
||||
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4)
|
||||
image_boxes, image_scores, image_labels = [], [], []
|
||||
for offsets_per_level, scores_per_level, anchors_per_level in zip(
|
||||
offsets_per_image, scores_per_image, anchors
|
||||
):
|
||||
# remove low scoring boxes
|
||||
scores_per_level = scores_per_level.flatten()
|
||||
keep_idxs = scores_per_level > score_thresh
|
||||
scores_per_level = scores_per_level[keep_idxs]
|
||||
|
||||
image_boxes.append(boxes_per_level)
|
||||
image_scores.append(scores_per_level)
|
||||
image_labels.append(labels_per_level)
|
||||
# keep topk
|
||||
topk_idxs = np.where(keep_idxs)[0]
|
||||
num_topk = min(len(topk_idxs), topk_candidates)
|
||||
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
|
||||
topk_idxs, scores_per_level = (
|
||||
topk_idxs[sort_idxs],
|
||||
scores_per_level[sort_idxs],
|
||||
)
|
||||
|
||||
image_boxes = np.concatenate(image_boxes)
|
||||
image_scores = np.concatenate(image_scores)
|
||||
image_labels = np.concatenate(image_labels)
|
||||
# bbox coords from offsets
|
||||
anchor_idxs = topk_idxs // self.num_classes
|
||||
labels_per_level = topk_idxs % self.num_classes
|
||||
boxes_per_level = decode_bbox(
|
||||
offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
|
||||
)
|
||||
# clip to image size
|
||||
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
|
||||
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
|
||||
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(
|
||||
-1, 4
|
||||
)
|
||||
|
||||
# nms for each class
|
||||
keep_mask = np.zeros_like(image_scores, dtype=bool)
|
||||
for class_id in np.unique(image_labels):
|
||||
curr_indices = np.where(image_labels == class_id)[0]
|
||||
curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh)
|
||||
keep_mask[curr_indices[curr_keep_indices]] = True
|
||||
keep = np.where(keep_mask)[0]
|
||||
keep = keep[image_scores[keep].argsort()[::-1]]
|
||||
image_boxes.append(boxes_per_level)
|
||||
image_scores.append(scores_per_level)
|
||||
image_labels.append(labels_per_level)
|
||||
|
||||
# resize bboxes back to original size
|
||||
image_boxes = image_boxes[keep]
|
||||
if orig_image_sizes is not None:
|
||||
resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w
|
||||
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
|
||||
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
|
||||
# xywh format
|
||||
image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1)
|
||||
image_boxes = np.concatenate(image_boxes)
|
||||
image_scores = np.concatenate(image_scores)
|
||||
image_labels = np.concatenate(image_labels)
|
||||
|
||||
# nms for each class
|
||||
keep_mask = np.zeros_like(image_scores, dtype=bool)
|
||||
for class_id in np.unique(image_labels):
|
||||
curr_indices = np.where(image_labels == class_id)[0]
|
||||
curr_keep_indices = nms(
|
||||
image_boxes[curr_indices], image_scores[curr_indices], nms_thresh
|
||||
)
|
||||
keep_mask[curr_indices[curr_keep_indices]] = True
|
||||
keep = np.where(keep_mask)[0]
|
||||
keep = keep[image_scores[keep].argsort()[::-1]]
|
||||
|
||||
# resize bboxes back to original size
|
||||
image_boxes = image_boxes[keep]
|
||||
if orig_image_sizes is not None:
|
||||
resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w
|
||||
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
|
||||
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
|
||||
# xywh format
|
||||
image_boxes = np.concatenate(
|
||||
[image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1
|
||||
)
|
||||
|
||||
detections.append(
|
||||
{
|
||||
"boxes": image_boxes,
|
||||
"scores": image_scores[keep],
|
||||
"labels": image_labels[keep],
|
||||
}
|
||||
)
|
||||
return detections
|
||||
|
||||
detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]})
|
||||
return detections
|
||||
|
||||
class ClassificationHead:
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
||||
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
|
||||
def __call__(self, x):
|
||||
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
|
||||
return out[0].cat(*out[1:], dim=1).sigmoid()
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.conv = flatten(
|
||||
[
|
||||
(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||
lambda x: x.relu(),
|
||||
)
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
self.cls_logits = nn.Conv2d(
|
||||
in_channels, num_anchors * num_classes, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
out = [
|
||||
self.cls_logits(feat.sequential(self.conv))
|
||||
.permute(0, 2, 3, 1)
|
||||
.reshape(feat.shape[0], -1, self.num_classes)
|
||||
for feat in x
|
||||
]
|
||||
return out[0].cat(*out[1:], dim=1).sigmoid()
|
||||
|
||||
|
||||
class RegressionHead:
|
||||
def __init__(self, in_channels, num_anchors):
|
||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
||||
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
|
||||
def __call__(self, x):
|
||||
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
|
||||
return out[0].cat(*out[1:], dim=1)
|
||||
def __init__(self, in_channels, num_anchors):
|
||||
self.conv = flatten(
|
||||
[
|
||||
(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
|
||||
lambda x: x.relu(),
|
||||
)
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
self.bbox_reg = nn.Conv2d(
|
||||
in_channels, num_anchors * 4, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
out = [
|
||||
self.bbox_reg(feat.sequential(self.conv))
|
||||
.permute(0, 2, 3, 1)
|
||||
.reshape(feat.shape[0], -1, 4)
|
||||
for feat in x
|
||||
]
|
||||
return out[0].cat(*out[1:], dim=1)
|
||||
|
||||
|
||||
class RetinaHead:
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
|
||||
self.regression_head = RegressionHead(in_channels, num_anchors)
|
||||
def __call__(self, x):
|
||||
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
|
||||
out = pred_bbox.cat(pred_class, dim=-1)
|
||||
return out
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.classification_head = ClassificationHead(
|
||||
in_channels, num_anchors, num_classes
|
||||
)
|
||||
self.regression_head = RegressionHead(in_channels, num_anchors)
|
||||
|
||||
def __call__(self, x):
|
||||
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
|
||||
out = pred_bbox.cat(pred_class, dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class ResNetFPN:
|
||||
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
|
||||
self.out_channels = out_channels
|
||||
self.body = resnet
|
||||
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
|
||||
self.fpn = FPN(in_channels_list, out_channels)
|
||||
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
|
||||
self.out_channels = out_channels
|
||||
self.body = resnet
|
||||
in_channels_list = [
|
||||
(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers
|
||||
]
|
||||
self.fpn = FPN(in_channels_list, out_channels)
|
||||
|
||||
# this is needed to decouple inference from postprocessing (anchors generation)
|
||||
def compute_grid_sizes(self, input_size):
|
||||
return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None])
|
||||
# this is needed to decouple inference from postprocessing (anchors generation)
|
||||
def compute_grid_sizes(self, input_size):
|
||||
return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None])
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.body.bn1(self.body.conv1(x)).relu()
|
||||
out = out.pad2d([1, 1, 1, 1]).max_pool2d((3, 3), 2)
|
||||
out = out.sequential(self.body.layer1)
|
||||
p3 = out.sequential(self.body.layer2)
|
||||
p4 = p3.sequential(self.body.layer3)
|
||||
p5 = p4.sequential(self.body.layer4)
|
||||
return self.fpn([p3, p4, p5])
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.body.bn1(self.body.conv1(x)).relu()
|
||||
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
|
||||
out = out.sequential(self.body.layer1)
|
||||
p3 = out.sequential(self.body.layer2)
|
||||
p4 = p3.sequential(self.body.layer3)
|
||||
p5 = p4.sequential(self.body.layer4)
|
||||
return self.fpn([p3, p4, p5])
|
||||
|
||||
class ExtraFPNBlock:
|
||||
def __init__(self, in_channels, out_channels):
|
||||
self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
||||
self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
||||
self.use_P5 = in_channels == out_channels
|
||||
def __init__(self, in_channels, out_channels):
|
||||
self.p6 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.p7 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
self.use_P5 = in_channels == out_channels
|
||||
|
||||
def __call__(self, p, c):
|
||||
p5, c5 = p[-1], c[-1]
|
||||
x = p5 if self.use_P5 else c5
|
||||
p6 = self.p6(x)
|
||||
p7 = self.p7(p6.relu())
|
||||
p.extend([p6, p7])
|
||||
return p
|
||||
|
||||
def __call__(self, p, c):
|
||||
p5, c5 = p[-1], c[-1]
|
||||
x = p5 if self.use_P5 else c5
|
||||
p6 = self.p6(x)
|
||||
p7 = self.p7(p6.relu())
|
||||
p.extend([p6, p7])
|
||||
return p
|
||||
|
||||
class FPN:
|
||||
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
|
||||
self.inner_blocks, self.layer_blocks = [], []
|
||||
for in_channels in in_channels_list:
|
||||
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
||||
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
||||
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
|
||||
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
|
||||
self.inner_blocks, self.layer_blocks = [], []
|
||||
for in_channels in in_channels_list:
|
||||
self.inner_blocks.append(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
)
|
||||
self.layer_blocks.append(
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
)
|
||||
self.extra_blocks = (
|
||||
ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
last_inner = self.inner_blocks[-1](x[-1])
|
||||
results = [self.layer_blocks[-1](last_inner)]
|
||||
for idx in range(len(x) - 2, -1, -1):
|
||||
inner_lateral = self.inner_blocks[idx](x[idx])
|
||||
def __call__(self, x):
|
||||
last_inner = self.inner_blocks[-1](x[-1])
|
||||
results = [self.layer_blocks[-1](last_inner)]
|
||||
for idx in range(len(x) - 2, -1, -1):
|
||||
inner_lateral = self.inner_blocks[idx](x[idx])
|
||||
|
||||
# upsample to inner_lateral's shape
|
||||
(ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2]
|
||||
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
|
||||
inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow]
|
||||
# upsample to inner_lateral's shape
|
||||
(ih, iw), (oh, ow), prefix = (
|
||||
last_inner.shape[-2:],
|
||||
inner_lateral.shape[-2:],
|
||||
last_inner.shape[:-2],
|
||||
)
|
||||
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
|
||||
inner_top_down = (
|
||||
last_inner.reshape(*prefix, ih, 1, iw, 1)
|
||||
.expand(*prefix, ih, eh, iw, ew)
|
||||
.reshape(*prefix, ih * eh, iw * ew)[:, :, :oh, :ow]
|
||||
)
|
||||
|
||||
last_inner = inner_lateral + inner_top_down
|
||||
results.insert(0, self.layer_blocks[idx](last_inner))
|
||||
if self.extra_blocks is not None:
|
||||
results = self.extra_blocks(results, x)
|
||||
return results
|
||||
|
||||
last_inner = inner_lateral + inner_top_down
|
||||
results.insert(0, self.layer_blocks[idx](last_inner))
|
||||
if self.extra_blocks is not None:
|
||||
results = self.extra_blocks(results, x)
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
from extra.models.resnet import ResNeXt50_32X4D
|
||||
backbone = ResNeXt50_32X4D()
|
||||
retina = RetinaNet(backbone)
|
||||
retina.load_from_pretrained()
|
||||
from extra.models.resnet import ResNeXt50_32X4D
|
||||
|
||||
backbone = ResNeXt50_32X4D()
|
||||
retina = RetinaNet(backbone)
|
||||
retina.load_from_pretrained()
|
||||
|
|
|
@ -7,196 +7,278 @@ from pathlib import Path
|
|||
|
||||
|
||||
class RNNT:
|
||||
def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
|
||||
self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
|
||||
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
|
||||
self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
|
||||
def __init__(
|
||||
self,
|
||||
input_features=240,
|
||||
vocab_size=29,
|
||||
enc_hidden_size=1024,
|
||||
pred_hidden_size=320,
|
||||
joint_hidden_size=512,
|
||||
pre_enc_layers=2,
|
||||
post_enc_layers=3,
|
||||
pred_layers=2,
|
||||
stack_time_factor=2,
|
||||
dropout=0.32,
|
||||
):
|
||||
self.encoder = Encoder(
|
||||
input_features,
|
||||
enc_hidden_size,
|
||||
pre_enc_layers,
|
||||
post_enc_layers,
|
||||
stack_time_factor,
|
||||
dropout,
|
||||
)
|
||||
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
|
||||
self.joint = Joint(
|
||||
vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
|
||||
)
|
||||
|
||||
@TinyJit
|
||||
def __call__(self, x, y, hc=None):
|
||||
f, _ = self.encoder(x, None)
|
||||
g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
|
||||
out = self.joint(f, g)
|
||||
return out.realize()
|
||||
@TinyJit
|
||||
def __call__(self, x, y, hc=None):
|
||||
f, _ = self.encoder(x, None)
|
||||
g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
|
||||
out = self.joint(f, g)
|
||||
return out.realize()
|
||||
|
||||
def decode(self, x, x_lens):
|
||||
logits, logit_lens = self.encoder(x, x_lens)
|
||||
outputs = []
|
||||
for b in range(logits.shape[0]):
|
||||
inseq = logits[b, :, :].unsqueeze(1)
|
||||
logit_len = logit_lens[b]
|
||||
seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item()))
|
||||
outputs.append(seq)
|
||||
return outputs
|
||||
def decode(self, x, x_lens):
|
||||
logits, logit_lens = self.encoder(x, x_lens)
|
||||
outputs = []
|
||||
for b in range(logits.shape[0]):
|
||||
inseq = logits[b, :, :].unsqueeze(1)
|
||||
logit_len = logit_lens[b]
|
||||
seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item()))
|
||||
outputs.append(seq)
|
||||
return outputs
|
||||
|
||||
def _greedy_decode(self, logits, logit_len):
|
||||
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
|
||||
labels = []
|
||||
label = Tensor.zeros(1, 1, requires_grad=False)
|
||||
mask = Tensor.zeros(1, requires_grad=False)
|
||||
for time_idx in range(logit_len):
|
||||
logit = logits[time_idx, :, :].unsqueeze(0)
|
||||
not_blank = True
|
||||
added = 0
|
||||
while not_blank and added < 30:
|
||||
if len(labels) > 0:
|
||||
mask = (mask + 1).clip(0, 1)
|
||||
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
|
||||
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
||||
k = jhc[0, 0, :29].argmax(axis=0).numpy()
|
||||
not_blank = k != 28
|
||||
if not_blank:
|
||||
labels.append(k)
|
||||
hc = jhc[:, :, 29:] + 1 - 1
|
||||
added += 1
|
||||
return labels
|
||||
def _greedy_decode(self, logits, logit_len):
|
||||
hc = Tensor.zeros(
|
||||
self.prediction.rnn.layers,
|
||||
2,
|
||||
self.prediction.hidden_size,
|
||||
requires_grad=False,
|
||||
)
|
||||
labels = []
|
||||
label = Tensor.zeros(1, 1, requires_grad=False)
|
||||
mask = Tensor.zeros(1, requires_grad=False)
|
||||
for time_idx in range(logit_len):
|
||||
logit = logits[time_idx, :, :].unsqueeze(0)
|
||||
not_blank = True
|
||||
added = 0
|
||||
while not_blank and added < 30:
|
||||
if len(labels) > 0:
|
||||
mask = (mask + 1).clip(0, 1)
|
||||
label = (
|
||||
Tensor(
|
||||
[[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]],
|
||||
requires_grad=False,
|
||||
)
|
||||
+ 1
|
||||
- 1
|
||||
)
|
||||
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
||||
k = jhc[0, 0, :29].argmax(axis=0).numpy()
|
||||
not_blank = k != 28
|
||||
if not_blank:
|
||||
labels.append(k)
|
||||
hc = jhc[:, :, 29:] + 1 - 1
|
||||
added += 1
|
||||
return labels
|
||||
|
||||
@TinyJit
|
||||
def _pred_joint(self, logit, label, hc, mask):
|
||||
g, hc = self.prediction(label, hc, mask)
|
||||
j = self.joint(logit, g)[0]
|
||||
j = j.pad(((0, 1), (0, 1), (0, 0)))
|
||||
out = j.cat(hc, dim=2)
|
||||
return out.realize()
|
||||
@TinyJit
|
||||
def _pred_joint(self, logit, label, hc, mask):
|
||||
g, hc = self.prediction(label, hc, mask)
|
||||
j = self.joint(logit, g)[0]
|
||||
j = j.pad(((0, 1), (0, 1), (0, 0)))
|
||||
out = j.cat(hc, dim=2)
|
||||
return out.realize()
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parents[1] / "weights/rnnt.pt"
|
||||
fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parents[1] / "weights/rnnt.pt"
|
||||
fetch(
|
||||
"https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1",
|
||||
fn,
|
||||
)
|
||||
|
||||
import torch
|
||||
with open(fn, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")["state_dict"]
|
||||
import torch
|
||||
|
||||
# encoder
|
||||
for i in range(2):
|
||||
self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
for i in range(3):
|
||||
self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
with open(fn, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")["state_dict"]
|
||||
|
||||
# prediction
|
||||
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
|
||||
for i in range(2):
|
||||
self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
# encoder
|
||||
for i in range(2):
|
||||
self.encoder.pre_rnn.cells[i].weights_ih.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.pre_rnn.cells[i].weights_hh.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.pre_rnn.cells[i].bias_ih.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.pre_rnn.cells[i].bias_hh.assign(
|
||||
state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||
)
|
||||
for i in range(3):
|
||||
self.encoder.post_rnn.cells[i].weights_ih.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.post_rnn.cells[i].weights_hh.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.post_rnn.cells[i].bias_ih.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||
)
|
||||
self.encoder.post_rnn.cells[i].bias_hh.assign(
|
||||
state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||
)
|
||||
|
||||
# joint
|
||||
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
|
||||
self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy())
|
||||
self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy())
|
||||
self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy())
|
||||
# prediction
|
||||
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
|
||||
for i in range(2):
|
||||
self.prediction.rnn.cells[i].weights_ih.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy()
|
||||
)
|
||||
self.prediction.rnn.cells[i].weights_hh.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy()
|
||||
)
|
||||
self.prediction.rnn.cells[i].bias_ih.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy()
|
||||
)
|
||||
self.prediction.rnn.cells[i].bias_hh.assign(
|
||||
state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy()
|
||||
)
|
||||
|
||||
# joint
|
||||
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
|
||||
self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy())
|
||||
self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy())
|
||||
self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy())
|
||||
|
||||
|
||||
class LSTMCell:
|
||||
def __init__(self, input_size, hidden_size, dropout):
|
||||
self.dropout = dropout
|
||||
def __init__(self, input_size, hidden_size, dropout):
|
||||
self.dropout = dropout
|
||||
|
||||
self.weights_ih = Tensor.uniform(hidden_size * 4, input_size)
|
||||
self.bias_ih = Tensor.uniform(hidden_size * 4)
|
||||
self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size)
|
||||
self.bias_hh = Tensor.uniform(hidden_size * 4)
|
||||
self.weights_ih = Tensor.uniform(hidden_size * 4, input_size)
|
||||
self.bias_ih = Tensor.uniform(hidden_size * 4)
|
||||
self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size)
|
||||
self.bias_hh = Tensor.uniform(hidden_size * 4)
|
||||
|
||||
def __call__(self, x, hc):
|
||||
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
|
||||
def __call__(self, x, hc):
|
||||
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[: x.shape[0]].linear(
|
||||
self.weights_hh.T, self.bias_hh
|
||||
)
|
||||
|
||||
i, f, g, o = gates.chunk(4, 1)
|
||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||
i, f, g, o = gates.chunk(4, 1)
|
||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||
|
||||
c = (f * hc[x.shape[0]:]) + (i * g)
|
||||
h = (o * c.tanh()).dropout(self.dropout)
|
||||
c = (f * hc[x.shape[0] :]) + (i * g)
|
||||
h = (o * c.tanh()).dropout(self.dropout)
|
||||
|
||||
return Tensor.cat(h, c).realize()
|
||||
return Tensor.cat(h, c).realize()
|
||||
|
||||
|
||||
class LSTM:
|
||||
def __init__(self, input_size, hidden_size, layers, dropout):
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.layers = layers
|
||||
def __init__(self, input_size, hidden_size, layers, dropout):
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.layers = layers
|
||||
|
||||
self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
|
||||
self.cells = [
|
||||
LSTMCell(input_size, hidden_size, dropout)
|
||||
if i == 0
|
||||
else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0)
|
||||
for i in range(layers)
|
||||
]
|
||||
|
||||
def __call__(self, x, hc):
|
||||
@TinyJit
|
||||
def _do_step(x_, hc_):
|
||||
return self.do_step(x_, hc_)
|
||||
def __call__(self, x, hc):
|
||||
@TinyJit
|
||||
def _do_step(x_, hc_):
|
||||
return self.do_step(x_, hc_)
|
||||
|
||||
if hc is None:
|
||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
|
||||
if hc is None:
|
||||
hc = Tensor.zeros(
|
||||
self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False
|
||||
)
|
||||
|
||||
output = None
|
||||
for t in range(x.shape[0]):
|
||||
hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this?
|
||||
if output is None:
|
||||
output = hc[-1:, :x.shape[1]]
|
||||
else:
|
||||
output = output.cat(hc[-1:, :x.shape[1]], dim=0).realize()
|
||||
output = None
|
||||
for t in range(x.shape[0]):
|
||||
hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this?
|
||||
if output is None:
|
||||
output = hc[-1:, : x.shape[1]]
|
||||
else:
|
||||
output = output.cat(hc[-1:, : x.shape[1]], dim=0).realize()
|
||||
|
||||
return output, hc
|
||||
return output, hc
|
||||
|
||||
def do_step(self, x, hc):
|
||||
new_hc = [x]
|
||||
for i, cell in enumerate(self.cells):
|
||||
new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i]))
|
||||
return Tensor.stack(new_hc[1:]).realize()
|
||||
def do_step(self, x, hc):
|
||||
new_hc = [x]
|
||||
for i, cell in enumerate(self.cells):
|
||||
new_hc.append(cell(new_hc[i][: x.shape[0]], hc[i]))
|
||||
return Tensor.stack(new_hc[1:]).realize()
|
||||
|
||||
|
||||
class StackTime:
|
||||
def __init__(self, factor):
|
||||
self.factor = factor
|
||||
def __init__(self, factor):
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self, x, x_lens):
|
||||
x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0)))
|
||||
x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor)
|
||||
return x, x_lens / self.factor if x_lens is not None else None
|
||||
def __call__(self, x, x_lens):
|
||||
x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0)))
|
||||
x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor)
|
||||
return x, x_lens / self.factor if x_lens is not None else None
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
|
||||
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
|
||||
self.stack_time = StackTime(stack_time_factor)
|
||||
self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
pre_layers,
|
||||
post_layers,
|
||||
stack_time_factor,
|
||||
dropout,
|
||||
):
|
||||
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
|
||||
self.stack_time = StackTime(stack_time_factor)
|
||||
self.post_rnn = LSTM(
|
||||
stack_time_factor * hidden_size, hidden_size, post_layers, dropout
|
||||
)
|
||||
|
||||
def __call__(self, x, x_lens):
|
||||
x, _ = self.pre_rnn(x, None)
|
||||
x, x_lens = self.stack_time(x, x_lens)
|
||||
x, _ = self.post_rnn(x, None)
|
||||
return x.transpose(0, 1), x_lens
|
||||
def __call__(self, x, x_lens):
|
||||
x, _ = self.pre_rnn(x, None)
|
||||
x, x_lens = self.stack_time(x, x_lens)
|
||||
x, _ = self.post_rnn(x, None)
|
||||
return x.transpose(0, 1), x_lens
|
||||
|
||||
|
||||
class Prediction:
|
||||
def __init__(self, vocab_size, hidden_size, layers, dropout):
|
||||
self.hidden_size = hidden_size
|
||||
def __init__(self, vocab_size, hidden_size, layers, dropout):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.emb = Embedding(vocab_size - 1, hidden_size)
|
||||
self.rnn = LSTM(hidden_size, hidden_size, layers, dropout)
|
||||
self.emb = Embedding(vocab_size - 1, hidden_size)
|
||||
self.rnn = LSTM(hidden_size, hidden_size, layers, dropout)
|
||||
|
||||
def __call__(self, x, hc, m):
|
||||
emb = self.emb(x) * m
|
||||
x_, hc = self.rnn(emb.transpose(0, 1), hc)
|
||||
return x_.transpose(0, 1), hc
|
||||
def __call__(self, x, hc, m):
|
||||
emb = self.emb(x) * m
|
||||
x_, hc = self.rnn(emb.transpose(0, 1), hc)
|
||||
return x_.transpose(0, 1), hc
|
||||
|
||||
|
||||
class Joint:
|
||||
def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
|
||||
self.dropout = dropout
|
||||
def __init__(
|
||||
self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout
|
||||
):
|
||||
self.dropout = dropout
|
||||
|
||||
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
|
||||
self.l2 = Linear(joint_hidden_size, vocab_size)
|
||||
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
|
||||
self.l2 = Linear(joint_hidden_size, vocab_size)
|
||||
|
||||
def __call__(self, f, g):
|
||||
(_, T, H), (B, U, H2) = f.shape, g.shape
|
||||
f = f.unsqueeze(2).expand(B, T, U, H)
|
||||
g = g.unsqueeze(1).expand(B, T, U, H2)
|
||||
def __call__(self, f, g):
|
||||
(_, T, H), (B, U, H2) = f.shape, g.shape
|
||||
f = f.unsqueeze(2).expand(B, T, U, H)
|
||||
g = g.unsqueeze(1).expand(B, T, U, H2)
|
||||
|
||||
inp = f.cat(g, dim=3)
|
||||
t = self.l1(inp).relu()
|
||||
t = t.dropout(self.dropout)
|
||||
return self.l2(t)
|
||||
inp = f.cat(g, dim=3)
|
||||
t = self.l1(inp).relu()
|
||||
t = t.dropout(self.dropout)
|
||||
return self.l2(t)
|
||||
|
|
|
@ -1,64 +1,104 @@
|
|||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1):
|
||||
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
ff_dim,
|
||||
prenorm=False,
|
||||
act=lambda x: x.relu(),
|
||||
dropout=0.1,
|
||||
):
|
||||
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = embed_dim // num_heads
|
||||
self.prenorm, self.act = prenorm, act
|
||||
self.dropout = dropout
|
||||
self.num_heads = num_heads
|
||||
self.head_size = embed_dim // num_heads
|
||||
self.prenorm, self.act = prenorm, act
|
||||
self.dropout = dropout
|
||||
|
||||
self.query = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.key = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.value = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.query = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
self.key = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
self.value = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
|
||||
self.out = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.out = (
|
||||
Tensor.scaled_uniform(embed_dim, embed_dim),
|
||||
Tensor.zeros(embed_dim),
|
||||
)
|
||||
|
||||
self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
|
||||
self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
|
||||
self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
|
||||
|
||||
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
||||
|
||||
def attn(self, x):
|
||||
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
|
||||
query, key, value = [x.linear(*y).reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)).transpose(1,2) for y in [self.query, self.key, self.value]]
|
||||
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(1,2)
|
||||
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
|
||||
def attn(self, x):
|
||||
# x: (bs, time, embed_dim) -> (bs, time, embed_dim)
|
||||
query, key, value = [
|
||||
x.linear(*y)
|
||||
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size))
|
||||
.transpose(1, 2)
|
||||
for y in [self.query, self.key, self.value]
|
||||
]
|
||||
attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(
|
||||
1, 2
|
||||
)
|
||||
return attention.reshape(
|
||||
shape=(x.shape[0], -1, self.num_heads * self.head_size)
|
||||
).linear(*self.out)
|
||||
|
||||
def __call__(self, x):
|
||||
if self.prenorm:
|
||||
x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout)
|
||||
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(
|
||||
*self.ff2
|
||||
).dropout(self.dropout)
|
||||
else:
|
||||
x = x + self.attn(x).dropout(self.dropout)
|
||||
x = x.layernorm().linear(*self.ln1)
|
||||
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(
|
||||
self.dropout
|
||||
)
|
||||
x = x.layernorm().linear(*self.ln2)
|
||||
return x
|
||||
|
||||
def __call__(self, x):
|
||||
if self.prenorm:
|
||||
x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout)
|
||||
x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
|
||||
else:
|
||||
x = x + self.attn(x).dropout(self.dropout)
|
||||
x = x.layernorm().linear(*self.ln1)
|
||||
x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
|
||||
x = x.layernorm().linear(*self.ln2)
|
||||
return x
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
|
||||
self.maxlen, self.syms = maxlen, syms
|
||||
self.embed = Tensor.scaled_uniform(maxlen+syms, embed_dim, requires_grad=False)
|
||||
self.tbs = []
|
||||
for i in range(layers):
|
||||
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
|
||||
self.final = Tensor.scaled_uniform(embed_dim, syms)
|
||||
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
|
||||
self.maxlen, self.syms = maxlen, syms
|
||||
self.embed = Tensor.scaled_uniform(
|
||||
maxlen + syms, embed_dim, requires_grad=False
|
||||
)
|
||||
self.tbs = []
|
||||
for i in range(layers):
|
||||
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
|
||||
self.final = Tensor.scaled_uniform(embed_dim, syms)
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
xnp = x.numpy().astype(np.int32)
|
||||
onehot = np.zeros((bs, x.shape[1], self.maxlen+self.syms), dtype=np.float32)
|
||||
for i in range(x.shape[1]):
|
||||
onehot[range(bs), i, i] = 1
|
||||
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
|
||||
onehot = onehot.reshape(bs*x.shape[1], self.maxlen+self.syms)
|
||||
|
||||
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
|
||||
x = x.sequential(self.tbs)
|
||||
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax()
|
||||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||
def forward(self, x):
|
||||
bs = x.shape[0]
|
||||
xnp = x.numpy().astype(np.int32)
|
||||
onehot = np.zeros((bs, x.shape[1], self.maxlen + self.syms), dtype=np.float32)
|
||||
for i in range(x.shape[1]):
|
||||
onehot[range(bs), i, i] = 1
|
||||
onehot[range(bs), i, self.maxlen + xnp[:, i]] = 1
|
||||
onehot = onehot.reshape(bs * x.shape[1], self.maxlen + self.syms)
|
||||
|
||||
x = (
|
||||
Tensor(onehot, device=x.device)
|
||||
.dot(self.embed)
|
||||
.reshape(shape=(bs, x.shape[1], -1))
|
||||
)
|
||||
x = x.sequential(self.tbs)
|
||||
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).log_softmax()
|
||||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue