fix type of helpers.prod, add test cases (#1859)
parent
e67306ba04
commit
1b46de1a3e
|
@ -1,6 +1,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
|
@ -130,5 +131,12 @@ class TestStripParens(unittest.TestCase):
|
|||
def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))"))
|
||||
def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)"))
|
||||
|
||||
class TestProd(unittest.TestCase):
|
||||
def test_empty(self): self.assertEqual(1, prod(tuple()))
|
||||
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
|
||||
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
|
||||
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
|
||||
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,11 +1,11 @@
|
|||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar
|
||||
|
||||
# TODO: fix types for prod
|
||||
#from math import prod # noqa: F401 # pylint:disable=unused-import
|
||||
def prod(x:Iterable): return functools.reduce(operator.__mul__, x, 1)
|
||||
T = TypeVar("T")
|
||||
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
||||
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
|
||||
|
||||
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
||||
OSX = platform.system() == "Darwin"
|
||||
|
|
|
@ -357,6 +357,7 @@ def _realize_from(buffer: LazyBuffer) -> None:
|
|||
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {rawbuf.device} size {rawbuf.realized.size} dtype {rawbuf.realized.dtype}")
|
||||
# TODO: make this generic
|
||||
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
|
||||
elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
|
||||
|
@ -365,6 +366,7 @@ def _realize_from(buffer: LazyBuffer) -> None:
|
|||
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
|
||||
|
||||
def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
|
|
|
@ -2,6 +2,7 @@ import math
|
|||
from typing import Optional, Union, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.symbolic import all_int
|
||||
|
||||
class BatchNorm2d:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
|
@ -44,6 +45,7 @@ class Conv2d:
|
|||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
assert all_int(self.weight.shape), "does not support symbolic shape"
|
||||
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
|
||||
|
@ -58,6 +60,7 @@ class ConvTranspose2d:
|
|||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
|
||||
self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
|
||||
assert all_int(self.weight.shape), "does not support symbolic shape"
|
||||
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
|
||||
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
|
||||
|
||||
|
|
|
@ -416,9 +416,11 @@ class Tensor:
|
|||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out * (prod(out.shape)/prod(self.shape))
|
||||
def std(self, axis=None, keepdim=False, correction=1):
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
|
||||
def _softmax(self, axis):
|
||||
|
@ -730,7 +732,7 @@ class Tensor:
|
|||
|
||||
@property
|
||||
def ndim(self) -> int: return len(self.shape)
|
||||
def numel(self) -> int: return prod(self.shape)
|
||||
def numel(self) -> sint: return prod(self.shape)
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
|
|
Loading…
Reference in New Issue