You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
4.9 KiB

  1. package s3api
  2. import (
  3. "fmt"
  4. "github.com/gorilla/mux"
  5. "github.com/seaweedfs/seaweedfs/weed/filer"
  6. "github.com/seaweedfs/seaweedfs/weed/glog"
  7. "github.com/seaweedfs/seaweedfs/weed/pb"
  8. "github.com/seaweedfs/seaweedfs/weed/pb/filer_pb"
  9. "github.com/seaweedfs/seaweedfs/weed/pb/s3_pb"
  10. "github.com/seaweedfs/seaweedfs/weed/s3api/s3_constants"
  11. "github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
  12. "net/http"
  13. "sync"
  14. "sync/atomic"
  15. )
  16. type CircuitBreaker struct {
  17. sync.RWMutex
  18. Enabled bool
  19. counters map[string]*int64
  20. limitations map[string]int64
  21. }
  22. func NewCircuitBreaker(option *S3ApiServerOption) *CircuitBreaker {
  23. cb := &CircuitBreaker{
  24. counters: make(map[string]*int64),
  25. limitations: make(map[string]int64),
  26. }
  27. err := pb.WithFilerClient(false, 0, option.Filer, option.GrpcDialOption, func(client filer_pb.SeaweedFilerClient) error {
  28. content, err := filer.ReadInsideFiler(client, s3_constants.CircuitBreakerConfigDir, s3_constants.CircuitBreakerConfigFile)
  29. if err != nil {
  30. return fmt.Errorf("read S3 circuit breaker config: %v", err)
  31. }
  32. return cb.LoadS3ApiConfigurationFromBytes(content)
  33. })
  34. if err != nil {
  35. glog.Infof("s3 circuit breaker not configured: %v", err)
  36. }
  37. return cb
  38. }
  39. func (cb *CircuitBreaker) LoadS3ApiConfigurationFromBytes(content []byte) error {
  40. cbCfg := &s3_pb.S3CircuitBreakerConfig{}
  41. if err := filer.ParseS3ConfigurationFromBytes(content, cbCfg); err != nil {
  42. glog.Warningf("unmarshal error: %v", err)
  43. return fmt.Errorf("unmarshal error: %v", err)
  44. }
  45. if err := cb.loadCircuitBreakerConfig(cbCfg); err != nil {
  46. return err
  47. }
  48. return nil
  49. }
  50. func (cb *CircuitBreaker) loadCircuitBreakerConfig(cfg *s3_pb.S3CircuitBreakerConfig) error {
  51. //global
  52. globalEnabled := false
  53. globalOptions := cfg.Global
  54. limitations := make(map[string]int64)
  55. if globalOptions != nil && globalOptions.Enabled && len(globalOptions.Actions) > 0 {
  56. globalEnabled = globalOptions.Enabled
  57. for action, limit := range globalOptions.Actions {
  58. limitations[action] = limit
  59. }
  60. }
  61. cb.Enabled = globalEnabled
  62. //buckets
  63. for bucket, cbOptions := range cfg.Buckets {
  64. if cbOptions.Enabled {
  65. for action, limit := range cbOptions.Actions {
  66. limitations[s3_constants.Concat(bucket, action)] = limit
  67. }
  68. }
  69. }
  70. cb.limitations = limitations
  71. return nil
  72. }
  73. func (cb *CircuitBreaker) Limit(f http.HandlerFunc, action string) (http.HandlerFunc, Action) {
  74. return func(w http.ResponseWriter, r *http.Request) {
  75. if !cb.Enabled {
  76. f(w, r)
  77. return
  78. }
  79. vars := mux.Vars(r)
  80. bucket := vars["bucket"]
  81. rollback, errCode := cb.limit(r, bucket, action)
  82. defer func() {
  83. for _, rf := range rollback {
  84. rf()
  85. }
  86. }()
  87. if errCode == s3err.ErrNone {
  88. f(w, r)
  89. return
  90. }
  91. s3err.WriteErrorResponse(w, r, errCode)
  92. }, Action(action)
  93. }
  94. func (cb *CircuitBreaker) limit(r *http.Request, bucket string, action string) (rollback []func(), errCode s3err.ErrorCode) {
  95. //bucket simultaneous request count
  96. bucketCountRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(bucket, action, s3_constants.LimitTypeCount), 1, s3err.ErrTooManyRequest)
  97. if bucketCountRollBack != nil {
  98. rollback = append(rollback, bucketCountRollBack)
  99. }
  100. if errCode != s3err.ErrNone {
  101. return
  102. }
  103. //bucket simultaneous request content bytes
  104. bucketContentLengthRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(bucket, action, s3_constants.LimitTypeBytes), r.ContentLength, s3err.ErrRequestBytesExceed)
  105. if bucketContentLengthRollBack != nil {
  106. rollback = append(rollback, bucketContentLengthRollBack)
  107. }
  108. if errCode != s3err.ErrNone {
  109. return
  110. }
  111. //global simultaneous request count
  112. globalCountRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(action, s3_constants.LimitTypeCount), 1, s3err.ErrTooManyRequest)
  113. if globalCountRollBack != nil {
  114. rollback = append(rollback, globalCountRollBack)
  115. }
  116. if errCode != s3err.ErrNone {
  117. return
  118. }
  119. //global simultaneous request content bytes
  120. globalContentLengthRollBack, errCode := cb.loadCounterAndCompare(s3_constants.Concat(action, s3_constants.LimitTypeBytes), r.ContentLength, s3err.ErrRequestBytesExceed)
  121. if globalContentLengthRollBack != nil {
  122. rollback = append(rollback, globalContentLengthRollBack)
  123. }
  124. if errCode != s3err.ErrNone {
  125. return
  126. }
  127. return
  128. }
  129. func (cb *CircuitBreaker) loadCounterAndCompare(key string, inc int64, errCode s3err.ErrorCode) (f func(), e s3err.ErrorCode) {
  130. e = s3err.ErrNone
  131. if max, ok := cb.limitations[key]; ok {
  132. cb.RLock()
  133. counter, exists := cb.counters[key]
  134. cb.RUnlock()
  135. if !exists {
  136. cb.Lock()
  137. counter, exists = cb.counters[key]
  138. if !exists {
  139. var newCounter int64
  140. counter = &newCounter
  141. cb.counters[key] = counter
  142. }
  143. cb.Unlock()
  144. }
  145. current := atomic.LoadInt64(counter)
  146. if current+inc > max {
  147. e = errCode
  148. return
  149. } else {
  150. current := atomic.AddInt64(counter, inc)
  151. f = func() {
  152. atomic.AddInt64(counter, -inc)
  153. }
  154. if current > max {
  155. e = errCode
  156. return
  157. }
  158. }
  159. }
  160. return
  161. }