154 lines
3.8 KiB

  1. package resource_pool
  2. import (
  3. "fmt"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. )
  8. type Semaphore interface {
  9. // Increment the semaphore counter by one.
  10. Release()
  11. // Decrement the semaphore counter by one, and block if counter < 0
  12. Acquire()
  13. // Decrement the semaphore counter by one, and block if counter < 0
  14. // Wait for up to the given duration. Returns true if did not timeout
  15. TryAcquire(timeout time.Duration) bool
  16. }
  17. // A simple counting Semaphore.
  18. type boundedSemaphore struct {
  19. slots chan struct{}
  20. }
  21. // Create a bounded semaphore. The count parameter must be a positive number.
  22. // NOTE: The bounded semaphore will panic if the user tries to Release
  23. // beyond the specified count.
  24. func NewBoundedSemaphore(count uint) Semaphore {
  25. sem := &boundedSemaphore{
  26. slots: make(chan struct{}, int(count)),
  27. }
  28. for i := 0; i < cap(sem.slots); i++ {
  29. sem.slots <- struct{}{}
  30. }
  31. return sem
  32. }
  33. // Acquire returns on successful acquisition.
  34. func (sem *boundedSemaphore) Acquire() {
  35. <-sem.slots
  36. }
  37. // TryAcquire returns true if it acquires a resource slot within the
  38. // timeout, false otherwise.
  39. func (sem *boundedSemaphore) TryAcquire(timeout time.Duration) bool {
  40. if timeout > 0 {
  41. // Wait until we get a slot or timeout expires.
  42. tm := time.NewTimer(timeout)
  43. defer tm.Stop()
  44. select {
  45. case <-sem.slots:
  46. return true
  47. case <-tm.C:
  48. // Timeout expired. In very rare cases this might happen even if
  49. // there is a slot available, e.g. GC pause after we create the timer
  50. // and select randomly picked this one out of the two available channels.
  51. // We should do one final immediate check below.
  52. }
  53. }
  54. // Return true if we have a slot available immediately and false otherwise.
  55. select {
  56. case <-sem.slots:
  57. return true
  58. default:
  59. return false
  60. }
  61. }
  62. // Release the acquired semaphore. You must not release more than you
  63. // have acquired.
  64. func (sem *boundedSemaphore) Release() {
  65. select {
  66. case sem.slots <- struct{}{}:
  67. default:
  68. // slots is buffered. If a send blocks, it indicates a programming
  69. // error.
  70. panic(fmt.Errorf("too many releases for boundedSemaphore"))
  71. }
  72. }
  73. // This returns an unbound counting semaphore with the specified initial count.
  74. // The semaphore counter can be arbitrary large (i.e., Release can be called
  75. // unlimited amount of times).
  76. //
  77. // NOTE: In general, users should use bounded semaphore since it is more
  78. // efficient than unbounded semaphore.
  79. func NewUnboundedSemaphore(initialCount int) Semaphore {
  80. res := &unboundedSemaphore{
  81. counter: int64(initialCount),
  82. }
  83. res.cond.L = &res.lock
  84. return res
  85. }
  86. type unboundedSemaphore struct {
  87. lock sync.Mutex
  88. cond sync.Cond
  89. counter int64
  90. }
  91. func (s *unboundedSemaphore) Release() {
  92. s.lock.Lock()
  93. s.counter += 1
  94. if s.counter > 0 {
  95. // Not broadcasting here since it's unlike we can satisfy all waiting
  96. // goroutines. Instead, we will Signal again if there are left over
  97. // quota after Acquire, in case of lost wakeups.
  98. s.cond.Signal()
  99. }
  100. s.lock.Unlock()
  101. }
  102. func (s *unboundedSemaphore) Acquire() {
  103. s.lock.Lock()
  104. for s.counter < 1 {
  105. s.cond.Wait()
  106. }
  107. s.counter -= 1
  108. if s.counter > 0 {
  109. s.cond.Signal()
  110. }
  111. s.lock.Unlock()
  112. }
  113. func (s *unboundedSemaphore) TryAcquire(timeout time.Duration) bool {
  114. done := make(chan bool, 1)
  115. // Gate used to communicate between the threads and decide what the result
  116. // is. If the main thread decides, we have timed out, otherwise we succeed.
  117. decided := new(int32)
  118. atomic.StoreInt32(decided, 0)
  119. go func() {
  120. s.Acquire()
  121. if atomic.SwapInt32(decided, 1) == 0 {
  122. // Acquire won the race
  123. done <- true
  124. } else {
  125. // If we already decided the result, and this thread did not win
  126. s.Release()
  127. }
  128. }()
  129. select {
  130. case <-done:
  131. return true
  132. case <-time.After(timeout):
  133. if atomic.SwapInt32(decided, 1) == 1 {
  134. // The other thread already decided the result
  135. return true
  136. }
  137. return false
  138. }
  139. }