1
0
Fork 0
tinygrab/tinygrad/runtime/ops_llvm.py

140 lines
5.2 KiB
Python

import ctypes
from typing import ClassVar, Tuple
from tinygrad.device import Compiled, MallocAllocator
from tinygrad.helpers import getenv, DEBUG, diskcache, cpu_time_execution
from ctypes import CFUNCTYPE
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.llvmir import uops_to_llvm_ir
import llvmlite.binding as llvm
LLVMOPT = bool(getenv("LLVMOPT"))
class LLVM:
"""
This class represents the LLVM compiler infrastructure. It provides an interface for compiling and optimizing code.
Attributes:
target_machine (ClassVar[llvm.targets.TargetMachine]): The target machine used for code generation.
engine (ClassVar[llvm.executionengine.ExecutionEngine]): The execution engine used to run the compiled code.
optimizer (ClassVar[llvm.passmanagers.ModulePassManager]): The module pass manager used to optimize the generated code.
"""
target_machine: ClassVar[llvm.targets.TargetMachine] = None
engine: ClassVar[llvm.executionengine.ExecutionEngine] = None
optimizer: ClassVar[llvm.passmanagers.ModulePassManager] = None
def __init__(self):
"""
Initialize the LLVM compiler infrastructure. This includes initializing the necessary LLVM components, creating a target machine, setting up an execution engine, and configuring the module pass manager. If the LLVM infrastructure has already been initialized, this method does nothing.
"""
if LLVM.engine is not None:
return
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
llvm.initialize_native_asmparser()
target = llvm.Target.from_triple(llvm.get_process_triple())
LLVM.optimizer = llvm.create_module_pass_manager()
LLVM.target_machine = target.create_target_machine(
opt=2
) # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
LLVM.target_machine.add_analysis_passes(LLVM.optimizer)
# TODO: this makes compile times so much faster
if LLVMOPT:
llvm.set_option(
str(), "-force-vector-interleave=4"
) # this makes sum the same speed as torch, it also doubles the (slow) conv speed
if DEBUG >= 4:
llvm.set_option(str(), "--debug-only=loop-vectorize")
# llvm.set_option(str(), '--debug')
# does this do anything?
builder = llvm.create_pass_manager_builder()
builder.opt_level = 3
builder.size_level = 0
builder.loop_vectorize = True
builder.slp_vectorize = True
builder.populate(LLVM.optimizer)
LLVM.target_machine.set_asm_verbosity(True)
backing_mod = llvm.parse_assembly(str())
backing_mod.triple = llvm.get_process_triple()
LLVM.engine = llvm.create_mcjit_compiler(backing_mod, LLVM.target_machine)
@diskcache
def compile_llvm(prg, llvmopt=LLVMOPT) -> bytes:
"""
Compile LLVM program.
This function takes a program in the LLVM assembly format and compiles it into an object file.
Args:
prg (str): The program to compile, in LLVM assembly format.
llvmopt (LLVMOPT): The optimization level for compilation. Defaults to LLVMOPT.
Returns:
bytes: The compiled object file as a byte array.
"""
mod = llvm.parse_assembly(prg)
mod.verify()
LLVM().optimizer.run(mod)
if DEBUG >= 5:
print(LLVM.target_machine.emit_assembly(mod))
return LLVM.target_machine.emit_object(mod)
class LLVMProgram:
"""
Represents a compiled LLVM program.
This class represents a compiled LLVM program that can be executed on the CPU. It contains a name, library, and function address.
The class also provides a callable interface for executing the compiled program.
Attributes:
name (str): The name of the program.
lib (bytes): The compiled object file as a byte array.
Args:
name (str): The name of the program.
lib (bytes): The compiled object file as a byte array.
"""
def __init__(self, name: str, lib: bytes):
self.name, self.lib = name, lib
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
self.fxn = LLVM.engine.get_function_address(name)
def __call__(self, *bufs, vals: Tuple[int, ...] = (), wait=False):
"""
Execute the compiled LLVM program.
This method executes the compiled LLVM program on the CPU and measures the execution time.
Args:
*bufs: The input buffers for the program.
vals (Tuple[int, ...]): The integer values to pass to the program. Defaults to an empty tuple.
wait (bool): Whether to measure and print the execution time. Defaults to False.
Returns:
Any: The result of executing the program.
"""
self.cfunc = CFUNCTYPE(
ctypes.c_int,
*([ctypes.c_void_p] * len(bufs)),
*([ctypes.c_int32] * len(vals))
)(self.fxn)
return cpu_time_execution(lambda: self.cfunc(*bufs, *vals), enable=wait)
LLVMDevice = Compiled(
MallocAllocator,
LinearizerOptions(supports_float4=False, has_local=False, has_shared=False),
uops_to_llvm_ir,
compile_llvm,
LLVMProgram,
)