pytorch/ufunc_defs.bzl

26 lines
800 B
Python

load("@bazel_skylib//lib:paths.bzl", "paths")
load(":build_variables.bzl", "aten_ufunc_headers")
aten_ufunc_names = [
paths.split_extension(paths.basename(h))[0]
for h in aten_ufunc_headers
]
def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"):
return [gencode_pattern.format(name) for name in [
"UfuncCPU_{}.cpp".format(n)
for n in aten_ufunc_names
]]
def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"):
return [gencode_pattern.format(name) for name in [
"UfuncCPUKernel_{}.cpp".format(n)
for n in aten_ufunc_names
]]
def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"):
return [gencode_pattern.format(name) for name in [
"UfuncCUDA_{}.cu".format(n)
for n in aten_ufunc_names
]]