146 lines
3.8 KiB
C++
146 lines
3.8 KiB
C++
#include <c10/core/thread_pool.h>
|
|
#include <c10/util/Logging.h>
|
|
|
|
namespace c10 {
|
|
|
|
ThreadPool::ThreadPool(
|
|
int pool_size,
|
|
int numa_node_id,
|
|
const std::function<void()>& init_thread)
|
|
: threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
|
|
running_(true),
|
|
complete_(true),
|
|
available_(threads_.size()),
|
|
total_(threads_.size()),
|
|
numa_node_id_(numa_node_id) {
|
|
for (std::size_t i = 0; i < threads_.size(); ++i) {
|
|
threads_[i] = std::thread([this, i, init_thread]() {
|
|
if (init_thread) {
|
|
init_thread();
|
|
}
|
|
this->main_loop(i);
|
|
});
|
|
}
|
|
}
|
|
|
|
ThreadPool::~ThreadPool() {
|
|
// Set running flag to false then notify all threads.
|
|
{
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
running_ = false;
|
|
condition_.notify_all();
|
|
}
|
|
|
|
for (auto& t : threads_) {
|
|
try {
|
|
t.join();
|
|
} catch (const std::exception&) {
|
|
}
|
|
}
|
|
}
|
|
|
|
size_t ThreadPool::size() const {
|
|
return threads_.size();
|
|
}
|
|
|
|
size_t ThreadPool::numAvailable() const {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
return available_;
|
|
}
|
|
|
|
bool ThreadPool::inThreadPool() const {
|
|
for (auto& thread : threads_) {
|
|
if (thread.get_id() == std::this_thread::get_id()) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void ThreadPool::run(std::function<void()> func) {
|
|
if (threads_.empty()) {
|
|
throw std::runtime_error("No threads to run a task");
|
|
}
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
// Set task and signal condition variable so that a worker thread will
|
|
// wake up and use the task.
|
|
tasks_.emplace(std::move(func));
|
|
complete_ = false;
|
|
condition_.notify_one();
|
|
}
|
|
|
|
void ThreadPool::waitWorkComplete() {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
completed_.wait(lock, [&]() { return complete_; });
|
|
}
|
|
|
|
void ThreadPool::main_loop(std::size_t index) {
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
while (running_) {
|
|
// Wait on condition variable while the task is empty and
|
|
// the pool is still running.
|
|
condition_.wait(lock, [&]() { return !tasks_.empty() || !running_; });
|
|
// If pool is no longer running, break out of loop.
|
|
if (!running_) {
|
|
break;
|
|
}
|
|
|
|
// Copy task locally and remove from the queue. This is
|
|
// done within its own scope so that the task object is
|
|
// destructed immediately after running the task. This is
|
|
// useful in the event that the function contains
|
|
// shared_ptr arguments bound via bind.
|
|
{
|
|
task_element_t tasks = std::move(tasks_.front());
|
|
tasks_.pop();
|
|
// Decrement count, indicating thread is no longer available.
|
|
--available_;
|
|
|
|
lock.unlock();
|
|
|
|
// Run the task.
|
|
try {
|
|
if (tasks.run_with_id) {
|
|
tasks.with_id(index);
|
|
} else {
|
|
tasks.no_id();
|
|
}
|
|
} catch (const std::exception& e) {
|
|
LOG(ERROR) << "Exception in thread pool task: " << e.what();
|
|
} catch (...) {
|
|
LOG(ERROR) << "Exception in thread pool task: unknown";
|
|
}
|
|
|
|
// Destruct tasks before taking the lock. As tasks
|
|
// are user provided std::function, they can run
|
|
// arbitrary code during destruction, including code
|
|
// that can reentrantly call into ThreadPool (which would
|
|
// cause a deadlock if we were holding the lock).
|
|
}
|
|
|
|
// Update status of empty, maybe
|
|
// Need to recover the lock first
|
|
lock.lock();
|
|
|
|
// Increment count, indicating thread is available.
|
|
++available_;
|
|
if (tasks_.empty() && available_ == total_) {
|
|
complete_ = true;
|
|
completed_.notify_one();
|
|
}
|
|
|
|
// Deliberately hold the lock on the backedge, so this thread has an
|
|
// opportunity to acquire a new task before another thread acquires
|
|
// the lock.
|
|
} // while running_
|
|
}
|
|
|
|
C10_DEFINE_SHARED_REGISTRY(
|
|
ThreadPoolRegistry,
|
|
TaskThreadPoolBase,
|
|
int,
|
|
int,
|
|
bool);
|
|
} // namespace c10
|