1
0
Fork 0
tinygrab/tinygrad/graph.py

312 lines
11 KiB
Python

import os, atexit, functools
from collections import defaultdict
from typing import Dict, List
from tinygrad.ops import (
ScheduleItem,
UnaryOps,
BinaryOps,
ReduceOps,
MovementOps,
LoadOps,
BufferOps,
TernaryOps,
Op,
OpType,
LazyOp,
)
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import NumNode
# **** debugging and graphing ****
cnts: Dict[OpType, int] = defaultdict(int)
if DEBUG >= 2:
"""
Function to print global counters.
This function prints the average throughput in GFLOPS and GB/s, total operations, memory usage,
and time taken in milliseconds. It also registers an exit action to call this function.
"""
def print_globalcounters():
if GlobalCounters.time_sum_s == 0:
return
print(
f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms",
)
atexit.register(print_globalcounters)
if GRAPH:
"""
Function to save the graph representation of operations.
This function iterates over counters, saves the graph, and registers an exit action to call this function.
It also uses os system command to convert dot file to svg format.
"""
import networkx as nx
G = nx.DiGraph()
def save_graph_exit():
for k, v in cnts.items():
print(k, v)
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, f"{GRAPHPATH}.dot")
# -Gnslimit=100 can make it finish, but you won't like results
os.system(f"dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg")
atexit.register(save_graph_exit)
node_count = 0
def nm(x):
"""
Assign a unique node_id to an object x if it doesn't have one already.
This function checks for the presence of 'node_id' attribute in the given object 'x'. If not found,
it assigns a new unique id by incrementing the global 'node_count' variable and associates this id
with the object 'x'. Finally, it returns the node_id of the object 'x'.
:param x: The input object to which a unique node_id will be assigned if not already present.
:type x: Any
:return: The node_id of the object 'x'
"""
global node_count
if not hasattr(x, "node_id"):
setattr(x, "node_id", node_count)
node_count += 1
return x.node_id
def get_sop(op: List[Op]):
"""
Returns a string representation of a list of operations 'op'. The returned string is based on the
last part of the class name of each operation in the list, after splitting by '.'. If the length of
'op' is less than or equal to 2, it returns the full names of all operations in reverse order, separated
by '.'. If the length is between 3 and 6, it returns only the first three characters of each operation name.
If the length of 'op' is more than 6, it simply returns the string representation of the length of 'op'.
:param op: List of operations for which a string representation will be returned.
:type op: List[Op]
:return: String representation of the list of operations 'op'
"""
op = [x for x in op if x not in BufferOps]
if len(op) <= 2:
return ".".join([str(y).split(".")[1] for y in op][::-1])
if len(op) <= 6:
return ".".join([str(y).split(".")[1][0:3] for y in op][::-1])
return str(len(op))
def str_dtype(dtyp):
"""
Returns a string representation of the type 'dtyp'. If 'dtyp' is of type 'float', it returns an
empty string. Otherwise, it returns the string representation of 'dtyp'.
:param dtyp: The input object whose string representation will be returned.
:type dtyp: Any
:return: String representation of the object 'dtyp' or an empty string if 'dtyp' is of type 'float'.
"""
ret = str(dtyp)[7:]
return "" if ret == "float" else f"\n{ret}"
"""
Add a node to the ShapeTracker graph.
:param nmx: The source node in the graph.
:type nmx: Node
:param nmo: The destination node in the graph.
:type nmo: Node
:param label: The label for the edge connecting the new node and the destination node.
:type label: str
:param st: The ShapeTracker object that holds information about the shape and strides of the array.
:type st: ShapeTracker
"""
@functools.lru_cache(None)
def add_st_node(nmx, nmo, label, st: ShapeTracker):
global node_count
inter_node = node_count
node_count += 1
offset = st.expr_node(NumNode(0))[0]
G.add_node(
inter_node,
style="filled",
fillcolor="#80ff8080",
color="black",
label=f"{st.shape}\n{st.real_strides()}"
+ (f"\n{offset}" if offset != 0 else ""),
)
G.add_edge(nmx, inter_node, color="#00000060")
G.add_edge(inter_node, nmo, label=label, color="#00000060")
"""
If LOGOPS environment variable is set and not empty, open the file in append mode.
Otherwise, assign None to logops.
"""
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
def log_schedule_item(si: ScheduleItem):
"""
Log the schedule item in the graph.
This function logs the schedule item, its operations, and input/output relationships
in a directed graph. The graph is used for visualization purposes and to keep track of
the transformations on the data.
:param si: ScheduleItem object representing an operation or transformation.
"""
if logops and si.ast.op not in LoadOps:
logops.write(str(si.ast) + "\n")
if not DEBUG and not GRAPH:
return
if si.ast.op == LoadOps.CONTIGUOUS:
setattr(si.out, "node_id", nm(si.inputs[0].base))
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}:
return
op: List[Op] = [x.op for x in si.ast.get_lazyops()]
oporder = [
LoadOps,
TernaryOps,
ReduceOps,
BinaryOps,
UnaryOps,
MovementOps,
BufferOps,
]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if GRAPH:
assert si.out.base == si.out, "all outputs based"
top_colors = {
LoadOps: "#FFFFa0",
UnaryOps: "#c0c0c0",
ReduceOps: "#8080ff",
BinaryOps: "#c0c0c0",
MovementOps: "#80ff80",
TernaryOps: "#c0c0c0",
BufferOps: "#FF8080",
}
# get inputs for shapetrackers
input_to_st = defaultdict(list)
for lo in si.ast.get_lazyops():
if lo.op != BufferOps.LOAD:
continue
input_to_st[si.inputs[lo.arg.idx - 1]].append(lo.arg.st)
# add them to the graph, potentially with a movement op separating them
for x in input_to_st:
for st in dedup(input_to_st[x]):
if st.contiguous:
G.add_edge(nm(x), nm(si.out), label=get_sop(op), color="#00000060")
else:
add_st_node(nm(x), nm(si.out), get_sop(op), st)
if "label" not in G.nodes[nm(x)]:
G.nodes[nm(x)]["label"] = str(x.shape) + str_dtype(si.out.dtype)
if nm(si.out) not in G.nodes:
G.add_node(nm(si.out))
G.nodes[nm(si.out)]["label"] = (
(
str(set(x.shape for x in si.inputs)) + "\n" + str(si.out.shape)
if optype == ReduceOps
else str(si.out.shape)
)
+ str_dtype(si.out.dtype)
+ (f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
)
G.nodes[nm(si.out)]["fillcolor"] = top_colors[optype]
G.nodes[nm(si.out)]["color"] = "black"
G.nodes[nm(si.out)]["style"] = "filled"
def _tree(lazydata, prefix=""):
"""
Recursively build a tree representation for the lazydata object.
Parameters:
lazydata (LazyOp): The LazyOp object to convert into a tree representation.
prefix (str): An optional string prefix to prepend to each line of the output.
Returns:
list: A list of strings, where each string represents a line in the tree.
"""
if type(lazydata).__name__ == "LazyBuffer":
return (
[f"━━ realized {lazydata.dtype.name} {lazydata.shape}"]
if (lazydata.realized)
else _tree(lazydata.op, "LB ")
)
if len(lazydata.src) == 0:
return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
childs = [_tree(c) for c in lazydata.src[:]]
for c in childs[:-1]:
lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + ["" + childs[-1][0]] + [" " + l for l in childs[-1][1:]]
def print_tree(lazydata: LazyOp):
"""
This function takes a single argument of type LazyOp and prints the tree structure.
Parameters:
lazydata (LazyOp): The input data for which we want to generate the tree structure.
Returns:
None. The function prints the tree structure directly.
"""
print("\n".join([f"{str(i).rjust(3)} {s}" for i, s in enumerate(_tree(lazydata))]))
def graph_uops(uops: List[UOp]):
"""
This function generates a directed graph from a list of UOps. It creates a node for each UOp and adds edges between
them based on their input/output relationships. The nodes are colored according to the type of UOp they represent.
:param uops: A list of UOp objects representing the operations to be graphed.
"""
import networkx as nx
colors = {
UOps.ALU: "#ffffc0",
UOps.LOAD: "#ffc0c0",
UOps.STORE: "#c0ffc0",
UOps.SPECIAL: "#c0c0ff",
UOps.CONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0",
UOps.DEFINE_LOCAL: "#ffe0d0",
UOps.DEFINE_ACC: "#f0ffe0",
UOps.LOOP: "#c8a0e0",
UOps.PHI: "#e0ffc0",
UOps.BARRIER: "#ff8080",
UOps.IF: "#c8b0c0",
}
G = nx.DiGraph()
for u in uops:
if u.uop == UOps.END:
continue
G.add_node(
uops.index(u),
label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}",
style="filled",
fillcolor=colors.get(u.uop, "#ffffff"),
)
for v in u.vin:
G.add_edge(uops.index(v), uops.index(u))
GRAPHPATH = "/tmp/uops"
nx.drawing.nx_pydot.write_dot(G, f"{GRAPHPATH}.dot")
os.system(f"dot -Grankdir=LR -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg")