You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

201 lines
5.4 KiB

  1. package client
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "fmt"
  6. util "github.com/seaweedfs/seaweedfs/weed/util"
  7. "github.com/spf13/viper"
  8. "io"
  9. "net/http"
  10. "net/url"
  11. "os"
  12. "strings"
  13. "sync"
  14. )
  15. var (
  16. loadSecurityConfigOnce sync.Once
  17. )
  18. type HTTPClient struct {
  19. Client *http.Client
  20. Transport *http.Transport
  21. expectHttpsScheme bool
  22. }
  23. func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) {
  24. req.URL.Scheme = httpClient.GetHttpScheme()
  25. return httpClient.Client.Do(req)
  26. }
  27. func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) {
  28. url, err = httpClient.NormalizeHttpScheme(url)
  29. if err != nil {
  30. return nil, err
  31. }
  32. return httpClient.Client.Get(url)
  33. }
  34. func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) {
  35. url, err = httpClient.NormalizeHttpScheme(url)
  36. if err != nil {
  37. return nil, err
  38. }
  39. return httpClient.Client.Post(url, contentType, body)
  40. }
  41. func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) {
  42. url, err = httpClient.NormalizeHttpScheme(url)
  43. if err != nil {
  44. return nil, err
  45. }
  46. return httpClient.Client.PostForm(url, data)
  47. }
  48. func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) {
  49. url, err = httpClient.NormalizeHttpScheme(url)
  50. if err != nil {
  51. return nil, err
  52. }
  53. return httpClient.Client.Head(url)
  54. }
  55. func (httpClient *HTTPClient) CloseIdleConnections() {
  56. httpClient.Client.CloseIdleConnections()
  57. }
  58. func (httpClient *HTTPClient) GetClientTransport() *http.Transport {
  59. return httpClient.Transport
  60. }
  61. func (httpClient *HTTPClient) GetHttpScheme() string {
  62. if httpClient.expectHttpsScheme {
  63. return "https"
  64. }
  65. return "http"
  66. }
  67. func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) {
  68. expectedScheme := httpClient.GetHttpScheme()
  69. if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) {
  70. return expectedScheme + "://" + rawURL, nil
  71. }
  72. parsedURL, err := url.Parse(rawURL)
  73. if err != nil {
  74. return "", err
  75. }
  76. if expectedScheme != parsedURL.Scheme {
  77. parsedURL.Scheme = expectedScheme
  78. }
  79. return parsedURL.String(), nil
  80. }
  81. func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) {
  82. httpClient := HTTPClient{}
  83. httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName)
  84. var tlsConfig *tls.Config = nil
  85. if httpClient.expectHttpsScheme {
  86. clientCertPair, err := getClientCertPair(clientName)
  87. if err != nil {
  88. return nil, err
  89. }
  90. clientCaCert, clientCaCertName, err := getClientCaCert(clientName)
  91. if err != nil {
  92. return nil, err
  93. }
  94. if clientCertPair != nil || len(clientCaCert) != 0 {
  95. caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName)
  96. if err != nil {
  97. return nil, err
  98. }
  99. tlsConfig = &tls.Config{
  100. Certificates: []tls.Certificate{},
  101. RootCAs: caCertPool,
  102. InsecureSkipVerify: false,
  103. }
  104. if clientCertPair != nil {
  105. tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair)
  106. }
  107. }
  108. }
  109. httpClient.Transport = &http.Transport{
  110. MaxIdleConns: 1024,
  111. MaxIdleConnsPerHost: 1024,
  112. TLSClientConfig: tlsConfig,
  113. }
  114. httpClient.Client = &http.Client{
  115. Transport: httpClient.Transport,
  116. }
  117. for _, opt := range opts {
  118. opt(&httpClient)
  119. }
  120. return &httpClient, nil
  121. }
  122. func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string {
  123. util.LoadSecurityConfiguration()
  124. return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName))
  125. }
  126. func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool {
  127. util.LoadSecurityConfiguration()
  128. return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName))
  129. }
  130. func checkIsHttpsClientEnabled(clientName ClientName) bool {
  131. return getBoolOptionFromSecurityConfiguration(clientName, "enabled")
  132. }
  133. func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) {
  134. if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" {
  135. fileContent, err := os.ReadFile(fileName)
  136. if err != nil {
  137. return nil, fileName, err
  138. }
  139. return fileContent, fileName, err
  140. }
  141. return nil, "", nil
  142. }
  143. func getClientCertPair(clientName ClientName) (*tls.Certificate, error) {
  144. certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert")
  145. keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key")
  146. if certFileName == "" && keyFileName == "" {
  147. return nil, nil
  148. }
  149. if certFileName != "" && keyFileName != "" {
  150. clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName)
  151. if err != nil {
  152. return nil, fmt.Errorf("error loading client certificate and key: %s", err)
  153. }
  154. return &clientCert, nil
  155. }
  156. return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName)
  157. }
  158. func getClientCaCert(clientName ClientName) ([]byte, string, error) {
  159. return getFileContentFromSecurityConfiguration(clientName, "ca")
  160. }
  161. func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) {
  162. certPool := x509.NewCertPool()
  163. if len(certContent) == 0 {
  164. return certPool, nil
  165. }
  166. ok := certPool.AppendCertsFromPEM(certContent)
  167. if !ok {
  168. return nil, fmt.Errorf("error processing certificate in %s", fileName)
  169. }
  170. return certPool, nil
  171. }