shape has to be a kwarg now, idk why this didn't break before
parent
2db670ef26
commit
9ae3e9daf3
|
@ -40,7 +40,7 @@ class TinyConvNet:
|
|||
x.data = x.data.reshape((-1, 1, 28, 28)) # hacks
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
def train(model, optim, steps, BS=128):
|
||||
|
|
|
@ -127,7 +127,7 @@ class Reshape(Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reshape(in_shape), None
|
||||
return grad_output.reshape(in_shape)
|
||||
register('reshape', Reshape)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
|
|
Loading…
Reference in New Issue