1
0
Fork 0

simple fix (#2543)

pull/2544/head^2
George Hotz 2023-12-01 09:42:15 -08:00 committed by GitHub
parent 04483f8187
commit d8175a4380
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 4 deletions

View File

@ -328,9 +328,9 @@ jobs:
# Prefer packages from the rocm repository over system packages
echo -e 'Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' | sudo tee /etc/apt/preferences.d/rocm-pin-600
sudo apt update
sudo apt install --allow-unauthenticated -y rocm-hip-libraries hip-dev
sudo apt install --no-install-recommends --allow-unauthenticated -y rocm-hip-libraries hip-dev
- name: Install Python Dependencies
run: pip install -e '.[testing]'
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test HIP compilation on RDNA3 [gfx1100]
run: |
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/hip/lib

View File

@ -32,7 +32,7 @@ class CUDAProgram:
subprocess.run(["ptxas", f"-arch={CUDADevice.default_arch_name}", "-o", fn, fn+".ptx"], check=True)
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
except Exception as e: print("failed to generate SASS", str(e))
if not CUDACPU:
self.module = init_c_var(cuda.CUmodule(), lambda x: check(cuda.cuModuleLoadData(ctypes.byref(x), prg)))
check(cuda.cuModuleGetFunction(ctypes.byref(prg := cuda.CUfunction()), self.module, name.encode("utf-8")))

View File

@ -68,7 +68,7 @@ class HIPDevice(Compiled):
def __init__(self, device:str):
self.device = int(device.split(":")[1]) if ":" in device else 0
if self.device == 0 and not MOCKHIP: HIPDevice.default_arch_name = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device))).gcnArchName.decode()
from tinygrad.runtime.graph.hip import HIPGraph
super().__init__(MallocAllocator if MOCKHIP else HIPAllocator(self.device), LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, functools.partial(HIPProgram, self.device), HIPGraph)
def synchronize(self): hip.hipDeviceSynchronize()