1
0
Fork 0
tinygrab/extra/optimization/extract_policynet.py

152 lines
4.1 KiB
Python

import os, sys, sqlite3, pickle, random
from tqdm import tqdm, trange
from copy import deepcopy
from tinygrad.nn import Linear
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import Adam
from tinygrad.nn.state import (
get_parameters,
get_state_dict,
safe_save,
safe_load,
load_state_dict,
)
from tinygrad.features.search import actions
from extra.optimization.helpers import (
load_worlds,
ast_str_to_lin,
lin_to_feats,
assert_same_lin,
)
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.helpers import getenv
# stuff needed to unpack a kernel
from tinygrad.ops import (
LazyOp,
TernaryOps,
BinaryOps,
UnaryOps,
ReduceOps,
BufferOps,
MemBuffer,
ConstBuffer,
)
from tinygrad.helpers import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
inf, nan = float("inf"), float("nan")
from tinygrad.codegen.kernel import Opt, OptOps
INNER = 256
class PolicyNet:
def __init__(self):
self.l1 = Linear(1021, INNER)
self.l2 = Linear(INNER, INNER)
self.l3 = Linear(INNER, 1 + len(actions))
def __call__(self, x):
x = self.l1(x).relu()
x = self.l2(x).relu().dropout(0.9)
return self.l3(x).log_softmax()
def dataset_from_cache(fn):
conn = sqlite3.connect(fn)
cur = conn.cursor()
cur.execute("SELECT * FROM beam_search")
X, A = [], []
for f in tqdm(cur.fetchall()):
Xs, As = [], []
try:
lin = Linearizer(eval(f[0]))
opts = pickle.loads(f[-1])
for o in opts:
Xs.append(lin_to_feats(lin, use_sts=True))
As.append(actions.index(o))
lin.apply_opt(o)
Xs.append(lin_to_feats(lin, use_sts=True))
As.append(0)
except Exception:
pass
X += Xs
A += As
return X, A
if __name__ == "__main__":
if getenv("REGEN"):
X, V = dataset_from_cache(
sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache"
)
safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset_policy")
else:
ld = safe_load("/tmp/dataset_policy")
X, V = ld["X"].numpy(), ld["V"].numpy()
print(X.shape, V.shape)
order = list(range(X.shape[0]))
random.shuffle(order)
X, V = X[order], V[order]
ratio = -256
X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:])
X, V = X[:ratio], V[:ratio]
print(X.shape, V.shape)
net = PolicyNet()
# if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
optim = Adam(get_parameters(net))
def get_minibatch(X, Y, bs):
xs, ys = [], []
for _ in range(bs):
sel = random.randint(0, len(X) - 1)
xs.append(X[sel])
ys.append(Y[sel])
return Tensor(xs), Tensor(ys)
Tensor.no_grad, Tensor.training = False, True
losses = []
test_losses = []
test_accuracy = 0
test_loss = float("inf")
for i in (t := trange(500)):
x, y = get_minibatch(X, V, bs=256)
out = net(x)
loss = out.sparse_categorical_crossentropy(y)
optim.zero_grad()
loss.backward()
optim.step()
cat = out.argmax(axis=-1)
accuracy = (cat == y).mean()
t.set_description(
f"loss {loss.numpy():7.2f} accuracy {accuracy.numpy()*100:7.2f}%, test loss {test_loss:7.2f} test accuracy {test_accuracy*100:7.2f}%"
)
losses.append(loss.numpy().item())
test_losses.append(test_loss)
if i % 10:
out = net(X_test)
test_loss = (
out.sparse_categorical_crossentropy(V_test)
.square()
.mean()
.numpy()
.item()
)
cat = out.argmax(axis=-1)
test_accuracy = (cat == y).mean().numpy()
safe_save(get_state_dict(net), "/tmp/policynet.safetensors")
import matplotlib.pyplot as plt
plt.plot(losses[10:])
plt.plot(test_losses[10:])
plt.show()