pytorch/binaries/intra_inter_benchmark.cc

167 lines
4.9 KiB
C++

#include "ATen/ATen.h"
#include "ATen/Parallel.h"
#include "c10/util/Flags.h"
#include "caffe2/core/init.h"
#include <chrono>
#include <condition_variable>
#include <ctime>
#include <iostream>
#include <mutex>
#include <thread>
C10_DEFINE_int(iter_pow, 10, "Number of tasks, 2^N");
C10_DEFINE_int(sub_iter, 1024, "Number of subtasks");
C10_DEFINE_int(warmup_iter_pow, 3, "Number of warmup tasks, 2^N");
C10_DEFINE_int(inter_op_threads, 0, "Number of inter-op threads");
C10_DEFINE_int(intra_op_threads, 0, "Number of intra-op threads");
C10_DEFINE_int(tensor_dim, 50, "Tensor dim");
C10_DEFINE_int(benchmark_iter, 10, "Number of times to run benchmark")
C10_DEFINE_bool(extra_stats, false,
"Collect extra stats; warning: skews results");
C10_DEFINE_string(task_type, "add", "Tensor operation: add or mm");
namespace {
std::atomic<int> counter{0};
int overall_tasks = 0;
std::condition_variable cv;
std::mutex tasks_mutex;
bool run_mm = false;
std::mutex stats_mutex;
std::unordered_set<std::thread::id> tids;
}
void wait() {
std::unique_lock<std::mutex> lk(tasks_mutex);
while (counter < overall_tasks) {
cv.wait(lk);
}
}
void _launch_tasks_tree(
int level, int end_level, at::Tensor& left, at::Tensor& right) {
if (level == end_level) {
at::parallel_for(0, FLAGS_sub_iter, 1,
[&left, &right](int64_t begin, int64_t end) {
if (FLAGS_extra_stats) {
std::unique_lock<std::mutex> lk(stats_mutex);
tids.insert(std::this_thread::get_id());
}
for (auto k = begin; k < end; ++k) {
if (run_mm) {
left.mm(right);
} else {
left.add(right);
}
auto cur_ctr = ++counter;
if (cur_ctr == overall_tasks) {
std::unique_lock<std::mutex> lk(tasks_mutex);
cv.notify_one();
}
}
});
} else {
at::launch([&left, &right, level, end_level]() {
_launch_tasks_tree(level + 1, end_level, left, right);
});
at::launch([&left, &right, level, end_level]() {
_launch_tasks_tree(level + 1, end_level, left, right);
});
}
};
void launch_tasks_and_wait(at::Tensor& left, at::Tensor& right, int iter_pow) {
overall_tasks = pow(2, iter_pow) * FLAGS_sub_iter;
counter = 0;
_launch_tasks_tree(0, iter_pow, left, right);
wait();
}
void reset_extra_stats() {
tids.clear();
}
void print_extra_stats() {
std::cout << "# threads: " << tids.size() << std::endl;
}
void print_runtime_stats(const std::vector<float>& runtimes) {
TORCH_INTERNAL_ASSERT(!runtimes.empty());
float sum = 0.0;
float sqr_sum = 0.0;
size_t N = runtimes.size();
for (size_t idx = 0; idx < N; ++idx) {
sum += runtimes[idx];
sqr_sum += runtimes[idx] * runtimes[idx];
}
float mean = sum / N;
float sd = std::sqrt(sqr_sum / N - mean * mean);
std::cout << "N = " << N << ", mean = " << mean << ", sd = " << sd
<< std::endl;
}
int main(int argc, char** argv) {
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cout << "Failed to parse command line flags" << std::endl;
return -1;
}
caffe2::unsafeRunCaffe2InitFunction("registerThreadPools");
at::init_num_threads();
if (FLAGS_inter_op_threads > 0) {
at::set_num_interop_threads(FLAGS_inter_op_threads);
}
if (FLAGS_intra_op_threads > 0) {
at::set_num_threads(FLAGS_intra_op_threads);
}
TORCH_CHECK(FLAGS_task_type == "add" || FLAGS_task_type == "mm");
run_mm = FLAGS_task_type == "mm";
auto left = at::ones({FLAGS_tensor_dim, FLAGS_tensor_dim}, at::kFloat);
auto right = at::ones({FLAGS_tensor_dim, FLAGS_tensor_dim}, at::kFloat);
std::cout << "Launching " << pow(2, FLAGS_warmup_iter_pow)
<< " warmup tasks" << std::endl;
typedef std::chrono::high_resolution_clock clock;
typedef std::chrono::milliseconds ms;
std::chrono::time_point<clock> start_time = clock::now();
launch_tasks_and_wait(left, right, FLAGS_warmup_iter_pow);
auto duration = static_cast<float>(
std::chrono::duration_cast<ms>(clock::now() - start_time).count());
std::cout << "Warmup time: " << duration << " ms." << std::endl;
std::cout << "Launching " << pow(2, FLAGS_iter_pow) << " tasks with "
<< FLAGS_sub_iter << " subtasks each, using "
<< at::get_num_interop_threads() << " inter-op threads and "
<< at::get_num_threads() << " intra-op threads, "
<< "tensor dim: " << FLAGS_tensor_dim
<< ", task type: " << FLAGS_task_type << std::endl;
std::vector<float> runtimes;
for (auto bench_iter = 0; bench_iter < FLAGS_benchmark_iter; ++bench_iter) {
reset_extra_stats();
start_time = clock::now();
launch_tasks_and_wait(left, right, FLAGS_iter_pow);
duration = static_cast<float>(
std::chrono::duration_cast<ms>(clock::now() - start_time).count());
runtimes.push_back(duration);
if (FLAGS_extra_stats) {
print_extra_stats();
}
std::cout << "Runtime: " << duration << " ms." << std::endl;
}
print_runtime_stats(runtimes);
return 0;
}