You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
2.8 KiB

  1. #pragma once
  2. #include "bounded_queue.hpp"
  3. #include "make_unique.hpp"
  4. #include <signal.h>
  5. #include <tuple>
  6. #include <atomic>
  7. #include <vector>
  8. #include <thread>
  9. #include <memory>
  10. #include <future>
  11. #include <utility>
  12. #include <functional>
  13. #include <type_traits>
  14. class BoundedThreadPool
  15. {
  16. private:
  17. using Proc = std::function<void(void)>;
  18. using Queue = BoundedQueue<Proc>;
  19. using Queues = std::vector<std::unique_ptr<Queue>>;
  20. public:
  21. explicit
  22. BoundedThreadPool(const std::size_t thread_count_ = std::thread::hardware_concurrency(),
  23. const std::size_t queue_depth_ = 1)
  24. : _queues(),
  25. _count(thread_count_)
  26. {
  27. for(std::size_t i = 0; i < thread_count_; i++)
  28. _queues.emplace_back(std::make_unique<Queue>(queue_depth_));
  29. auto worker = [this](std::size_t i)
  30. {
  31. while(true)
  32. {
  33. Proc f;
  34. for(std::size_t n = 0; n < (_count * K); ++n)
  35. {
  36. if(_queues[(i + n) % _count]->pop(f))
  37. break;
  38. }
  39. if(!f && !_queues[i]->pop(f))
  40. break;
  41. f();
  42. }
  43. };
  44. sigset_t oldset;
  45. sigset_t newset;
  46. sigfillset(&newset);
  47. pthread_sigmask(SIG_BLOCK,&newset,&oldset);
  48. _threads.reserve(thread_count_);
  49. for(std::size_t i = 0; i < thread_count_; ++i)
  50. _threads.emplace_back(worker, i);
  51. pthread_sigmask(SIG_SETMASK,&oldset,NULL);
  52. }
  53. ~BoundedThreadPool()
  54. {
  55. for(auto &queue : _queues)
  56. queue->unblock();
  57. for(auto &thread : _threads)
  58. pthread_cancel(thread.native_handle());
  59. for(auto &thread : _threads)
  60. thread.join();
  61. }
  62. template<typename F>
  63. void
  64. enqueue_work(F&& f_)
  65. {
  66. auto i = _index++;
  67. for(std::size_t n = 0; n < (_count * K); ++n)
  68. {
  69. if(_queues[(i + n) % _count]->push(f_))
  70. return;
  71. }
  72. _queues[i % _count]->push(std::move(f_));
  73. }
  74. template<typename F>
  75. [[nodiscard]]
  76. std::future<typename std::result_of<F()>::type>
  77. enqueue_task(F&& f_)
  78. {
  79. using TaskReturnType = typename std::result_of<F()>::type;
  80. using Promise = std::promise<TaskReturnType>;
  81. auto i = _index++;
  82. auto promise = std::make_shared<Promise>();
  83. auto future = promise->get_future();
  84. auto work = [=]() {
  85. auto rv = f_();
  86. promise->set_value(rv);
  87. };
  88. for(std::size_t n = 0; n < (_count * K); ++n)
  89. {
  90. if(_queues[(i + n) % _count]->push(work))
  91. return future;
  92. }
  93. _queues[i % _count]->push(std::move(work));
  94. return future;
  95. }
  96. public:
  97. std::vector<pthread_t>
  98. threads()
  99. {
  100. std::vector<pthread_t> rv;
  101. for(auto &thread : _threads)
  102. rv.push_back(thread.native_handle());
  103. return rv;
  104. }
  105. private:
  106. Queues _queues;
  107. private:
  108. std::vector<std::thread> _threads;
  109. private:
  110. const std::size_t _count;
  111. std::atomic_uint _index;
  112. static const unsigned int K = 2;
  113. };