You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
2.4 KiB

  1. #pragma once
  2. #include "unbounded_queue.hpp"
  3. #include <tuple>
  4. #include <atomic>
  5. #include <vector>
  6. #include <thread>
  7. #include <memory>
  8. #include <future>
  9. #include <utility>
  10. #include <functional>
  11. #include <type_traits>
  12. class ThreadPool
  13. {
  14. public:
  15. explicit
  16. ThreadPool(const std::size_t thread_count_ = std::thread::hardware_concurrency())
  17. : _queues(thread_count_),
  18. _count(thread_count_)
  19. {
  20. auto worker = [this](std::size_t i)
  21. {
  22. while(true)
  23. {
  24. Proc f;
  25. for(std::size_t n = 0; n < (_count * K); ++n)
  26. {
  27. if(_queues[(i + n) % _count].try_pop(f))
  28. break;
  29. }
  30. if(!f && !_queues[i].pop(f))
  31. break;
  32. f();
  33. }
  34. };
  35. _threads.reserve(thread_count_);
  36. for(std::size_t i = 0; i < thread_count_; ++i)
  37. _threads.emplace_back(worker, i);
  38. }
  39. ~ThreadPool()
  40. {
  41. for(auto& queue : _queues)
  42. queue.unblock();
  43. for(auto& thread : _threads)
  44. thread.join();
  45. }
  46. template<typename F>
  47. void
  48. enqueue_work(F&& f_)
  49. {
  50. auto i = _index++;
  51. for(std::size_t n = 0; n < (_count * K); ++n)
  52. {
  53. if(_queues[(i + n) % _count].try_push(f_))
  54. return;
  55. }
  56. _queues[i % _count].push(std::move(f_));
  57. }
  58. template<typename F>
  59. [[nodiscard]]
  60. std::future<typename std::result_of<F()>::type>
  61. enqueue_task(F&& f_)
  62. {
  63. using TaskReturnType = typename std::result_of<F()>::type;
  64. using Promise = std::promise<TaskReturnType>;
  65. auto i = _index++;
  66. auto promise = std::make_shared<Promise>();
  67. auto future = promise->get_future();
  68. auto work = [=]() {
  69. auto rv = f_();
  70. promise->set_value(rv);
  71. };
  72. for(std::size_t n = 0; n < (_count * K); ++n)
  73. {
  74. if(_queues[(i + n) % _count].try_push(work))
  75. return future;
  76. }
  77. _queues[i % _count].push(std::move(work));
  78. return future;
  79. }
  80. public:
  81. std::vector<pthread_t>
  82. threads()
  83. {
  84. std::vector<pthread_t> rv;
  85. for(auto &thread : _threads)
  86. rv.push_back(thread.native_handle());
  87. return rv;
  88. }
  89. private:
  90. using Proc = std::function<void(void)>;
  91. using Queue = UnboundedQueue<Proc>;
  92. using Queues = std::vector<Queue>;
  93. Queues _queues;
  94. private:
  95. std::vector<std::thread> _threads;
  96. private:
  97. const std::size_t _count;
  98. std::atomic_uint _index;
  99. static const unsigned int K = 2;
  100. };