You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
4.7 KiB

  1. package s3api
  2. import (
  3. "fmt"
  4. "github.com/chrislusf/seaweedfs/weed/config"
  5. "github.com/chrislusf/seaweedfs/weed/filer"
  6. "github.com/chrislusf/seaweedfs/weed/glog"
  7. "github.com/chrislusf/seaweedfs/weed/pb"
  8. "github.com/chrislusf/seaweedfs/weed/pb/filer_pb"
  9. "github.com/chrislusf/seaweedfs/weed/pb/s3_pb"
  10. "github.com/chrislusf/seaweedfs/weed/s3api/s3err"
  11. "github.com/gorilla/mux"
  12. "go.uber.org/atomic"
  13. "net/http"
  14. )
  15. type CircuitBreaker struct {
  16. Enabled bool
  17. counters map[string]*atomic.Int64
  18. limitations map[string]int64
  19. }
  20. func NewCircuitBreaker(option *S3ApiServerOption) *CircuitBreaker {
  21. cb := &CircuitBreaker{
  22. counters: make(map[string]*atomic.Int64),
  23. limitations: make(map[string]int64),
  24. }
  25. _ = pb.WithFilerClient(false, option.Filer, option.GrpcDialOption, func(client filer_pb.SeaweedFilerClient) error {
  26. content, err := filer.ReadInsideFiler(client, config.CircuitBreakerConfigDir, config.CircuitBreakerConfigFile)
  27. if err == nil {
  28. err = cb.LoadS3ApiConfigurationFromBytes(content)
  29. }
  30. if err != nil {
  31. glog.Warningf("load s3 circuit breaker config from filer: %v", err)
  32. } else {
  33. glog.V(2).Infof("load s3 circuit breaker config complete: %v", cb)
  34. }
  35. return err
  36. })
  37. return cb
  38. }
  39. func (cb *CircuitBreaker) LoadS3ApiConfigurationFromBytes(content []byte) error {
  40. cbCfg := &s3_pb.S3CircuitBreakerConfig{}
  41. if err := filer.ParseS3ConfigurationFromBytes(content, cbCfg); err != nil {
  42. glog.Warningf("unmarshal error: %v", err)
  43. return fmt.Errorf("unmarshal error: %v", err)
  44. }
  45. if err := cb.loadCbCfg(cbCfg); err != nil {
  46. return err
  47. }
  48. return nil
  49. }
  50. func (cb *CircuitBreaker) loadCbCfg(cfg *s3_pb.S3CircuitBreakerConfig) error {
  51. //global
  52. globalEnabled := false
  53. globalOptions := cfg.Global
  54. limitations := make(map[string]int64)
  55. if globalOptions != nil && globalOptions.Enabled && len(globalOptions.Actions) > 0 {
  56. globalEnabled = globalOptions.Enabled
  57. for action, limit := range globalOptions.Actions {
  58. limitations[action] = limit
  59. }
  60. }
  61. cb.Enabled = globalEnabled
  62. //buckets
  63. for bucket, cbOptions := range cfg.Buckets {
  64. if cbOptions.Enabled {
  65. for action, limit := range cbOptions.Actions {
  66. limitations[config.Concat(bucket, action)] = limit
  67. }
  68. }
  69. }
  70. cb.limitations = limitations
  71. return nil
  72. }
  73. func (cb *CircuitBreaker) Check(f func(w http.ResponseWriter, r *http.Request), action string) (http.HandlerFunc, Action) {
  74. return func(w http.ResponseWriter, r *http.Request) {
  75. if !cb.Enabled {
  76. f(w, r)
  77. return
  78. }
  79. vars := mux.Vars(r)
  80. bucket := vars["bucket"]
  81. rollback, errCode := cb.check(r, bucket, action)
  82. defer func() {
  83. for _, rf := range rollback {
  84. rf()
  85. }
  86. }()
  87. if errCode == s3err.ErrNone {
  88. f(w, r)
  89. return
  90. }
  91. s3err.WriteErrorResponse(w, r, errCode)
  92. }, Action(action)
  93. }
  94. func (cb *CircuitBreaker) check(r *http.Request, bucket string, action string) (rollback []func(), errCode s3err.ErrorCode) {
  95. //bucket simultaneous request count
  96. bucketCountRollBack, errCode := cb.loadAndCompare(bucket, action, config.LimitTypeCount, 1, s3err.ErrTooManyRequest)
  97. if bucketCountRollBack != nil {
  98. rollback = append(rollback, bucketCountRollBack)
  99. }
  100. if errCode != s3err.ErrNone {
  101. return
  102. }
  103. //bucket simultaneous request content bytes
  104. bucketContentLengthRollBack, errCode := cb.loadAndCompare(bucket, action, config.LimitTypeBytes, r.ContentLength, s3err.ErrRequestBytesExceed)
  105. if bucketContentLengthRollBack != nil {
  106. rollback = append(rollback, bucketContentLengthRollBack)
  107. }
  108. if errCode != s3err.ErrNone {
  109. return
  110. }
  111. //global simultaneous request count
  112. globalCountRollBack, errCode := cb.loadAndCompare("", action, config.LimitTypeCount, 1, s3err.ErrTooManyRequest)
  113. if globalCountRollBack != nil {
  114. rollback = append(rollback, globalCountRollBack)
  115. }
  116. if errCode != s3err.ErrNone {
  117. return
  118. }
  119. //global simultaneous request content bytes
  120. globalContentLengthRollBack, errCode := cb.loadAndCompare("", action, config.LimitTypeBytes, r.ContentLength, s3err.ErrRequestBytesExceed)
  121. if globalContentLengthRollBack != nil {
  122. rollback = append(rollback, globalContentLengthRollBack)
  123. }
  124. if errCode != s3err.ErrNone {
  125. return
  126. }
  127. return
  128. }
  129. func (cb CircuitBreaker) loadAndCompare(bucket, action, limitType string, inc int64, errCode s3err.ErrorCode) (f func(), e s3err.ErrorCode) {
  130. key := config.Concat(bucket, action, limitType)
  131. e = s3err.ErrNone
  132. if max, ok := cb.limitations[key]; ok {
  133. counter, exists := cb.counters[key]
  134. if !exists {
  135. counter = atomic.NewInt64(0)
  136. cb.counters[key] = counter
  137. }
  138. current := counter.Load()
  139. if current+inc > max {
  140. e = errCode
  141. return
  142. } else {
  143. counter.Add(inc)
  144. f = func() {
  145. counter.Sub(inc)
  146. }
  147. current = counter.Load()
  148. if current+inc > max {
  149. e = errCode
  150. return
  151. }
  152. }
  153. }
  154. return
  155. }