clean up vit code
parent
6884add850
commit
835869974c
|
@ -23,8 +23,8 @@ dat = np.load(io.BytesIO(fetch("https://storage.googleapis.com/vit_models/augreg
|
|||
#for x in dat.keys():
|
||||
# print(x, dat[x].shape, dat[x].dtype)
|
||||
|
||||
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
|
||||
m.conv_bias.assign(dat['embedding/bias'])
|
||||
m.conv[0].assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
|
||||
m.conv[1].assign(dat['embedding/bias'])
|
||||
|
||||
m.norm[0].assign(dat['Transformer/encoder_norm/scale'])
|
||||
m.norm[1].assign(dat['Transformer/encoder_norm/bias'])
|
||||
|
|
|
@ -51,7 +51,6 @@ class TransformerBlock:
|
|||
return x
|
||||
|
||||
class Transformer:
|
||||
# L = layers, H = embed_dim, A = num_heads
|
||||
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
|
||||
self.maxlen, self.syms = maxlen, syms
|
||||
self.embed = Tensor.uniform(maxlen+syms, embed_dim, requires_grad=False)
|
||||
|
@ -75,18 +74,16 @@ class Transformer:
|
|||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||
|
||||
class ViT:
|
||||
def __init__(self, embed_dim=192):
|
||||
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
|
||||
self.conv_bias = Tensor.zeros(embed_dim)
|
||||
def __init__(self, layers=12, embed_dim=192, num_heads=3):
|
||||
self.conv = (Tensor.uniform(embed_dim, 3, 16, 16), Tensor.zeros(embed_dim))
|
||||
self.cls_token = Tensor.ones(1, 1, embed_dim)
|
||||
self.tbs = [TransformerBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768, prenorm=True) for i in range(12)]
|
||||
self.pos_embed = Tensor.ones(1, 197, embed_dim)
|
||||
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
|
||||
self.tbs = [TransformerBlock(embed_dim=embed_dim, num_heads=num_heads, ff_dim=embed_dim*4, prenorm=True) for i in range(layers)]
|
||||
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
|
||||
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
|
||||
|
||||
def patch_embed(self, x):
|
||||
x = x.conv2d(self.conv_weight, stride=16)
|
||||
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
|
||||
x = x.conv2d(*self.conv, stride=16)
|
||||
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
|
||||
return x
|
||||
|
||||
|
|
Loading…
Reference in New Issue