Fork 0

brown img

George Hotz 2022-09-05 15:20:18 -07:00
parent 98d6264987
commit 5a685b93ac
1 changed files with 85 additions and 11 deletions

View File

@ -206,6 +206,7 @@ class CrossAttention:
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
self.scale = d_head ** -0.5
self.num_heads = n_heads
self.head_size = d_head
self.to_out = [Linear(n_heads*d_head, query_dim)]
@ -215,12 +216,14 @@ class CrossAttention:
context = x if context is None else context
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time)
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
score = q.dot(k) * (1 / np.sqrt(self.head_size))
score = q.dot(k) * self.scale
#print("score", score.shape, score.numpy())
weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size)
@ -260,13 +263,13 @@ class BasicTransformerBlock:
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer:
def __init__(self, channels, context_dim, n_heads, d_head):
self.norm = Normalize(channels)
assert channels == n_heads * d_head
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
@ -280,6 +283,7 @@ class SpatialTransformer:
x = x.reshape(b, c, h*w).permute(0,2,1)
for block in self.transformer_blocks:
x = block(x, context=context)
#print(x.shape, x.numpy())
x = x.permute(0,2,1).reshape(b, c, h, w)
ret = self.proj_out(x) + x_in
@ -366,15 +370,18 @@ class UNetModel:
else: x = bb(x)
return x
saved_inputs = []
for i,b in enumerate(self.input_blocks):
print("input block", i)
for bb in b:
x = run(x, bb)
#if i == 3:
# print(x.numpy())
# return None
#if i == 1:
for bb in self.middle_block:
x = run(x, bb)
for i,b in enumerate(self.output_blocks):
@ -500,19 +507,19 @@ class CLIPTextTransformer:
def __call__(self, input_ids):
x = self.embeddings(input_ids, list(range(len(input_ids))))
causal_attention_mask = np.triu(np.ones((1,1,77,77), dtype=np.float32) * -np.inf, k=1)
x = self.encoder(x, Tensor(causal_attention_mask, device=x.device))
return self.final_layer_norm(x)
class StableDiffusion:
def __init__(self):
#self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
#self.first_stage_model = AutoencoderKL()
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
def __call__(self, x, timesteps, context):
return self.model.diffusion_model(x, timesteps, context)
#def __call__(self, x, timesteps, context):
#return self.model.diffusion_model(x, timesteps, context)
#return self.first_stage_model(x)
# ** ldm.models.autoencoder.AutoencoderKL (done!)
@ -556,7 +563,74 @@ if __name__ == "__main__":
# "a horse sized cat eating a bagel"
# run through CLIP to get context
phrase = [49406, 320, 4558, 9832, 2368, 4371, 320, 28777, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
context = model.cond_stage_model.transformer.text_model(phrase)
print("got CLIP context", context.shape)
phrase = [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
unconditional_context = model.cond_stage_model.transformer.text_model(phrase)
print("got unconditional CLIP context", unconditional_context.shape)
def get_model_output(latent):
# put into diffuser
timesteps = Tensor([1])
unconditional_latent = model.model.diffusion_model(latent, timesteps, unconditional_context)
latent = model.model.diffusion_model(latent, timesteps, context)
unconditional_guidance_scale = 7.5
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
return e_t
def get_x_prev_and_pred_x0(x, e_t, index):
temperature = 1
a_t, a_prev, sigma_t, sqrt_one_minus_at = 0.9983, 0.9991, 0., 0.0413
pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
# direction pointing to x_t
dir_xt = math.sqrt(1. - a_prev - sigma_t**2) * e_t
noise = sigma_t * Tensor.randn(*x.shape) * temperature
x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt #+ noise
return x_prev, pred_x0
# start with random noise
latent = Tensor.randn(1,4,64,64)
# is this the diffusion?
index = 0
e_t = get_model_output(latent)
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)
#e_t_next = get_model_output(x_prev)
#e_t_prime = (e_t + e_t_next) / 2
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
latent = x_prev
# sanity check
#latent = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(latent)
x = model.first_stage_model.decoder(x)
# make image correct size and scale
x = (x + 1.0) / 2.0
x = x.reshape(3,512,512).permute(1,2,0)
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
# save image
from PIL import Image
im = Image.fromarray(dat)
# load apple latent space