1
0
Fork 0

allow zerosized tensors (#1659)

* allow zerosized tensors

* works with numpy
pull/1718/head
nimlgen 2023-08-30 20:39:24 +03:00 committed by GitHub
parent f9cb31fdc2
commit 355b02dc3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 8 deletions

View File

@ -53,10 +53,11 @@ def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._
class TestAllocators(unittest.TestCase):
def test_lru_allocator_reusage(self):
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
def test():
lru_allocator = FakeAllocator(2048)
traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32)
assert GlobalCounters.mem_cached == 16*dtypes.float32.itemsize, "Buffer should be cached"
assert GlobalCounters.mem_cached - mc == 16*dtypes.float32.itemsize, "Buffer should be cached"
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
@ -69,19 +70,20 @@ class TestAllocators(unittest.TestCase):
buf = alloc(lru_allocator, 16, dtypes.float32)
assert usedbuf != buf, "Nobody should get used buffer"
__test()
assert GlobalCounters.mem_used == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
assert GlobalCounters.mem_used - mu == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
test()
check_gc()
def test_lru_allocator_cache_free(self):
mc, mu = GlobalCounters.mem_cached, GlobalCounters.mem_used
def test():
lru_allocator = FakeAllocator(128)
refs = []
for _ in range(32):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32))
for sz in range(32):
for sz in range(1, 32):
alloc_free_trace(lru_allocator, sz, dtypes.float32)
assert GlobalCounters.mem_used + GlobalCounters.mem_cached <= 128, "Should not allocate on device more than allowed (128)"
assert GlobalCounters.mem_used + GlobalCounters.mem_cached - mc - mu <= 128, "Should not allocate on device more than allowed (128)"
for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache"
test()
check_gc()

View File

@ -220,5 +220,9 @@ class TestTinygrad(unittest.TestCase):
x = Tensor.randn(1, 1, 1)
x.dot(layer).mean().backward()
def test_zerosized_tensors(self):
Tensor([]).realize()
Tensor([]).numpy()
if __name__ == '__main__':
unittest.main()

View File

@ -42,13 +42,13 @@ class RawBufferCopyIn(RawBuffer):
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs):
ret = cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
ret._copyin(x)
if x.size > 0: ret._copyin(x)
return ret
class RawBufferMapped(RawBufferCopyIn):
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
def toCPU(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.toCPU(), x.reshape(-1))
# this one is simple enough that i moved it out of the runtimes
@ -61,7 +61,7 @@ class RawBufferCopyInOut(RawBufferCopyIn):
def toCPU(self) -> np.ndarray:
x: np.ndarray = np.empty(self.size, dtype=self.dtype.np)
self._copyout(x)
if x.size > 0: self._copyout(x)
return x
class RawBufferTransfer(RawBuffer):
@ -91,7 +91,7 @@ class LRUAllocator:
while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
bucket, epoch = self.aging_order[device].popleft()
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
newbuf = self._do_alloc(size, dtype, device, **kwargs)
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
self.buffer_info[newbuf] = (size, dtype, device)
return newbuf
def _free_buffer(self, buf_to_free):