1
0
Fork 0

clean up vit code

pull/295/head
George Hotz 2021-11-30 10:58:03 -05:00
parent 6884add850
commit 835869974c
2 changed files with 7 additions and 10 deletions

View File

@ -23,8 +23,8 @@ dat = np.load(io.BytesIO(fetch("https://storage.googleapis.com/vit_models/augreg
#for x in dat.keys():
# print(x, dat[x].shape, dat[x].dtype)
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
m.conv_bias.assign(dat['embedding/bias'])
m.conv[0].assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
m.conv[1].assign(dat['embedding/bias'])
m.norm[0].assign(dat['Transformer/encoder_norm/scale'])
m.norm[1].assign(dat['Transformer/encoder_norm/bias'])

View File

@ -51,7 +51,6 @@ class TransformerBlock:
return x
class Transformer:
# L = layers, H = embed_dim, A = num_heads
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
self.maxlen, self.syms = maxlen, syms
self.embed = Tensor.uniform(maxlen+syms, embed_dim, requires_grad=False)
@ -75,18 +74,16 @@ class Transformer:
return x.reshape(shape=(bs, -1, x.shape[-1]))
class ViT:
def __init__(self, embed_dim=192):
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
self.conv_bias = Tensor.zeros(embed_dim)
def __init__(self, layers=12, embed_dim=192, num_heads=3):
self.conv = (Tensor.uniform(embed_dim, 3, 16, 16), Tensor.zeros(embed_dim))
self.cls_token = Tensor.ones(1, 1, embed_dim)
self.tbs = [TransformerBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768, prenorm=True) for i in range(12)]
self.pos_embed = Tensor.ones(1, 197, embed_dim)
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
self.tbs = [TransformerBlock(embed_dim=embed_dim, num_heads=num_heads, ff_dim=embed_dim*4, prenorm=True) for i in range(layers)]
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
def patch_embed(self, x):
x = x.conv2d(self.conv_weight, stride=16)
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
x = x.conv2d(*self.conv, stride=16)
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
return x