You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							201 lines
						
					
					
						
							5.4 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							201 lines
						
					
					
						
							5.4 KiB
						
					
					
				
								package client
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"crypto/tls"
							 | 
						|
									"crypto/x509"
							 | 
						|
									"fmt"
							 | 
						|
									util "github.com/seaweedfs/seaweedfs/weed/util"
							 | 
						|
									"github.com/spf13/viper"
							 | 
						|
									"io"
							 | 
						|
									"net/http"
							 | 
						|
									"net/url"
							 | 
						|
									"os"
							 | 
						|
									"strings"
							 | 
						|
									"sync"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								var (
							 | 
						|
									loadSecurityConfigOnce sync.Once
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								type HTTPClient struct {
							 | 
						|
									Client            *http.Client
							 | 
						|
									Transport         *http.Transport
							 | 
						|
									expectHttpsScheme bool
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) {
							 | 
						|
									req.URL.Scheme = httpClient.GetHttpScheme()
							 | 
						|
									return httpClient.Client.Do(req)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) {
							 | 
						|
									url, err = httpClient.NormalizeHttpScheme(url)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
									return httpClient.Client.Get(url)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) {
							 | 
						|
									url, err = httpClient.NormalizeHttpScheme(url)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
									return httpClient.Client.Post(url, contentType, body)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) {
							 | 
						|
									url, err = httpClient.NormalizeHttpScheme(url)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
									return httpClient.Client.PostForm(url, data)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) {
							 | 
						|
									url, err = httpClient.NormalizeHttpScheme(url)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, err
							 | 
						|
									}
							 | 
						|
									return httpClient.Client.Head(url)
							 | 
						|
								}
							 | 
						|
								func (httpClient *HTTPClient) CloseIdleConnections() {
							 | 
						|
									httpClient.Client.CloseIdleConnections()
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) GetClientTransport() *http.Transport {
							 | 
						|
									return httpClient.Transport
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) GetHttpScheme() string {
							 | 
						|
									if httpClient.expectHttpsScheme {
							 | 
						|
										return "https"
							 | 
						|
									}
							 | 
						|
									return "http"
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) {
							 | 
						|
									expectedScheme := httpClient.GetHttpScheme()
							 | 
						|
								
							 | 
						|
									if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) {
							 | 
						|
										return expectedScheme + "://" + rawURL, nil
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									parsedURL, err := url.Parse(rawURL)
							 | 
						|
									if err != nil {
							 | 
						|
										return "", err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if expectedScheme != parsedURL.Scheme {
							 | 
						|
										parsedURL.Scheme = expectedScheme
							 | 
						|
									}
							 | 
						|
									return parsedURL.String(), nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) {
							 | 
						|
									httpClient := HTTPClient{}
							 | 
						|
									httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName)
							 | 
						|
									var tlsConfig *tls.Config = nil
							 | 
						|
								
							 | 
						|
									if httpClient.expectHttpsScheme {
							 | 
						|
										clientCertPair, err := getClientCertPair(clientName)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, err
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										clientCaCert, clientCaCertName, err := getClientCaCert(clientName)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, err
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										if clientCertPair != nil || len(clientCaCert) != 0 {
							 | 
						|
											caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName)
							 | 
						|
											if err != nil {
							 | 
						|
												return nil, err
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											tlsConfig = &tls.Config{
							 | 
						|
												Certificates:       []tls.Certificate{},
							 | 
						|
												RootCAs:            caCertPool,
							 | 
						|
												InsecureSkipVerify: false,
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											if clientCertPair != nil {
							 | 
						|
												tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair)
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									httpClient.Transport = &http.Transport{
							 | 
						|
										MaxIdleConns:        1024,
							 | 
						|
										MaxIdleConnsPerHost: 1024,
							 | 
						|
										TLSClientConfig:     tlsConfig,
							 | 
						|
									}
							 | 
						|
									httpClient.Client = &http.Client{
							 | 
						|
										Transport: httpClient.Transport,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, opt := range opts {
							 | 
						|
										opt(&httpClient)
							 | 
						|
									}
							 | 
						|
									return &httpClient, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string {
							 | 
						|
									util.LoadSecurityConfiguration()
							 | 
						|
									return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName))
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool {
							 | 
						|
									util.LoadSecurityConfiguration()
							 | 
						|
									return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName))
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func checkIsHttpsClientEnabled(clientName ClientName) bool {
							 | 
						|
									return getBoolOptionFromSecurityConfiguration(clientName, "enabled")
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) {
							 | 
						|
									if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" {
							 | 
						|
										fileContent, err := os.ReadFile(fileName)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, fileName, err
							 | 
						|
										}
							 | 
						|
										return fileContent, fileName, err
							 | 
						|
									}
							 | 
						|
									return nil, "", nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func getClientCertPair(clientName ClientName) (*tls.Certificate, error) {
							 | 
						|
									certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert")
							 | 
						|
									keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key")
							 | 
						|
									if certFileName == "" && keyFileName == "" {
							 | 
						|
										return nil, nil
							 | 
						|
									}
							 | 
						|
									if certFileName != "" && keyFileName != "" {
							 | 
						|
										clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName)
							 | 
						|
										if err != nil {
							 | 
						|
											return nil, fmt.Errorf("error loading client certificate and key: %s", err)
							 | 
						|
										}
							 | 
						|
										return &clientCert, nil
							 | 
						|
									}
							 | 
						|
									return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func getClientCaCert(clientName ClientName) ([]byte, string, error) {
							 | 
						|
									return getFileContentFromSecurityConfiguration(clientName, "ca")
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) {
							 | 
						|
									certPool := x509.NewCertPool()
							 | 
						|
									if len(certContent) == 0 {
							 | 
						|
										return certPool, nil
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									ok := certPool.AppendCertsFromPEM(certContent)
							 | 
						|
									if !ok {
							 | 
						|
										return nil, fmt.Errorf("error processing certificate in %s", fileName)
							 | 
						|
									}
							 | 
						|
									return certPool, nil
							 | 
						|
								}
							 |