1
0
Fork 0

fix type of helpers.prod, add test cases (#1859)

pull/1866/head
chenyu 2023-09-13 14:16:55 -07:00 committed by GitHub
parent e67306ba04
commit 1b46de1a3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 21 additions and 6 deletions

View File

@ -1,6 +1,7 @@
import unittest
import numpy as np
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod
from tinygrad.shape.symbolic import Variable, NumNode
VARIABLE = ContextVar("VARIABLE", 0)
@ -130,5 +131,12 @@ class TestStripParens(unittest.TestCase):
def test_nested(self): self.assertEqual("1+(2+3)", strip_parens("(1+(2+3))"))
def test_casted_no_strip(self): self.assertEqual("(int)(1+2)", strip_parens("(int)(1+2)"))
class TestProd(unittest.TestCase):
def test_empty(self): self.assertEqual(1, prod(tuple()))
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
if __name__ == '__main__':
unittest.main()

View File

@ -1,11 +1,11 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar
# TODO: fix types for prod
#from math import prod # noqa: F401 # pylint:disable=unused-import
def prod(x:Iterable): return functools.reduce(operator.__mul__, x, 1)
T = TypeVar("T")
# NOTE: it returns int 1 if x is empty regardless of the type of x
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX = platform.system() == "Darwin"

View File

@ -357,6 +357,7 @@ def _realize_from(buffer: LazyBuffer) -> None:
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {rawbuf.device} size {rawbuf.realized.size} dtype {rawbuf.realized.dtype}")
# TODO: make this generic
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
assert all_int(buffer.shape), "does not support symbolic shape"
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
rawbuf.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
elif isinstance(rawbuf.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and P2P >= 1:
@ -365,6 +366,7 @@ def _realize_from(buffer: LazyBuffer) -> None:
buffer.realized = Device[buffer.device].buffer.fromCPU(rawbuf.toCPU(), **buffer._device_extra_args())
def _realize_empty(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
def _realize_rand(buffer: LazyBuffer) -> None:

View File

@ -2,6 +2,7 @@ import math
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import all_int
class BatchNorm2d:
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
@ -44,6 +45,7 @@ class Conv2d:
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5))
assert all_int(self.weight.shape), "does not support symbolic shape"
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None
@ -58,6 +60,7 @@ class ConvTranspose2d:
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5))
assert all_int(self.weight.shape), "does not support symbolic shape"
bound = 1 / math.sqrt(prod(self.weight.shape[1:]))
self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None

View File

@ -416,9 +416,11 @@ class Tensor:
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False):
assert all_int(self.shape), "does not support symbolic shape"
out = self.sum(axis=axis, keepdim=keepdim)
return out * (prod(out.shape)/prod(self.shape))
def std(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
def _softmax(self, axis):
@ -730,7 +732,7 @@ class Tensor:
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> int: return prod(self.shape)
def numel(self) -> sint: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)