This reverts commit ab645317c9
.
pull/1285/head
parent
0aed3f73da
commit
940b6fd21a
|
@ -120,8 +120,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.ops import LazyOp, BinaryOps, LoadOps
|
||||
|
||||
# the 2+3 from before
|
||||
# added some 0s, otherwise Tensor([2]) will be folded into a constant without using LoadOps.FROM
|
||||
result = Tensor([2, 0]) + Tensor([3, 0])
|
||||
result = Tensor([2]) + Tensor([3])
|
||||
print(type(result.lazydata), result.lazydata) # let's look at the lazydata of result
|
||||
|
||||
# you'll see it has a LazyOp
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import dataclasses
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps, OpType
|
||||
import itertools
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.helpers import dtypes
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
|
||||
|
@ -191,26 +192,5 @@ class TestTinygrad(unittest.TestCase):
|
|||
for _, dtype in dtypes.fields().items():
|
||||
assert dtype.itemsize == Tensor.randn(3, dtype=dtype).element_size(), f"Tensor.element_size() not matching Tensor.dtype.itemsize for {dtype}"
|
||||
|
||||
def test_constant_fold(self):
|
||||
def helper_assert_all_const(op: OpType):
|
||||
if isinstance(op.op, LoadOps): assert op.op == LoadOps.CONST
|
||||
else:
|
||||
for buf in op.buffers: helper_assert_all_const(buf.op)
|
||||
helper_assert_all_const(Tensor(2).lazydata.op)
|
||||
helper_assert_all_const(Tensor(2).reshape([1, 1, 1]).lazydata.op)
|
||||
helper_assert_all_const(Tensor([2]).lazydata.op)
|
||||
helper_assert_all_const(Tensor([2]).reshape([1, 1, 1]).lazydata.op)
|
||||
helper_assert_all_const((Tensor(2)+Tensor(3)).lazydata.op)
|
||||
helper_assert_all_const((Tensor(2)+Tensor([3])).lazydata.op)
|
||||
helper_assert_all_const((Tensor([[2]])+Tensor([3])).lazydata.op)
|
||||
with self.assertRaises(AssertionError):
|
||||
helper_assert_all_const((Tensor([2, 0])+Tensor([3, 0])).lazydata.op)
|
||||
|
||||
def test_constant_fold_shape(self):
|
||||
self.assertEqual(Tensor(3).shape, ())
|
||||
self.assertEqual(Tensor([3]).shape, (1,))
|
||||
self.assertEqual(Tensor([[3]]).shape, (1, 1))
|
||||
self.assertEqual(Tensor([[[3]]]).shape, (1, 1, 1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -67,11 +67,8 @@ class Tensor:
|
|||
|
||||
if data.__class__ is np.ndarray:
|
||||
data = cast(np.ndarray, data)
|
||||
if data.size == 1: # constant fold
|
||||
self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtypes.from_np(data.dtype), device, data.flat[0]).reshape(data.shape)
|
||||
else:
|
||||
data = LazyBuffer.fromCPU(data)
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
data = LazyBuffer.fromCPU(data)
|
||||
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
|
||||
return
|
||||
|
||||
raise RuntimeError(f"can't create Tensor from {data}")
|
||||
|
|
Loading…
Reference in New Issue