diff --git a/acmed/src/acme_proto/structs/directory.rs b/acmed/src/acme_proto/structs/directory.rs index 3b82d12..42e5198 100644 --- a/acmed/src/acme_proto/structs/directory.rs +++ b/acmed/src/acme_proto/structs/directory.rs @@ -2,7 +2,7 @@ use acme_common::error::Error; use serde::Deserialize; use std::str::FromStr; -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct DirectoryMeta { pub terms_of_service: Option, @@ -11,7 +11,7 @@ pub struct DirectoryMeta { pub external_account_required: Option, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Directory { pub meta: Option, diff --git a/acmed/src/config.rs b/acmed/src/config.rs index 53715e1..fa6bd4e 100644 --- a/acmed/src/config.rs +++ b/acmed/src/config.rs @@ -184,10 +184,13 @@ pub struct Endpoint { } impl Endpoint { - fn to_generic(&self, _cnf: &Config) -> Result { - // TODO: include rate limits using `cnf.get_rate_limit()` - let ep = crate::endpoint::Endpoint::new(&self.name, &self.url, self.tos_agreed); - Ok(ep) + fn to_generic(&self, cnf: &Config) -> Result { + let mut limits = vec![]; + for rl_name in self.rate_limits.iter() { + let (nb, timeframe) = cnf.get_rate_limit(&rl_name)?; + limits.push((nb, timeframe)); + } + crate::endpoint::Endpoint::new(&self.name, &self.url, self.tos_agreed, &limits) } } diff --git a/acmed/src/endpoint.rs b/acmed/src/endpoint.rs index 3d1fd2e..c5d97fd 100644 --- a/acmed/src/endpoint.rs +++ b/acmed/src/endpoint.rs @@ -1,20 +1,33 @@ use crate::acme_proto::structs::Directory; +use acme_common::error::Error; +use nom::bytes::complete::take_while_m_n; +use nom::character::complete::digit1; +use nom::combinator::map_res; +use nom::multi::fold_many1; +use nom::IResult; +use std::cmp; +use std::thread; +use std::time::{Duration, Instant}; +#[derive(Debug)] pub struct Endpoint { pub name: String, pub url: String, pub tos_agreed: bool, - pub dir: Directory, pub nonce: Option, - // TODO: rate limits + pub rl: RateLimit, + pub dir: Directory, } impl Endpoint { - pub fn new(name: &str, url: &str, tos_agreed: bool) -> Self { - Self { + pub fn new(name: &str, url: &str, tos_agreed: bool, limits: &[(usize, String)]) -> Result { + let rl = RateLimit::new(limits)?; + Ok(Self { name: name.to_string(), url: url.to_string(), tos_agreed, + nonce: None, + rl, dir: Directory { meta: None, new_nonce: String::new(), @@ -24,7 +37,124 @@ impl Endpoint { revoke_cert: String::new(), key_change: String::new(), }, - nonce: None, + }) + } +} + +#[derive(Clone, Debug)] +pub struct RateLimit { + limits: Vec<(usize, Duration)>, + query_log: Vec, +} + +impl RateLimit { + pub fn new(raw_limits: &[(usize, String)]) -> Result { + let mut limits = vec![]; + for (nb, raw_duration) in raw_limits.iter() { + let parsed_duration = parse_duration(raw_duration)?; + limits.push((*nb, parsed_duration)); + } + limits.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + limits.reverse(); + Ok(Self { + limits, + query_log: vec![], + }) + } + + pub fn block_until_allowed(&mut self) { + if self.limits.is_empty() { + return; + } + let sleep_duration = self.get_sleep_duration(); + loop { + self.prune_log(); + if self.request_allowed() { + self.query_log.push(Instant::now()); + return; + } + // TODO: find a better sleep duration + thread::sleep(sleep_duration); + } + } + + fn get_sleep_duration(&self) -> Duration { + let (nb_req, min_duration) = match self.limits.last() { + Some((n, d)) => (*n as u64, *d), + None => { + return Duration::from_millis(0); + } + }; + let nb_mili = match min_duration.as_secs() { + 0 | 1 => crate::MIN_RATE_LIMIT_SLEEP_MILISEC, + n => { + let a = n * 200 / nb_req; + let a = cmp::min(a, crate::MAX_RATE_LIMIT_SLEEP_MILISEC); + cmp::max(a, crate::MIN_RATE_LIMIT_SLEEP_MILISEC) + } + }; + Duration::from_millis(nb_mili) + } + + fn request_allowed(&self) -> bool { + for (max_allowed, duration) in self.limits.iter() { + let max_date = Instant::now() - *duration; + let nb_req = self.query_log.iter().filter(move |x| **x > max_date).count(); + if nb_req >= *max_allowed { + return false; + } + } + true + } + + fn prune_log(&mut self) { + if let Some((_, max_limit)) = self.limits.first() { + let prune_date = Instant::now() - *max_limit; + self.query_log.retain(move |&d| d > prune_date); } } } + +fn is_duration_chr(c: char) -> bool { + c == 's' || c == 'm' || c == 'h' || c == 'd' || c == 'w' +} + +fn get_multiplicator(input: &str) -> IResult<&str, u64> { + let (input, nb) = take_while_m_n(1, 1, is_duration_chr)(input)?; + let mult = match nb.chars().nth(0) { + Some('s') => 1, + Some('m') => 60, + Some('h') => 3_600, + Some('d') => 86_400, + Some('w') => 604_800, + _ => 0, + }; + Ok((input, mult)) +} + +fn get_duration_part(input: &str) -> IResult<&str, Duration> { + let (input, nb) = map_res(digit1, |s: &str| s.parse::())(input)?; + let (input, mult) = get_multiplicator(input)?; + Ok((input, Duration::from_secs(nb * mult))) +} + +fn get_duration(input: &str) -> IResult<&str, Duration> { + fold_many1( + get_duration_part, + Duration::new(0, 0), + |mut acc: Duration, item| { + acc += item; + acc + }, + )(input) +} + +fn parse_duration(input: &str) -> Result { + match get_duration(input) { + Ok((r, d)) => match r.len() { + 0 => Ok(d), + _ => Err(format!("{}: invalid duration", input).into()), + }, + Err(_) => Err(format!("{}: invalid duration", input).into()), + } +} diff --git a/acmed/src/http.rs b/acmed/src/http.rs index 1986320..7bdcbff 100644 --- a/acmed/src/http.rs +++ b/acmed/src/http.rs @@ -51,8 +51,8 @@ fn check_status(response: &Response) -> Result<(), Error> { Ok(()) } -fn rate_limit(endpoint: &Endpoint) { - // TODO: Implement +fn rate_limit(endpoint: &mut Endpoint) { + endpoint.rl.block_until_allowed(); } pub fn header_to_string(header_value: &HeaderValue) -> Result {