Revert "revert t as tensor, constant folding should be done better"
This reverts commit 1d800a94ad
.
pull/554/head
parent
1d800a94ad
commit
c0ea538ba0
|
@ -55,7 +55,7 @@ class RMSprop(Optimizer):
|
|||
class Adam(Optimizer):
|
||||
def __init__(self, params : List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
super().__init__(params)
|
||||
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, 0
|
||||
self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0], requires_grad=False).realize()
|
||||
|
||||
self.m = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, device=params[0].device, requires_grad=False) for t in self.params]
|
||||
|
@ -68,7 +68,7 @@ class Adam(Optimizer):
|
|||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
||||
self.realize(self.m + self.v)
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
def get_parameters(obj) -> List[Tensor]:
|
||||
parameters : List[Tensor] = []
|
||||
|
|
Loading…
Reference in New Issue