add thneed self test (#1535)

* add thneed self test

* don't do the memset in thneed, shouldn't matter though

Co-authored-by: Comma Device <device@comma.ai>
albatross
George Hotz 2020-05-18 11:34:29 -07:00 committed by GitHub
parent 52fe671c53
commit 6c0ad1e675
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 5 deletions

View File

@ -9,8 +9,9 @@ void PrintErrorStringAndExit() {
std::exit(EXIT_FAILURE);
}
SNPEModel::SNPEModel(const char *path, float *loutput, size_t output_size, int runtime) {
SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime) {
output = loutput;
output_size = loutput_size;
#ifdef QCOM
if (runtime==USE_GPU_RUNTIME) {
Runtime = zdl::DlSystem::Runtime_t::GPU;
@ -102,6 +103,7 @@ SNPEModel::SNPEModel(const char *path, float *loutput, size_t output_size, int r
void SNPEModel::addRecurrent(float *state, int state_size) {
recurrent = state;
recurrent_size = state_size;
recurrentBuffer = this->addExtra(state, state_size, 3);
}
@ -134,21 +136,37 @@ std::unique_ptr<zdl::DlSystem::IUserBuffer> SNPEModel::addExtra(float *state, in
void SNPEModel::execute(float *net_input_buf, int buf_size) {
#ifdef USE_THNEED
if (Runtime == zdl::DlSystem::Runtime_t::GPU) {
float *inputs[4] = {recurrent, trafficConvention, desire, net_input_buf};
if (thneed == NULL) {
assert(inputBuffer->setBufferAddress(net_input_buf));
if (!snpe->execute(inputMap, outputMap)) {
PrintErrorStringAndExit();
}
memset(recurrent, 0, recurrent_size*sizeof(float));
thneed = new Thneed();
//thneed->record = 3;
if (!snpe->execute(inputMap, outputMap)) {
PrintErrorStringAndExit();
}
thneed->stop();
//thneed->record = 2;
printf("thneed cached\n");
// doing self test
float *outputs_golden = (float *)malloc(output_size*sizeof(float));
memcpy(outputs_golden, output, output_size*sizeof(float));
memset(output, 0, output_size*sizeof(float));
memset(recurrent, 0, recurrent_size*sizeof(float));
thneed->execute(inputs, output);
if (memcmp(output, outputs_golden, output_size*sizeof(float)) == 0) {
printf("thneed selftest passed\n");
} else {
for (int i = 0; i < output_size; i++) {
printf("mismatch %3d: %f %f\n", i, output[i], outputs_golden[i]);
}
assert(false);
}
free(outputs_golden);
} else {
float *inputs[4] = {recurrent, trafficConvention, desire, net_input_buf};
thneed->execute(inputs, output);
}
} else {

View File

@ -23,7 +23,7 @@
class SNPEModel : public RunModel {
public:
SNPEModel(const char *path, float *loutput, size_t output_size, int runtime);
SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime);
~SNPEModel() {
if (model_data) free(model_data);
}
@ -51,10 +51,12 @@ private:
zdl::DlSystem::UserBufferMap outputMap;
std::unique_ptr<zdl::DlSystem::IUserBuffer> outputBuffer;
float *output;
size_t output_size;
// recurrent and desire
std::unique_ptr<zdl::DlSystem::IUserBuffer> addExtra(float *state, int state_size, int idx);
float *recurrent;
size_t recurrent_size;
std::unique_ptr<zdl::DlSystem::IUserBuffer> recurrentBuffer;
float *trafficConvention;
std::unique_ptr<zdl::DlSystem::IUserBuffer> trafficConventionBuffer;