1
0
Fork 0

Revert "revert t as tensor, constant folding should be done better"

This reverts commit 1d800a94ad.
pull/554/head
George Hotz 2023-02-10 23:06:00 -06:00
parent 1d800a94ad
commit c0ea538ba0
1 changed files with 2 additions and 2 deletions

View File

@ -55,7 +55,7 @@ class RMSprop(Optimizer):
class Adam(Optimizer):
def __init__(self, params : List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
super().__init__(params)
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, 0
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize()
self.m = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params]
self.v = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params]
@ -68,7 +68,7 @@ class Adam(Optimizer):
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
self.realize(self.m + self.v)
self.realize([self.t] + self.m + self.v)
def get_parameters(obj) -> List[Tensor]:
parameters : List[Tensor] = []