1
0
Fork 0

layernorm fixes transformer instability

pull/221/head
George Hotz 2020-12-28 12:58:15 -05:00
parent 628d21f899
commit 2e89e75dcb
2 changed files with 16 additions and 5 deletions

View File

@ -23,6 +23,12 @@ def make_dataset():
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
def layernorm(x, eps=1e-5):
layer_mean = x.mean(axis=(0,1))
y = (x - layer_mean.reshape(shape=[1, 1, -1]))
layer_var = (y*y).mean(axis=(0,1))
return y.div(layer_var.add(eps).reshape(shape=[1, 1, -1]))
class TransformerBlock:
def __init__(self, embed_dim, num_heads):
# Multi-Head Attention
@ -55,12 +61,16 @@ class TransformerBlock:
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size))
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3))
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
x = inputs + attention.reshape(shape=(-1, self.num_heads * self.head_size)).dot(self.final)
# layernorm
x = x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
x = layernorm(x)
x = x.reshape(shape=(-1, self.num_heads * self.head_size))
x = x + x.dot(self.ff1).relu().dot(self.ff2)
# layernorm
x = x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
x = layernorm(x)
x = x.reshape(shape=(-1, self.num_heads * self.head_size))
return x.reshape(shape=(bs, -1, self.num_heads * self.head_size))
class Transformer:
@ -93,7 +103,7 @@ if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = make_dataset()
optim = Adam(get_parameters(model), lr=0.001)
train(model, X_train, Y_train, optim, 500)
train(model, X_train, Y_train, optim, 500, BS=16)
evaluate(model, X_test, Y_test, num_classes=10)

View File

@ -28,3 +28,4 @@ class BatchNorm2D:
def normalize(self, x, mean, var):
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1])
return x.div(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**0.5) + self.bias.reshape(shape=[1, -1, 1, 1])