You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							201 lines
						
					
					
						
							5.4 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							201 lines
						
					
					
						
							5.4 KiB
						
					
					
				| package client | |
| 
 | |
| import ( | |
| 	"crypto/tls" | |
| 	"crypto/x509" | |
| 	"fmt" | |
| 	util "github.com/seaweedfs/seaweedfs/weed/util" | |
| 	"github.com/spf13/viper" | |
| 	"io" | |
| 	"net/http" | |
| 	"net/url" | |
| 	"os" | |
| 	"strings" | |
| 	"sync" | |
| ) | |
| 
 | |
| var ( | |
| 	loadSecurityConfigOnce sync.Once | |
| ) | |
| 
 | |
| type HTTPClient struct { | |
| 	Client            *http.Client | |
| 	Transport         *http.Transport | |
| 	expectHttpsScheme bool | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) Do(req *http.Request) (*http.Response, error) { | |
| 	req.URL.Scheme = httpClient.GetHttpScheme() | |
| 	return httpClient.Client.Do(req) | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) Get(url string) (resp *http.Response, err error) { | |
| 	url, err = httpClient.NormalizeHttpScheme(url) | |
| 	if err != nil { | |
| 		return nil, err | |
| 	} | |
| 	return httpClient.Client.Get(url) | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) Post(url, contentType string, body io.Reader) (resp *http.Response, err error) { | |
| 	url, err = httpClient.NormalizeHttpScheme(url) | |
| 	if err != nil { | |
| 		return nil, err | |
| 	} | |
| 	return httpClient.Client.Post(url, contentType, body) | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) PostForm(url string, data url.Values) (resp *http.Response, err error) { | |
| 	url, err = httpClient.NormalizeHttpScheme(url) | |
| 	if err != nil { | |
| 		return nil, err | |
| 	} | |
| 	return httpClient.Client.PostForm(url, data) | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) Head(url string) (resp *http.Response, err error) { | |
| 	url, err = httpClient.NormalizeHttpScheme(url) | |
| 	if err != nil { | |
| 		return nil, err | |
| 	} | |
| 	return httpClient.Client.Head(url) | |
| } | |
| func (httpClient *HTTPClient) CloseIdleConnections() { | |
| 	httpClient.Client.CloseIdleConnections() | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) GetClientTransport() *http.Transport { | |
| 	return httpClient.Transport | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) GetHttpScheme() string { | |
| 	if httpClient.expectHttpsScheme { | |
| 		return "https" | |
| 	} | |
| 	return "http" | |
| } | |
| 
 | |
| func (httpClient *HTTPClient) NormalizeHttpScheme(rawURL string) (string, error) { | |
| 	expectedScheme := httpClient.GetHttpScheme() | |
| 
 | |
| 	if !(strings.HasPrefix(rawURL, "http://") || strings.HasPrefix(rawURL, "https://")) { | |
| 		return expectedScheme + "://" + rawURL, nil | |
| 	} | |
| 
 | |
| 	parsedURL, err := url.Parse(rawURL) | |
| 	if err != nil { | |
| 		return "", err | |
| 	} | |
| 
 | |
| 	if expectedScheme != parsedURL.Scheme { | |
| 		parsedURL.Scheme = expectedScheme | |
| 	} | |
| 	return parsedURL.String(), nil | |
| } | |
| 
 | |
| func NewHttpClient(clientName ClientName, opts ...HttpClientOpt) (*HTTPClient, error) { | |
| 	httpClient := HTTPClient{} | |
| 	httpClient.expectHttpsScheme = checkIsHttpsClientEnabled(clientName) | |
| 	var tlsConfig *tls.Config = nil | |
| 
 | |
| 	if httpClient.expectHttpsScheme { | |
| 		clientCertPair, err := getClientCertPair(clientName) | |
| 		if err != nil { | |
| 			return nil, err | |
| 		} | |
| 
 | |
| 		clientCaCert, clientCaCertName, err := getClientCaCert(clientName) | |
| 		if err != nil { | |
| 			return nil, err | |
| 		} | |
| 
 | |
| 		if clientCertPair != nil || len(clientCaCert) != 0 { | |
| 			caCertPool, err := createHTTPClientCertPool(clientCaCert, clientCaCertName) | |
| 			if err != nil { | |
| 				return nil, err | |
| 			} | |
| 
 | |
| 			tlsConfig = &tls.Config{ | |
| 				Certificates:       []tls.Certificate{}, | |
| 				RootCAs:            caCertPool, | |
| 				InsecureSkipVerify: false, | |
| 			} | |
| 
 | |
| 			if clientCertPair != nil { | |
| 				tlsConfig.Certificates = append(tlsConfig.Certificates, *clientCertPair) | |
| 			} | |
| 		} | |
| 	} | |
| 
 | |
| 	httpClient.Transport = &http.Transport{ | |
| 		MaxIdleConns:        1024, | |
| 		MaxIdleConnsPerHost: 1024, | |
| 		TLSClientConfig:     tlsConfig, | |
| 	} | |
| 	httpClient.Client = &http.Client{ | |
| 		Transport: httpClient.Transport, | |
| 	} | |
| 
 | |
| 	for _, opt := range opts { | |
| 		opt(&httpClient) | |
| 	} | |
| 	return &httpClient, nil | |
| } | |
| 
 | |
| func getStringOptionFromSecurityConfiguration(clientName ClientName, stringOptionName string) string { | |
| 	util.LoadSecurityConfiguration() | |
| 	return viper.GetString(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), stringOptionName)) | |
| } | |
| 
 | |
| func getBoolOptionFromSecurityConfiguration(clientName ClientName, boolOptionName string) bool { | |
| 	util.LoadSecurityConfiguration() | |
| 	return viper.GetBool(fmt.Sprintf("https.%s.%s", clientName.LowerCaseString(), boolOptionName)) | |
| } | |
| 
 | |
| func checkIsHttpsClientEnabled(clientName ClientName) bool { | |
| 	return getBoolOptionFromSecurityConfiguration(clientName, "enabled") | |
| } | |
| 
 | |
| func getFileContentFromSecurityConfiguration(clientName ClientName, fileType string) ([]byte, string, error) { | |
| 	if fileName := getStringOptionFromSecurityConfiguration(clientName, fileType); fileName != "" { | |
| 		fileContent, err := os.ReadFile(fileName) | |
| 		if err != nil { | |
| 			return nil, fileName, err | |
| 		} | |
| 		return fileContent, fileName, err | |
| 	} | |
| 	return nil, "", nil | |
| } | |
| 
 | |
| func getClientCertPair(clientName ClientName) (*tls.Certificate, error) { | |
| 	certFileName := getStringOptionFromSecurityConfiguration(clientName, "cert") | |
| 	keyFileName := getStringOptionFromSecurityConfiguration(clientName, "key") | |
| 	if certFileName == "" && keyFileName == "" { | |
| 		return nil, nil | |
| 	} | |
| 	if certFileName != "" && keyFileName != "" { | |
| 		clientCert, err := tls.LoadX509KeyPair(certFileName, keyFileName) | |
| 		if err != nil { | |
| 			return nil, fmt.Errorf("error loading client certificate and key: %s", err) | |
| 		} | |
| 		return &clientCert, nil | |
| 	} | |
| 	return nil, fmt.Errorf("error loading key pair: key `%s` and certificate `%s`", keyFileName, certFileName) | |
| } | |
| 
 | |
| func getClientCaCert(clientName ClientName) ([]byte, string, error) { | |
| 	return getFileContentFromSecurityConfiguration(clientName, "ca") | |
| } | |
| 
 | |
| func createHTTPClientCertPool(certContent []byte, fileName string) (*x509.CertPool, error) { | |
| 	certPool := x509.NewCertPool() | |
| 	if len(certContent) == 0 { | |
| 		return certPool, nil | |
| 	} | |
| 
 | |
| 	ok := certPool.AppendCertsFromPEM(certContent) | |
| 	if !ok { | |
| 		return nil, fmt.Errorf("error processing certificate in %s", fileName) | |
| 	} | |
| 	return certPool, nil | |
| }
 |