|
|
|
@ -1,13 +1,19 @@ |
|
|
|
use std::collections::HashSet;
|
|
|
|
use std::fmt;
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
use rustls::client::danger::HandshakeSignatureValid;
|
|
|
|
use rustls::crypto::aws_lc_rs;
|
|
|
|
use rustls::crypto::CryptoProvider;
|
|
|
|
use rustls::pki_types::UnixTime;
|
|
|
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
|
|
|
use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
|
|
|
|
use rustls::server::WebPkiClientVerifier;
|
|
|
|
use rustls::{
|
|
|
|
CipherSuite, RootCertStore, ServerConfig, SupportedCipherSuite, SupportedProtocolVersion,
|
|
|
|
CipherSuite, DigitallySignedStruct, DistinguishedName, RootCertStore, ServerConfig,
|
|
|
|
SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion,
|
|
|
|
};
|
|
|
|
use x509_parser::prelude::{FromDer, X509Certificate};
|
|
|
|
|
|
|
|
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
|
|
|
pub struct TlsPolicy {
|
|
|
|
@ -16,6 +22,12 @@ pub struct TlsPolicy { |
|
|
|
pub cipher_suites: String,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
|
|
|
pub struct GrpcClientAuthPolicy {
|
|
|
|
pub allowed_common_names: Vec<String>,
|
|
|
|
pub allowed_wildcard_domain: String,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
|
|
pub struct TlsPolicyError(String);
|
|
|
|
|
|
|
|
@ -36,11 +48,107 @@ enum GoTlsVersion { |
|
|
|
Tls13,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
|
struct CommonNameVerifier {
|
|
|
|
inner: Arc<dyn ClientCertVerifier>,
|
|
|
|
allowed_common_names: HashSet<String>,
|
|
|
|
allowed_wildcard_domain: String,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ClientCertVerifier for CommonNameVerifier {
|
|
|
|
fn offer_client_auth(&self) -> bool {
|
|
|
|
self.inner.offer_client_auth()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn client_auth_mandatory(&self) -> bool {
|
|
|
|
self.inner.client_auth_mandatory()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn root_hint_subjects(&self) -> &[DistinguishedName] {
|
|
|
|
self.inner.root_hint_subjects()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn verify_client_cert(
|
|
|
|
&self,
|
|
|
|
end_entity: &CertificateDer<'_>,
|
|
|
|
intermediates: &[CertificateDer<'_>],
|
|
|
|
now: UnixTime,
|
|
|
|
) -> Result<ClientCertVerified, rustls::Error> {
|
|
|
|
self.inner
|
|
|
|
.verify_client_cert(end_entity, intermediates, now)?;
|
|
|
|
let common_name = parse_common_name(end_entity).map_err(|e| {
|
|
|
|
rustls::Error::General(format!(
|
|
|
|
"parse client certificate common name failed: {}",
|
|
|
|
e
|
|
|
|
))
|
|
|
|
})?;
|
|
|
|
if common_name_is_allowed(
|
|
|
|
&common_name,
|
|
|
|
&self.allowed_common_names,
|
|
|
|
&self.allowed_wildcard_domain,
|
|
|
|
) {
|
|
|
|
return Ok(ClientCertVerified::assertion());
|
|
|
|
}
|
|
|
|
Err(rustls::Error::General(format!(
|
|
|
|
"Authenticate: invalid subject client common name: {}",
|
|
|
|
common_name
|
|
|
|
)))
|
|
|
|
}
|
|
|
|
|
|
|
|
fn verify_tls12_signature(
|
|
|
|
&self,
|
|
|
|
message: &[u8],
|
|
|
|
cert: &CertificateDer<'_>,
|
|
|
|
dss: &DigitallySignedStruct,
|
|
|
|
) -> Result<HandshakeSignatureValid, rustls::Error> {
|
|
|
|
self.inner.verify_tls12_signature(message, cert, dss)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn verify_tls13_signature(
|
|
|
|
&self,
|
|
|
|
message: &[u8],
|
|
|
|
cert: &CertificateDer<'_>,
|
|
|
|
dss: &DigitallySignedStruct,
|
|
|
|
) -> Result<HandshakeSignatureValid, rustls::Error> {
|
|
|
|
self.inner.verify_tls13_signature(message, cert, dss)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
|
|
|
|
self.inner.supported_verify_schemes()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn build_rustls_server_config(
|
|
|
|
cert_path: &str,
|
|
|
|
key_path: &str,
|
|
|
|
ca_path: &str,
|
|
|
|
policy: &TlsPolicy,
|
|
|
|
) -> Result<ServerConfig, TlsPolicyError> {
|
|
|
|
build_rustls_server_config_with_client_auth(cert_path, key_path, ca_path, policy, None)
|
|
|
|
}
|
|
|
|
|
|
|
|
pub fn build_rustls_server_config_with_grpc_client_auth(
|
|
|
|
cert_path: &str,
|
|
|
|
key_path: &str,
|
|
|
|
ca_path: &str,
|
|
|
|
policy: &TlsPolicy,
|
|
|
|
client_auth_policy: &GrpcClientAuthPolicy,
|
|
|
|
) -> Result<ServerConfig, TlsPolicyError> {
|
|
|
|
build_rustls_server_config_with_client_auth(
|
|
|
|
cert_path,
|
|
|
|
key_path,
|
|
|
|
ca_path,
|
|
|
|
policy,
|
|
|
|
Some(client_auth_policy),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn build_rustls_server_config_with_client_auth(
|
|
|
|
cert_path: &str,
|
|
|
|
key_path: &str,
|
|
|
|
ca_path: &str,
|
|
|
|
policy: &TlsPolicy,
|
|
|
|
client_auth_policy: Option<&GrpcClientAuthPolicy>,
|
|
|
|
) -> Result<ServerConfig, TlsPolicyError> {
|
|
|
|
let cert_chain = read_cert_chain(cert_path)?;
|
|
|
|
let private_key = read_private_key(key_path)?;
|
|
|
|
@ -55,9 +163,27 @@ pub fn build_rustls_server_config( |
|
|
|
builder.with_no_client_auth()
|
|
|
|
} else {
|
|
|
|
let roots = read_root_store(ca_path)?;
|
|
|
|
let verifier = WebPkiClientVerifier::builder_with_provider(Arc::new(roots), provider)
|
|
|
|
.build()
|
|
|
|
.map_err(|e| TlsPolicyError(format!("build client verifier failed: {}", e)))?;
|
|
|
|
let verifier =
|
|
|
|
WebPkiClientVerifier::builder_with_provider(Arc::new(roots), provider.clone())
|
|
|
|
.build()
|
|
|
|
.map_err(|e| TlsPolicyError(format!("build client verifier failed: {}", e)))?;
|
|
|
|
let verifier: Arc<dyn ClientCertVerifier> = if let Some(client_auth_policy) =
|
|
|
|
client_auth_policy.filter(|policy| {
|
|
|
|
!policy.allowed_common_names.is_empty()
|
|
|
|
|| !policy.allowed_wildcard_domain.is_empty()
|
|
|
|
}) {
|
|
|
|
Arc::new(CommonNameVerifier {
|
|
|
|
inner: verifier,
|
|
|
|
allowed_common_names: client_auth_policy
|
|
|
|
.allowed_common_names
|
|
|
|
.iter()
|
|
|
|
.cloned()
|
|
|
|
.collect(),
|
|
|
|
allowed_wildcard_domain: client_auth_policy.allowed_wildcard_domain.clone(),
|
|
|
|
})
|
|
|
|
} else {
|
|
|
|
verifier
|
|
|
|
};
|
|
|
|
builder.with_client_cert_verifier(verifier)
|
|
|
|
};
|
|
|
|
|
|
|
|
@ -209,6 +335,30 @@ fn parse_cipher_suite_name(value: &str) -> Result<CipherSuite, TlsPolicyError> { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn parse_common_name(cert: &CertificateDer<'_>) -> Result<String, TlsPolicyError> {
|
|
|
|
let (_, certificate) = X509Certificate::from_der(cert.as_ref())
|
|
|
|
.map_err(|e| TlsPolicyError(format!("parse X.509 certificate failed: {}", e)))?;
|
|
|
|
let common_name = certificate
|
|
|
|
.subject()
|
|
|
|
.iter_common_name()
|
|
|
|
.next()
|
|
|
|
.and_then(|common_name| common_name.as_str().ok())
|
|
|
|
.map(str::to_string);
|
|
|
|
match common_name {
|
|
|
|
Some(common_name) => Ok(common_name),
|
|
|
|
None => Ok(String::new()),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn common_name_is_allowed(
|
|
|
|
common_name: &str,
|
|
|
|
allowed_common_names: &HashSet<String>,
|
|
|
|
allowed_wildcard_domain: &str,
|
|
|
|
) -> bool {
|
|
|
|
(!allowed_wildcard_domain.is_empty() && common_name.ends_with(allowed_wildcard_domain))
|
|
|
|
|| allowed_common_names.contains(common_name)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn go_tls_version_for_supported(version: &SupportedProtocolVersion) -> GoTlsVersion {
|
|
|
|
match version.version {
|
|
|
|
rustls::ProtocolVersion::TLSv1_2 => GoTlsVersion::Tls12,
|
|
|
|
@ -219,8 +369,9 @@ fn go_tls_version_for_supported(version: &SupportedProtocolVersion) -> GoTlsVers |
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use super::{build_supported_versions, parse_cipher_suites, TlsPolicy};
|
|
|
|
use super::{build_supported_versions, common_name_is_allowed, parse_cipher_suites, TlsPolicy};
|
|
|
|
use rustls::crypto::aws_lc_rs;
|
|
|
|
use std::collections::HashSet;
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_build_supported_versions_defaults_to_tls12_and_tls13() {
|
|
|
|
@ -262,4 +413,25 @@ mod tests { |
|
|
|
.unwrap();
|
|
|
|
assert_eq!(cipher_suites.len(), 2);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
fn test_common_name_is_allowed_matches_exact_and_wildcard() {
|
|
|
|
let allowed_common_names =
|
|
|
|
HashSet::from([String::from("volume-a.internal"), String::from("worker-7")]);
|
|
|
|
assert!(common_name_is_allowed(
|
|
|
|
"volume-a.internal",
|
|
|
|
&allowed_common_names,
|
|
|
|
"",
|
|
|
|
));
|
|
|
|
assert!(common_name_is_allowed(
|
|
|
|
"node.prod.example.com",
|
|
|
|
&allowed_common_names,
|
|
|
|
".example.com",
|
|
|
|
));
|
|
|
|
assert!(!common_name_is_allowed(
|
|
|
|
"node.prod.other.net",
|
|
|
|
&allowed_common_names,
|
|
|
|
".example.com",
|
|
|
|
));
|
|
|
|
}
|
|
|
|
}
|