brown img
parent
98d6264987
commit
5a685b93ac
|
@ -206,6 +206,7 @@ class CrossAttention:
|
|||
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
|
||||
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
|
||||
self.scale = d_head ** -0.5
|
||||
self.num_heads = n_heads
|
||||
self.head_size = d_head
|
||||
self.to_out = [Linear(n_heads*d_head, query_dim)]
|
||||
|
@ -215,12 +216,14 @@ class CrossAttention:
|
|||
context = x if context is None else context
|
||||
#print(x.numpy())
|
||||
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
|
||||
#print(q.numpy())
|
||||
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
|
||||
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time)
|
||||
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
|
||||
|
||||
score = q.dot(k) * (1 / np.sqrt(self.head_size))
|
||||
score = q.dot(k) * self.scale
|
||||
#print("score", score.shape, score.numpy())
|
||||
#exit(0)
|
||||
|
||||
weights = score.softmax() # (bs, num_heads, time, time)
|
||||
attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size)
|
||||
|
||||
|
@ -260,13 +263,13 @@ class BasicTransformerBlock:
|
|||
#print(self.norm1(x).numpy())
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
#print(x.numpy())
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
class SpatialTransformer:
|
||||
def __init__(self, channels, context_dim, n_heads, d_head):
|
||||
self.norm = Normalize(channels)
|
||||
assert channels == n_heads * d_head
|
||||
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
|
||||
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
|
||||
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
|
||||
|
@ -280,6 +283,7 @@ class SpatialTransformer:
|
|||
x = x.reshape(b, c, h*w).permute(0,2,1)
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
#print(x.numpy())
|
||||
#print(x.shape, x.numpy())
|
||||
x = x.permute(0,2,1).reshape(b, c, h, w)
|
||||
ret = self.proj_out(x) + x_in
|
||||
|
@ -366,15 +370,18 @@ class UNetModel:
|
|||
else: x = bb(x)
|
||||
return x
|
||||
|
||||
#print(emb.numpy())
|
||||
|
||||
saved_inputs = []
|
||||
for i,b in enumerate(self.input_blocks):
|
||||
print("input block", i)
|
||||
for bb in b:
|
||||
x = run(x, bb)
|
||||
#if i == 3:
|
||||
# print(x.numpy())
|
||||
# return None
|
||||
#if i == 1:
|
||||
#print(x.numpy())
|
||||
#exit(0)
|
||||
saved_inputs.append(x)
|
||||
#print(x.numpy())
|
||||
for bb in self.middle_block:
|
||||
x = run(x, bb)
|
||||
for i,b in enumerate(self.output_blocks):
|
||||
|
@ -500,19 +507,19 @@ class CLIPTextTransformer:
|
|||
|
||||
def __call__(self, input_ids):
|
||||
x = self.embeddings(input_ids, list(range(len(input_ids))))
|
||||
print(x.numpy())
|
||||
#print(x.numpy())
|
||||
causal_attention_mask = np.triu(np.ones((1,1,77,77), dtype=np.float32) * -np.inf, k=1)
|
||||
x = self.encoder(x, Tensor(causal_attention_mask, device=x.device))
|
||||
return self.final_layer_norm(x)
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
#self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
#self.first_stage_model = AutoencoderKL()
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
self.first_stage_model = AutoencoderKL()
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
|
||||
def __call__(self, x, timesteps, context):
|
||||
return self.model.diffusion_model(x, timesteps, context)
|
||||
#def __call__(self, x, timesteps, context):
|
||||
#return self.model.diffusion_model(x, timesteps, context)
|
||||
#return self.first_stage_model(x)
|
||||
|
||||
# ** ldm.models.autoencoder.AutoencoderKL (done!)
|
||||
|
@ -556,7 +563,74 @@ if __name__ == "__main__":
|
|||
w.assign(v.astype(np.float32))
|
||||
|
||||
# "a horse sized cat eating a bagel"
|
||||
# run through CLIP to get context
|
||||
phrase = [49406, 320, 4558, 9832, 2368, 4371, 320, 28777, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
|
||||
context = model.cond_stage_model.transformer.text_model(phrase)
|
||||
print("got CLIP context", context.shape)
|
||||
|
||||
phrase = [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
|
||||
unconditional_context = model.cond_stage_model.transformer.text_model(phrase)
|
||||
print("got unconditional CLIP context", unconditional_context.shape)
|
||||
|
||||
def get_model_output(latent):
|
||||
# put into diffuser
|
||||
timesteps = Tensor([1])
|
||||
unconditional_latent = model.model.diffusion_model(latent, timesteps, unconditional_context)
|
||||
|
||||
latent = model.model.diffusion_model(latent, timesteps, context)
|
||||
unconditional_guidance_scale = 7.5
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
return e_t
|
||||
|
||||
def get_x_prev_and_pred_x0(x, e_t, index):
|
||||
temperature = 1
|
||||
a_t, a_prev, sigma_t, sqrt_one_minus_at = 0.9983, 0.9991, 0., 0.0413
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = math.sqrt(1. - a_prev - sigma_t**2) * e_t
|
||||
noise = sigma_t * Tensor.randn(*x.shape) * temperature
|
||||
|
||||
x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt #+ noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
# start with random noise
|
||||
latent = Tensor.randn(1,4,64,64)
|
||||
|
||||
# is this the diffusion?
|
||||
index = 0
|
||||
e_t = get_model_output(latent)
|
||||
print(e_t.numpy())
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)
|
||||
#e_t_next = get_model_output(x_prev)
|
||||
#e_t_prime = (e_t + e_t_next) / 2
|
||||
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
||||
|
||||
latent = x_prev
|
||||
|
||||
print(latent.numpy())
|
||||
#exit(0)
|
||||
|
||||
# sanity check
|
||||
#latent = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.first_stage_model.post_quant_conv(latent)
|
||||
x = model.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.reshape(3,512,512).permute(1,2,0)
|
||||
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
|
||||
print(dat.shape)
|
||||
|
||||
# save image
|
||||
from PIL import Image
|
||||
im = Image.fromarray(dat)
|
||||
im.save("/tmp/rendered.png")
|
||||
exit(0)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
# load apple latent space
|
||||
|
|
Loading…
Reference in New Issue