add thneed optimizer (#23772)

* add thneed optimizer

* local work group opt

* kernels and final mods

* release files

* build system touchups

* fix kernel path, rand inputs for self test

* broken since extra is gone

* update model replay ref

Co-authored-by: Comma Device <device@comma.ai>
pull/23774/head
George Hotz 2022-02-15 16:32:00 -07:00 committed by GitHub
parent 7176f5c401
commit 90beaebefb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 903 additions and 7 deletions

View File

@ -426,7 +426,9 @@ selfdrive/modeld/transforms/transform.cl
selfdrive/modeld/thneed/thneed.*
selfdrive/modeld/thneed/serialize.cc
selfdrive/modeld/thneed/compile.cc
selfdrive/modeld/thneed/optimizer.cc
selfdrive/modeld/thneed/include/*
selfdrive/modeld/thneed/kernels/*.cl
selfdrive/modeld/runners/snpemodel.cc
selfdrive/modeld/runners/snpemodel.h

View File

@ -1,3 +1,5 @@
import os
Import('env', 'arch', 'cereal', 'messaging', 'common', 'gpucommon', 'visionipc')
lenv = env.Clone()
@ -23,6 +25,7 @@ common_src = [
thneed_src = [
"thneed/thneed.cc",
"thneed/serialize.cc",
"thneed/optimizer.cc",
"runners/thneedmodel.cc",
]
@ -62,12 +65,16 @@ common_model = lenv.Object(common_src)
# build thneed model
if use_thneed and arch in ("aarch64", "larch64"):
fn = "../../models/supercombo"
compiler = lenv.Program('thneed/compile', ["thneed/compile.cc"]+common_model, LIBS=libs)
cmd = f"cd {Dir('.').abspath} && {compiler[0].abspath} ../../models/supercombo.dlc ../../models/supercombo.thneed --binary"
cmd = f"cd {Dir('.').abspath} && {compiler[0].abspath} {fn}.dlc {fn}.thneed --binary"
lib_paths = ':'.join(Dir(p).abspath for p in lenv["LIBPATH"])
cenv = Environment(ENV={'LD_LIBRARY_PATH': f"{lib_paths}:{lenv['ENV']['LD_LIBRARY_PATH']}"})
cenv.Command("../../models/supercombo.thneed", ["../../models/supercombo.dlc", compiler], cmd)
kernel_path = os.path.join(Dir('.').abspath, "thneed", "kernels")
cenv = Environment(ENV={'LD_LIBRARY_PATH': f"{lib_paths}:{lenv['ENV']['LD_LIBRARY_PATH']}", 'KERNEL_PATH': kernel_path})
kernels = [os.path.join(kernel_path, x) for x in os.listdir(kernel_path) if x.endswith(".cl")]
cenv.Command(fn + ".thneed", [fn + ".dlc", kernels, compiler], cmd)
lenv.Program('_dmonitoringmodeld', [
"dmonitoringmodeld.cc",

View File

@ -7,6 +7,7 @@
#include <cstring>
#include "selfdrive/common/util.h"
#include "selfdrive/common/timing.h"
void PrintErrorStringAndExit() {
std::cerr << zdl::DlSystem::getLastErrorString() << std::endl;
@ -158,8 +159,14 @@ void SNPEModel::execute(float *net_input_buf, int buf_size) {
float *outputs_golden = (float *)malloc(output_size*sizeof(float));
memcpy(outputs_golden, output, output_size*sizeof(float));
memset(output, 0, output_size*sizeof(float));
memset(recurrent, 0, recurrent_size*sizeof(float));
thneed->execute(inputs, output);
for (int i = 0; i < 5; i++) {
memset(recurrent, 0, recurrent_size*sizeof(float));
uint64_t start_time = nanos_since_boot();
thneed->execute(inputs, output);
uint64_t elapsed_time = nanos_since_boot() - start_time;
printf("ran model in %.2f ms\n", float(elapsed_time)/1e6);
}
if (memcmp(output, outputs_golden, output_size*sizeof(float)) == 0) {
printf("thneed selftest passed\n");

View File

@ -0,0 +1,129 @@
__kernel void convolution_horizontal_reduced_reads(
read_only image2d_t input,
short startPackedInputChannel,
short numPackedInputChannelsForGroup, short totalNumPackedInputChannels,
short packedOuputChannelOffset, short totalNumPackedOutputChannels,
read_only image2d_t weights, __constant float *biases,
short filterSizeX, short filterSizeY,
write_only image2d_t output,
short paddingX, short paddingY, short strideX, short strideY,
short dilationX, short dilationY,
short neuron, float a, float b, float min_clamp, float max_clamp,
__constant float *parameters, __constant float *batchNormBiases,
short numOutputColumns) {
// init
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
short packedOutputChannel = get_global_id(0);
short startOutputColumn = mul24((short)get_global_id(1), 4);
short outputRow = get_global_id(2);
short startX = mad24(mad24(startOutputColumn, strideX, -paddingX),
totalNumPackedInputChannels, startPackedInputChannel);
short strideWithChannels = mul24(strideX, totalNumPackedInputChannels);
float4 outputValues[4];
for (short i = 0; i < 4; ++i) {
outputValues[i] = (float4)(0, 0, 0, 0);
}
int2 inputLocation;
inputLocation.y = mad24(outputRow, strideY, -paddingY);
int2 weightLocation;
weightLocation.x = 0;
weightLocation.y = packedOutputChannel;
// convolution
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
for (short packedInputChannel = 0;
packedInputChannel < numPackedInputChannelsForGroup;
++packedInputChannel) {
short startXForChannel = startX + packedInputChannel;
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
float4 weightValues[4];
for (short outChIdx = 0; outChIdx < 4; ++outChIdx) {
weightValues[outChIdx] = read_imagef(weights, smp, weightLocation);
++weightLocation.x;
}
short dilatedStepX = mul24(totalNumPackedInputChannels, dilationX);
inputLocation.x = mad24(rfColumn, dilatedStepX, startXForChannel);
float4 inputValues[4];
for (short i = 0; i < 4; ++i) {
inputValues[i] = read_imagef(input, smp, inputLocation);
inputLocation.x += strideWithChannels;
}
for (short i = 0; i < 4; ++i) {
float4 curOutputValues = outputValues[i];
curOutputValues.x += inputValues[i].x * weightValues[0].x;
curOutputValues.x += inputValues[i].y * weightValues[0].y;
curOutputValues.x += inputValues[i].z * weightValues[0].z;
curOutputValues.x += inputValues[i].w * weightValues[0].w;
curOutputValues.y += inputValues[i].x * weightValues[1].x;
curOutputValues.y += inputValues[i].y * weightValues[1].y;
curOutputValues.y += inputValues[i].z * weightValues[1].z;
curOutputValues.y += inputValues[i].w * weightValues[1].w;
curOutputValues.z += inputValues[i].x * weightValues[2].x;
curOutputValues.z += inputValues[i].y * weightValues[2].y;
curOutputValues.z += inputValues[i].z * weightValues[2].z;
curOutputValues.z += inputValues[i].w * weightValues[2].w;
curOutputValues.w += inputValues[i].x * weightValues[3].x;
curOutputValues.w += inputValues[i].y * weightValues[3].y;
curOutputValues.w += inputValues[i].z * weightValues[3].z;
curOutputValues.w += inputValues[i].w * weightValues[3].w;
outputValues[i] = curOutputValues;
}
}
}
inputLocation.y += dilationY;
}
// bias
packedOutputChannel += packedOuputChannelOffset;
short outputChannel = mul24(packedOutputChannel, 4);
float4 biasValues = vload4(0, biases + outputChannel);
for (short i = 0; i < 4; ++i) {
outputValues[i] += biasValues;
}
// activation
switch (neuron) {
case 1:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f);
}
break;
case 2:
for (short i = 0; i < 4; ++i) {
outputValues[i] = a * tanh(b * outputValues[i]);
}
break;
case 3:
for (short i = 0; i < 4; ++i) {
outputValues[i] = native_recip(1.0f + native_exp(-a * outputValues[i] + b));
}
break;
case 4:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], min_clamp);
outputValues[i] = min(outputValues[i], max_clamp);
}
break;
case 5:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f) + a * (native_exp(min(outputValues[i], 0.0f)) - 1.0f);
}
break;
}
// output
int2 outputLocation;
short outputColumn = startOutputColumn;
outputLocation.y = outputRow;
for (short i = 0; i < 4; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
if (outputColumn < numOutputColumns) {
write_imagef(output, outputLocation, outputValues[i]);
}
++outputColumn;
}
}

View File

@ -0,0 +1,140 @@
__kernel void convolution_horizontal_reduced_reads_1x1(
read_only image2d_t input,
short startPackedInputChannel,
short numPackedInputChannelsForGroup, short totalNumPackedInputChannels,
short packedOuputChannelOffset, short totalNumPackedOutputChannels,
read_only image2d_t weights, __constant float *biases,
short filterSizeX, short filterSizeY,
write_only image2d_t output,
short paddingX, short paddingY, short strideX, short strideY,
short neuron, float a, float b, float min_clamp, float max_clamp,
__constant float *parameters, __constant float *batchNormBiases,
short numOutputColumns,
short doAccumulate, read_only image2d_t accumulator) {
// init
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
short packedOutputChannel = get_global_id(0);
short startOutputColumn = mul24((short)get_global_id(1), 4);
short outputRow = get_global_id(2);
short endPackedInputChannel = startPackedInputChannel + numPackedInputChannelsForGroup;
short startX = mad24(mad24(startOutputColumn, strideX, -paddingX),
totalNumPackedInputChannels, startPackedInputChannel);
short strideWithChannels = mul24(strideX, totalNumPackedInputChannels);
float4 outputValues[4];
for (short i = 0; i < 4; ++i) {
outputValues[i] = (float4)(0, 0, 0, 0);
}
int2 inputLocation;
inputLocation.y = mad24(outputRow, strideY, -paddingY);
int2 weightLocation;
weightLocation.x = 0;
weightLocation.y = packedOutputChannel;
// convolution
for (short packedInputChannel = startPackedInputChannel;
packedInputChannel < endPackedInputChannel; ++packedInputChannel) {
float4 weightValues[4];
for (short outChIdx = 0; outChIdx < 4; ++outChIdx) {
weightValues[outChIdx] = read_imagef(weights, smp, weightLocation);
++weightLocation.x;
}
inputLocation.x = startX + packedInputChannel;
float4 inputValues[4];
for (short i = 0; i < 4; ++i) {
inputValues[i] = read_imagef(input, smp, inputLocation);
inputLocation.x += strideWithChannels;
}
for (short i = 0; i < 4; ++i) {
float4 curOutputValues = outputValues[i];
curOutputValues.x += inputValues[i].x * weightValues[0].x;
curOutputValues.x += inputValues[i].y * weightValues[0].y;
curOutputValues.x += inputValues[i].z * weightValues[0].z;
curOutputValues.x += inputValues[i].w * weightValues[0].w;
curOutputValues.y += inputValues[i].x * weightValues[1].x;
curOutputValues.y += inputValues[i].y * weightValues[1].y;
curOutputValues.y += inputValues[i].z * weightValues[1].z;
curOutputValues.y += inputValues[i].w * weightValues[1].w;
curOutputValues.z += inputValues[i].x * weightValues[2].x;
curOutputValues.z += inputValues[i].y * weightValues[2].y;
curOutputValues.z += inputValues[i].z * weightValues[2].z;
curOutputValues.z += inputValues[i].w * weightValues[2].w;
curOutputValues.w += inputValues[i].x * weightValues[3].x;
curOutputValues.w += inputValues[i].y * weightValues[3].y;
curOutputValues.w += inputValues[i].z * weightValues[3].z;
curOutputValues.w += inputValues[i].w * weightValues[3].w;
outputValues[i] = curOutputValues;
}
}
// bias
packedOutputChannel += packedOuputChannelOffset;
short outputChannel = mul24(packedOutputChannel, 4);
float4 biasValues = vload4(0, biases + outputChannel);
for (short i = 0; i < 4; ++i) {
outputValues[i] += biasValues;
}
// accumulate
if (doAccumulate) {
int2 outputLocation;
short outputColumn = startOutputColumn;
outputLocation.y = outputRow;
for (short i = 0; i < 4; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
if (outputColumn < numOutputColumns) {
outputValues[i] += read_imagef(accumulator, smp, outputLocation);
}
++outputColumn;
}
}
// activation
switch (neuron) {
case 1:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f);
}
break;
case 2:
for (short i = 0; i < 4; ++i) {
outputValues[i] = a * tanh(b * outputValues[i]);
}
break;
case 3:
for (short i = 0; i < 4; ++i) {
outputValues[i] = native_recip(1.0f + native_exp(-a * outputValues[i] + b));
}
break;
case 4:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], min_clamp);
outputValues[i] = min(outputValues[i], max_clamp);
}
break;
case 5:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f) + a * (native_exp(min(outputValues[i], 0.0f)) - 1.0f);
}
break;
}
// output
int2 outputLocation;
short outputColumn = startOutputColumn;
outputLocation.y = outputRow;
for (short i = 0; i < 4; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
if (outputColumn < numOutputColumns) {
write_imagef(output, outputLocation, outputValues[i]);
}
++outputColumn;
}
}

View File

@ -0,0 +1,130 @@
__kernel void convolution_horizontal_reduced_reads_5_outputs(
read_only image2d_t input,
short startPackedInputChannel,
short numPackedInputChannelsForGroup, short totalNumPackedInputChannels,
short packedOuputChannelOffset, short totalNumPackedOutputChannels,
read_only image2d_t weights, __constant float *biases,
short filterSizeX, short filterSizeY,
write_only image2d_t output,
short paddingX, short paddingY, short strideX, short strideY,
short neuron, float a, float b, float min_clamp, float max_clamp,
__constant float *parameters, __constant float *batchNormBiases,
short numOutputColumns) {
// init
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
short packedOutputChannel = get_global_id(0);
short startOutputColumn = mul24((short)get_global_id(1), 5);
short outputRow = get_global_id(2);
short startX = mad24(mad24(startOutputColumn, strideX, -paddingX),
totalNumPackedInputChannels, startPackedInputChannel);
short strideWithChannels = mul24(strideX, totalNumPackedInputChannels);
float4 outputValues[5];
for (short i = 0; i < 5; ++i) {
outputValues[i] = (float4)(0, 0, 0, 0);
}
int2 inputLocation;
inputLocation.y = mad24(outputRow, strideY, -paddingY);
int2 weightLocation;
weightLocation.x = 0;
weightLocation.y = packedOutputChannel;
// convolution
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
for (short packedInputChannel = 0;
packedInputChannel < numPackedInputChannelsForGroup;
++packedInputChannel) {
short startXForChannel = startX + packedInputChannel;
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
float4 weightValues[4];
for (short outChIdx = 0; outChIdx < 4; ++outChIdx) {
weightValues[outChIdx] = read_imagef(weights, smp, weightLocation);
++weightLocation.x;
}
inputLocation.x =
mad24(rfColumn, totalNumPackedInputChannels, startXForChannel);
float4 inputValues[5];
for (short i = 0; i < 5; ++i) {
inputValues[i] = read_imagef(input, smp, inputLocation);
inputLocation.x += strideWithChannels;
}
for (short i = 0; i < 5; ++i) {
float4 curOutputValues = outputValues[i];
curOutputValues.x += inputValues[i].x * weightValues[0].x;
curOutputValues.x += inputValues[i].y * weightValues[0].y;
curOutputValues.x += inputValues[i].z * weightValues[0].z;
curOutputValues.x += inputValues[i].w * weightValues[0].w;
curOutputValues.y += inputValues[i].x * weightValues[1].x;
curOutputValues.y += inputValues[i].y * weightValues[1].y;
curOutputValues.y += inputValues[i].z * weightValues[1].z;
curOutputValues.y += inputValues[i].w * weightValues[1].w;
curOutputValues.z += inputValues[i].x * weightValues[2].x;
curOutputValues.z += inputValues[i].y * weightValues[2].y;
curOutputValues.z += inputValues[i].z * weightValues[2].z;
curOutputValues.z += inputValues[i].w * weightValues[2].w;
curOutputValues.w += inputValues[i].x * weightValues[3].x;
curOutputValues.w += inputValues[i].y * weightValues[3].y;
curOutputValues.w += inputValues[i].z * weightValues[3].z;
curOutputValues.w += inputValues[i].w * weightValues[3].w;
outputValues[i] = curOutputValues;
}
}
}
++inputLocation.y;
}
// bias
packedOutputChannel += packedOuputChannelOffset;
short outputChannel = mul24(packedOutputChannel, 4);
float4 biasValues = vload4(0, biases + outputChannel);
for (short i = 0; i < 5; ++i) {
outputValues[i] += biasValues;
}
// activation
switch (neuron) {
case 1:
for (short i = 0; i < 5; ++i) {
outputValues[i] = max(outputValues[i], 0.0f);
}
break;
case 2:
for (short i = 0; i < 5; ++i) {
outputValues[i] = a * tanh(b * outputValues[i]);
}
break;
case 3:
for (short i = 0; i < 5; ++i) {
outputValues[i] = native_recip(1.0f + native_exp(-a * outputValues[i] + b));
}
break;
case 4:
for (short i = 0; i < 5; ++i) {
outputValues[i] = max(outputValues[i], min_clamp);
outputValues[i] = min(outputValues[i], max_clamp);
}
break;
case 5:
for (short i = 0; i < 5; ++i) {
outputValues[i] = max(outputValues[i], 0.0f) + a * (native_exp(min(outputValues[i], 0.0f)) - 1.0f);
}
break;
}
// output
int2 outputLocation;
short outputColumn = startOutputColumn;
outputLocation.y = outputRow;
for (short i = 0; i < 5; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
if (outputColumn < numOutputColumns) {
write_imagef(output, outputLocation, outputValues[i]);
}
++outputColumn;
}
}

View File

@ -0,0 +1,101 @@
__kernel void convolution_horizontal_reduced_reads_depthwise(
read_only image2d_t input,
short totalNumPackedChannels,
read_only image2d_t weights, __constant float *biases,
short filterSizeX, short filterSizeY,
write_only image2d_t output,
short paddingX, short paddingY, short strideX, short strideY,
short dilationX, short dilationY,
short neuron, float a, float b, float min_clamp, float max_clamp,
short numOutputColumns) {
// init
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
short packedChannel = get_global_id(0);
short startOutputColumn = mul24((short)get_global_id(1), 4);
short outputRow = get_global_id(2);
short startXForChannel = mad24(mad24(startOutputColumn, strideX, -paddingX),
totalNumPackedChannels, packedChannel);
short strideWithChannels = mul24(strideX, totalNumPackedChannels);
float4 outputValues[4];
for (short i = 0; i < 4; ++i) {
outputValues[i] = (float4)(0, 0, 0, 0);
}
int2 inputLocation;
inputLocation.y = mad24(outputRow, strideY, -paddingY);
int2 weightLocation;
weightLocation.x = 0;
weightLocation.y = packedChannel;
// convolution
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
short dilatedStepX = mul24(totalNumPackedChannels, dilationX);
inputLocation.x = mad24(rfColumn, dilatedStepX, startXForChannel);
float4 inputValues[4];
for (short i = 0; i < 4; ++i) {
inputValues[i] = read_imagef(input, smp, inputLocation);
inputLocation.x += strideWithChannels;
}
float4 weightValues = read_imagef(weights, smp, weightLocation);
++weightLocation.x;
outputValues[0] += inputValues[0] * weightValues;
outputValues[1] += inputValues[1] * weightValues;
outputValues[2] += inputValues[2] * weightValues;
outputValues[3] += inputValues[3] * weightValues;
}
inputLocation.y += dilationY;
}
// bias
short outputChannel = mul24(packedChannel, 4);
float4 biasValues = vload4(0, biases + outputChannel);
for (short i = 0; i < 4; ++i) {
outputValues[i] += biasValues;
}
// activation
switch (neuron) {
case 1:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f);
}
break;
case 2:
for (short i = 0; i < 4; ++i) {
outputValues[i] = a * tanh(b * outputValues[i]);
}
break;
case 3:
for (short i = 0; i < 4; ++i) {
outputValues[i] = native_recip(1.0f + native_exp(-a * outputValues[i] + b));
}
break;
case 4:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], min_clamp);
outputValues[i] = min(outputValues[i], max_clamp);
}
break;
case 5:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f) + a * (native_exp(min(outputValues[i], 0.0f)) - 1.0f);
}
break;
}
// output
int2 outputLocation;
short outputColumn = startOutputColumn;
outputLocation.y = outputRow;
for (short i = 0; i < 4; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedChannels, packedChannel);
if (outputColumn < numOutputColumns) {
write_imagef(output, outputLocation, outputValues[i]);
}
++outputColumn;
}
}

View File

@ -0,0 +1,103 @@
__kernel void convolution_horizontal_reduced_reads_depthwise_stride_1(
read_only image2d_t input,
short totalNumPackedChannels,
read_only image2d_t weights, __constant float *biases,
short filterSizeX, short filterSizeY,
write_only image2d_t output,
short paddingX, short paddingY, short strideX, short strideY,
short neuron, float a, float b, float min_clamp, float max_clamp,
short numOutputColumns) {
// init
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
short packedChannel = get_global_id(0);
short startOutputColumn = mul24((short)get_global_id(1), 4);
short outputRow = get_global_id(2);
short startXForChannel = mad24(mad24(startOutputColumn, strideX, -paddingX),
totalNumPackedChannels, packedChannel);
float4 outputValues[4];
for (short i = 0; i < 4; ++i) {
outputValues[i] = (float4)(0, 0, 0, 0);
}
int2 inputLocation;
inputLocation.y = mad24(outputRow, strideY, -paddingY);
int2 weightLocation;
weightLocation.x = 0;
weightLocation.y = packedChannel;
// convolution
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
float4 inputValues[4];
inputLocation.x = startXForChannel;
for (short i = 1; i < 4; ++i) {
inputValues[i] = read_imagef(input, smp, inputLocation);
inputLocation.x += totalNumPackedChannels;
}
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
inputValues[0] = inputValues[1];
inputValues[1] = inputValues[2];
inputValues[2] = inputValues[3];
inputValues[3] = read_imagef(input, smp, inputLocation);
inputLocation.x += totalNumPackedChannels;
float4 weightValues = read_imagef(weights, smp, weightLocation);
++weightLocation.x;
outputValues[0] += inputValues[0] * weightValues;
outputValues[1] += inputValues[1] * weightValues;
outputValues[2] += inputValues[2] * weightValues;
outputValues[3] += inputValues[3] * weightValues;
}
++inputLocation.y;
}
// bias
short outputChannel = mul24(packedChannel, 4);
float4 biasValues = vload4(0, biases + outputChannel);
for (short i = 0; i < 4; ++i) {
outputValues[i] += biasValues;
}
// activation
switch (neuron) {
case 1:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f);
}
break;
case 2:
for (short i = 0; i < 4; ++i) {
outputValues[i] = a * tanh(b * outputValues[i]);
}
break;
case 3:
for (short i = 0; i < 4; ++i) {
outputValues[i] = native_recip(1.0f + native_exp(-a * outputValues[i] + b));
}
break;
case 4:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], min_clamp);
outputValues[i] = min(outputValues[i], max_clamp);
}
break;
case 5:
for (short i = 0; i < 4; ++i) {
outputValues[i] = max(outputValues[i], 0.0f) + a * (native_exp(min(outputValues[i], 0.0f)) - 1.0f);
}
break;
}
// output
int2 outputLocation;
short outputColumn = startOutputColumn;
outputLocation.y = outputRow;
for (short i = 0; i < 4; ++i) {
outputLocation.x = mad24(outputColumn, totalNumPackedChannels, packedChannel);
if (outputColumn < numOutputColumns) {
write_imagef(output, outputLocation, outputValues[i]);
}
++outputColumn;
}
}

View File

@ -0,0 +1,259 @@
#include <map>
#include <string>
#include <string.h>
#include <assert.h>
#include "thneed.h"
extern map<cl_program, string> g_program_source;
static int is_same_size_image(cl_mem a, cl_mem b) {
size_t a_width, a_height, a_depth, a_array_size, a_row_pitch, a_slice_pitch;
clGetImageInfo(a, CL_IMAGE_WIDTH, sizeof(a_width), &a_width, NULL);
clGetImageInfo(a, CL_IMAGE_HEIGHT, sizeof(a_height), &a_height, NULL);
clGetImageInfo(a, CL_IMAGE_DEPTH, sizeof(a_depth), &a_depth, NULL);
clGetImageInfo(a, CL_IMAGE_ARRAY_SIZE, sizeof(a_array_size), &a_array_size, NULL);
clGetImageInfo(a, CL_IMAGE_ROW_PITCH, sizeof(a_row_pitch), &a_row_pitch, NULL);
clGetImageInfo(a, CL_IMAGE_SLICE_PITCH, sizeof(a_slice_pitch), &a_slice_pitch, NULL);
size_t b_width, b_height, b_depth, b_array_size, b_row_pitch, b_slice_pitch;
clGetImageInfo(b, CL_IMAGE_WIDTH, sizeof(b_width), &b_width, NULL);
clGetImageInfo(b, CL_IMAGE_HEIGHT, sizeof(b_height), &b_height, NULL);
clGetImageInfo(b, CL_IMAGE_DEPTH, sizeof(b_depth), &b_depth, NULL);
clGetImageInfo(b, CL_IMAGE_ARRAY_SIZE, sizeof(b_array_size), &b_array_size, NULL);
clGetImageInfo(b, CL_IMAGE_ROW_PITCH, sizeof(b_row_pitch), &b_row_pitch, NULL);
clGetImageInfo(b, CL_IMAGE_SLICE_PITCH, sizeof(b_slice_pitch), &b_slice_pitch, NULL);
return (a_width == b_width) && (a_height == b_height) &&
(a_depth == b_depth) && (a_array_size == b_array_size) &&
(a_row_pitch == b_row_pitch) && (a_slice_pitch == b_slice_pitch);
}
static cl_mem make_image_like(cl_context context, cl_mem val) {
cl_image_format format;
size_t width, height, row_pitch;
clGetImageInfo(val, CL_IMAGE_FORMAT, sizeof(format), &format, NULL);
assert(format.image_channel_order == CL_RGBA);
assert(format.image_channel_data_type == CL_HALF_FLOAT);
clGetImageInfo(val, CL_IMAGE_WIDTH, sizeof(width), &width, NULL);
clGetImageInfo(val, CL_IMAGE_HEIGHT, sizeof(height), &height, NULL);
clGetImageInfo(val, CL_IMAGE_ROW_PITCH, sizeof(row_pitch), &row_pitch, NULL);
cl_image_desc desc = {0};
desc.image_type = CL_MEM_OBJECT_IMAGE2D;
desc.image_width = width;
desc.image_height = height;
desc.image_row_pitch = row_pitch;
cl_mem buf = clCreateBuffer(context, CL_MEM_READ_WRITE, row_pitch*height, NULL, NULL);
assert(buf != NULL);
desc.buffer = buf;
cl_int err;
cl_mem tmp = clCreateImage(context, CL_MEM_READ_WRITE, &format, &desc, NULL, &err);
//printf("got %d for image %zux%zu %zu\n", err, width, height, row_pitch);
assert(tmp != NULL);
return tmp;
}
// convolution_horizontal_reduced_reads_1x1 is 66% of the model runtime
// make that faster and the model gets faster
// this cuts ~2 ms off the model runtime right now
int Thneed::optimize() {
const char *kernel_path = getenv("KERNEL_PATH");
if (!kernel_path) { kernel_path = "/data/openpilot/selfdrive/modeld/thneed/kernels"; printf("no KERNEL_PATH set, defaulting to %s\n", kernel_path); }
// load custom kernels
map<string, cl_program> g_programs;
for (auto &k : kq) {
// replace program?
if (g_programs.find(k->name) == g_programs.end()) {
char fn[0x100];
snprintf(fn, sizeof(fn), "%s/%s.cl", kernel_path, k->name.c_str());
FILE *g = fopen(fn, "rb");
if (g != NULL) {
char *src[0x10000];
const char *srcs[1]; srcs[0] = (const char *)src;
memset(src, 0, sizeof(src));
size_t length = fread(src, 1, sizeof(src), g);
fclose(g);
printf("building kernel %s\n", k->name.c_str());
k->program = clCreateProgramWithSource(context, 1, srcs, &length, NULL);
int err = clBuildProgram(k->program, 1, &device_id, "", NULL, NULL);
if (err != 0) {
printf("got err %d\n", err);
size_t err_length;
char buffer[2048];
clGetProgramBuildInfo(k->program, device_id, CL_PROGRAM_BUILD_LOG, sizeof(buffer), buffer, &err_length);
buffer[err_length] = '\0';
printf("%s\n", buffer);
}
assert(err == 0);
// save in cache
g_programs[k->name] = k->program;
g_program_source[k->program] = string((char *)src, length);
} else {
g_programs[k->name] = NULL;
}
} else {
// cached replacement
if (g_programs[k->name] != NULL) {
k->program = g_programs[k->name];
}
}
// hack in accumulator to convolution_horizontal_reduced_reads_1x1
if (k->name == "convolution_horizontal_reduced_reads_1x1") {
k->arg_names.push_back("doAccumulate");
short doAccumulate = 0;
k->args.push_back(string((char *)&doAccumulate, sizeof(doAccumulate)));
k->args_size.push_back(2);
k->arg_names.push_back("accumulator");
k->args.push_back(k->args[k->get_arg_num("output")]);
k->args_size.push_back(8);
k->num_args += 2;
}
// assert that parameters + batchNormBiases are not used
// since they aren't supported in custom replacement kernels
if (k->name == "convolution_horizontal_reduced_reads_1x1" ||
k->name == "convolution_horizontal_reduced_reads" ||
k->name == "convolution_horizontal_reduced_reads_5_outputs") {
string p1 = k->args[k->get_arg_num("parameters")];
string p2 = k->args[k->get_arg_num("batchNormBiases")];
assert(p1.length() == 8 && *((uint64_t*)p1.data()) == 0);
assert(p2.length() == 8 && *((uint64_t*)p2.data()) == 0);
}
}
// optimizer
size_t start_size;
do {
start_size = kq.size();
// get optimizations
map<string, string> replacements;
for (int i = 0; i < kq.size(); i++) {
// fusing elementwise_sum + activate_image will save 3 enqueues
// delete useless copy layers
// saves ~0.7 ms
if (kq[i]->name == "concatenation" || kq[i]->name == "flatten") {
string in = kq[i]->args[kq[i]->get_arg_num("input")];
string out = kq[i]->args[kq[i]->get_arg_num("output")];
if (is_same_size_image(*(cl_mem*)in.data(), *(cl_mem*)out.data())) {
cl_mem tmp = make_image_like(context, *(cl_mem *)in.data());
replacements[in] = string((char *)&tmp, sizeof(tmp));
replacements[out] = string((char *)&tmp, sizeof(tmp));
kq.erase(kq.begin()+i); --i;
}
}
// NOTE: if activations/accumulation are done in the wrong order, this will be wrong
// fuse activations into convs and fc_Wtx
// saves ~1.5 ms
// NOTE: this changes the outputs because of rounding, should be better now!
if (i != 0 && kq[i]->name == "activate_image") {
if (kq[i-1]->name == "convolution_horizontal_reduced_reads_1x1" ||
kq[i-1]->name == "convolution_horizontal_reduced_reads_5_outputs" ||
kq[i-1]->name == "convolution_horizontal_reduced_reads" ||
kq[i-1]->name == "convolution_horizontal_reduced_reads_depthwise" ||
kq[i-1]->name == "convolution_horizontal_reduced_reads_depthwise_stride_1" ||
kq[i-1]->name == "fc_Wtx") {
string lastout = kq[i-1]->args[kq[i-1]->get_arg_num("output")];
string in = kq[i]->args[kq[i]->get_arg_num("input")];
string out = kq[i]->args[kq[i]->get_arg_num("output")];
if (lastout == in) {
short neuron = *(int*)kq[i]->args[kq[i]->get_arg_num("neuron")].data();
kq[i-1]->args[kq[i-1]->get_arg_num("neuron")] = string((char *)&neuron, sizeof(neuron));
cl_mem tmp = make_image_like(context, *(cl_mem *)lastout.data());
replacements[in] = string((char *)&tmp, sizeof(tmp));
replacements[out] = string((char *)&tmp, sizeof(tmp));
kq.erase(kq.begin()+i); --i;
}
}
}
// fuse accumulation into convs and fc_Wtx
if (i != 0 && kq[i]->name == "elementwise_sum") {
if (kq[i-1]->name == "convolution_horizontal_reduced_reads_1x1" ||
kq[i-1]->name == "fc_Wtx") {
string lastout = kq[i-1]->args[kq[i-1]->get_arg_num("output")];
string a = kq[i]->args[kq[i]->get_arg_num("a")];
string b = kq[i]->args[kq[i]->get_arg_num("b")];
string out = kq[i]->args[kq[i]->get_arg_num("output")];
if (lastout == a) {
kq[i-1]->args[kq[i-1]->get_arg_num("accumulator")] = b;
} else if (lastout == b) {
kq[i-1]->args[kq[i-1]->get_arg_num("accumulator")] = a;
} else {
continue;
}
cl_mem tmp = make_image_like(context, *(cl_mem *)lastout.data());
replacements[lastout] = string((char *)&tmp, sizeof(tmp));
replacements[out] = string((char *)&tmp, sizeof(tmp));
short doAccumulate = 1;
kq[i-1]->args[kq[i-1]->get_arg_num("doAccumulate")] = string((char *)&doAccumulate, sizeof(doAccumulate));
kq.erase(kq.begin()+i); --i;
}
}
}
// remap inputs and outputs, and clear the kernels
for (int i = 0; i < kq.size(); i++) {
kq[i]->kernel = NULL;
for (int j = 0; j < kq[i]->num_args; j++) {
if (replacements.find(kq[i]->args[j]) != replacements.end()) {
kq[i]->args[j] = replacements[kq[i]->args[j]];
}
}
}
printf("optimize %lu -> %lu\n", start_size, kq.size());
} while (kq.size() != start_size);
size_t work_group_size = 0;
clGetDeviceInfo(device_id, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(work_group_size), &work_group_size, NULL);
printf("max work group size %lu\n", work_group_size);
// local work group optimizer
for (auto &k : kq) {
// only do it for convs, since others might share memory
if (k->name.rfind("convolution_", 0) == 0) {
int best = -1;
if (k->local_work_size[0] * k->local_work_size[1] * k->local_work_size[2] < work_group_size/2) {
uint64_t base_time = k->benchmark();
uint64_t best_time = base_time;
for (int i = 0; i < 3; i++) {
k->local_work_size[i] *= 2;
uint64_t this_time = k->benchmark();
if (this_time < best_time) {
best = i;
best_time = this_time;
}
k->local_work_size[i] /= 2;
}
if (best != -1) {
k->local_work_size[best] *= 2;
//printf("%s %.2f ms doubled %d to %.2f ms\n", k->name.c_str(), base_time/1e6, best, best_time/1e6);
}
}
}
}
return 0;
}

View File

@ -12,7 +12,7 @@
#include "selfdrive/common/clutil.h"
#include "selfdrive/common/timing.h"
//#define RUN_DISASSEMBLER
//#define RUN_OPTIMIZER
#define RUN_OPTIMIZER
Thneed *g_thneed = NULL;
int g_fd = -1;
@ -528,6 +528,23 @@ cl_int CLQueuedKernel::exec() {
kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL);
}
uint64_t CLQueuedKernel::benchmark() {
uint64_t ret = 0;
int old_record = thneed->record;
thneed->record = 0;
clFinish(thneed->command_queue);
// TODO: benchmarking at a lower level will make this more accurate
for (int i = 0; i < 10; i++) {
uint64_t sb = nanos_since_boot();
exec();
clFinish(thneed->command_queue);
uint64_t et = nanos_since_boot() - sb;
if (ret == 0 || et < ret) ret = et;
}
thneed->record = old_record;
return ret;
}
void CLQueuedKernel::debug_print(bool verbose) {
printf("%p %56s -- ", kernel, name.c_str());
for (int i = 0; i < work_dim; i++) {

View File

@ -44,6 +44,7 @@ class CLQueuedKernel {
const size_t *_global_work_size,
const size_t *_local_work_size);
cl_int exec();
uint64_t benchmark();
void debug_print(bool verbose);
int get_arg_num(const char *search_arg_name);
cl_program program;

View File

@ -1 +1 @@
7d3ad941bc4ba4c923af7a1d7b48544bfc0d3e13
0c94f7c258bcdabc34c7f7be6cb6c2502afbb339