1
0
Fork 0

affine is always the last dim

pull/295/head
George Hotz 2021-11-29 15:22:49 -05:00
parent e86f7a4aa3
commit ca160504e1
2 changed files with 11 additions and 6 deletions

View File

@ -12,7 +12,7 @@ import io
from extra.utils import fetch
from tinygrad.tensor import Tensor
from models.transformer import TransformerBlock
from models.transformer import TransformerBlock, layernorm
class ViT:
def __init__(self):
self.conv_weight = Tensor.uniform(192, 3, 16, 16)
@ -37,7 +37,8 @@ class ViT:
for l in self.tbs:
x = l(x)
#print(x.sum())
x = x.affine(self.norm)
print(x.shape)
x = layernorm(x, 192).affine(self.norm)
return x[:, 0].affine(self.head)
m = ViT()
@ -77,8 +78,8 @@ for i in range(12):
m.tbs[i].ln2[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/scale'])
m.tbs[i].ln2[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/bias'])
url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
#url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
# category labels
import ast
@ -96,6 +97,8 @@ img = img[y0:y0+224, x0:x0+224]
img = np.moveaxis(img, [2,0,1], [0,1,2])
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
img /= 255.0
img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1))
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
Tensor.training = False
out = m.forward(Tensor(img))

View File

@ -235,6 +235,7 @@ class Tensor:
def gelu(x):
# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py
#import torch; return Tensor(torch.nn.functional.gelu(torch.tensor(x.data)).numpy())
return 0.5 * x * (1 + (x * 0.7978845608 * (1 + 0.044715 * x * x)).tanh())
def leakyrelu(self, neg_slope=0.01):
@ -286,10 +287,11 @@ class Tensor:
return self._pool2d(*kernel_size).max(axis=(3,5))
def affine(self, params):
shp = [1] * (len(self.shape)-1) + [-1]
if len(params[0].shape) == 1: # elementwise affine
return self.mul(params[0].reshape(shape=[1, -1])).add(params[1].reshape(shape=[1, -1]))
return self.mul(params[0].reshape(shape=shp)).add(params[1].reshape(shape=shp))
else:
return self.dot(params[0]).add(params[1].reshape(shape=[1, -1]))
return self.dot(params[0]).add(params[1].reshape(shape=shp))
# An instantiation of the Function is the Context
class Function: