iroh_quinn_proto/congestion/
cubic.rs

1use std::any::Any;
2use std::cmp;
3use std::sync::Arc;
4
5use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
6use crate::connection::RttEstimator;
7use crate::{Duration, Instant};
8
9/// CUBIC Constants.
10///
11/// These are recommended value in RFC8312.
12const BETA_CUBIC: f64 = 0.7;
13
14const C: f64 = 0.4;
15
16/// CUBIC State Variables.
17///
18/// We need to keep those variables across the connection.
19/// k, w_max are described in the RFC.
20#[derive(Debug, Default, Clone)]
21pub(super) struct State {
22    /// Time period that the cubic function takes to increase the window size to W_max.
23    k: f64,
24
25    /// Congestion window size when the last congestion event occurred.
26    w_max: f64,
27
28    /// Congestion window increment stored during congestion avoidance.
29    cwnd_inc: u64,
30
31    /// Maximum number of bytes in flight that may be sent.
32    window: u64,
33
34    /// Slow start threshold in bytes.
35    ///
36    /// When the congestion window is below ssthresh, the mode is slow start
37    /// and the window grows by the number of bytes acknowledged.
38    ssthresh: u64,
39
40    /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent
41    /// after this time is acknowledged, QUIC exits recovery.
42    recovery_start_time: Option<Instant>,
43}
44
45/// CUBIC Functions.
46///
47/// Note that these calculations are based on a count of cwnd as bytes,
48/// not packets.
49/// Unit of t (duration) and RTT are based on seconds (f64).
50impl State {
51    // K = cbrt(w_max * (1 - beta_cubic) / C) (Eq. 2)
52    fn cubic_k(&self, max_datagram_size: u64) -> f64 {
53        let w_max = self.w_max / max_datagram_size as f64;
54        (w_max * (1.0 - BETA_CUBIC) / C).cbrt()
55    }
56
57    // W_cubic(t) = C * (t - K)^3 + w_max (Eq. 1)
58    fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 {
59        let w_max = self.w_max / max_datagram_size as f64;
60
61        (C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64
62    }
63
64    // W_est(t) = w_max * beta_cubic + 3 * (1 - beta_cubic) / (1 + beta_cubic) *
65    // (t / RTT) (Eq. 4)
66    fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 {
67        let w_max = self.w_max / max_datagram_size as f64;
68        (w_max * BETA_CUBIC
69            + 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64())
70            * max_datagram_size as f64
71    }
72}
73
74/// The RFC8312 congestion controller, as widely used for TCP
75#[derive(Debug, Clone)]
76pub struct Cubic {
77    config: Arc<CubicConfig>,
78    current_mtu: u64,
79    state: State,
80    /// Copy of the controller state to restore when a spurious congestion event is detected.
81    pre_congestion_state: Option<State>,
82}
83
84impl Cubic {
85    /// Construct a state using the given `config` and current time `now`
86    pub fn new(config: Arc<CubicConfig>, _now: Instant, current_mtu: u16) -> Self {
87        Self {
88            state: State {
89                window: config.initial_window,
90                ssthresh: u64::MAX,
91                ..Default::default()
92            },
93            current_mtu: current_mtu as u64,
94            pre_congestion_state: None,
95            config,
96        }
97    }
98
99    fn minimum_window(&self) -> u64 {
100        2 * self.current_mtu
101    }
102}
103
104impl Controller for Cubic {
105    fn on_ack(
106        &mut self,
107        now: Instant,
108        sent: Instant,
109        bytes: u64,
110        app_limited: bool,
111        rtt: &RttEstimator,
112    ) {
113        if app_limited
114            || self
115                .state
116                .recovery_start_time
117                .map(|recovery_start_time| sent <= recovery_start_time)
118                .unwrap_or(false)
119        {
120            return;
121        }
122
123        if self.state.window < self.state.ssthresh {
124            // Slow start
125            self.state.window += bytes;
126        } else {
127            // Congestion avoidance.
128            let ca_start_time;
129
130            match self.state.recovery_start_time {
131                Some(t) => ca_start_time = t,
132                None => {
133                    // When we come here without congestion_event() triggered,
134                    // initialize congestion_recovery_start_time, w_max and k.
135                    ca_start_time = now;
136                    self.state.recovery_start_time = Some(now);
137
138                    self.state.w_max = self.state.window as f64;
139                    self.state.k = 0.0;
140                }
141            }
142
143            let t = now - ca_start_time;
144
145            // w_cubic(t + rtt)
146            let w_cubic = self.state.w_cubic(t + rtt.get(), self.current_mtu);
147
148            // w_est(t)
149            let w_est = self.state.w_est(t, rtt.get(), self.current_mtu);
150
151            let mut cubic_cwnd = self.state.window;
152
153            if w_cubic < w_est {
154                // TCP friendly region.
155                cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64);
156            } else if cubic_cwnd < w_cubic as u64 {
157                // Concave region or convex region use same increment.
158                let cubic_inc =
159                    (w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64;
160
161                cubic_cwnd += cubic_inc as u64;
162            }
163
164            // Update the increment and increase cwnd by MSS.
165            self.state.cwnd_inc += cubic_cwnd - self.state.window;
166
167            // cwnd_inc can be more than 1 MSS in the late stage of max probing.
168            // however RFC9002 ยง7.3.3 (Congestion Avoidance) limits
169            // the increase of cwnd to 1 max_datagram_size per cwnd acknowledged.
170            if self.state.cwnd_inc >= self.current_mtu {
171                self.state.window += self.current_mtu;
172                self.state.cwnd_inc = 0;
173            }
174        }
175    }
176
177    fn on_congestion_event(
178        &mut self,
179        now: Instant,
180        sent: Instant,
181        is_persistent_congestion: bool,
182        is_ecn: bool,
183        _lost_bytes: u64,
184    ) {
185        if self
186            .state
187            .recovery_start_time
188            .map(|recovery_start_time| sent <= recovery_start_time)
189            .unwrap_or(false)
190        {
191            return;
192        }
193
194        // Save state in case this event ends up being spurious
195        if !is_ecn {
196            self.pre_congestion_state = Some(self.state.clone());
197        }
198
199        self.state.recovery_start_time = Some(now);
200
201        // Fast convergence
202        if (self.state.window as f64) < self.state.w_max {
203            self.state.w_max = self.state.window as f64 * (1.0 + BETA_CUBIC) / 2.0;
204        } else {
205            self.state.w_max = self.state.window as f64;
206        }
207
208        self.state.ssthresh = cmp::max(
209            (self.state.w_max * BETA_CUBIC) as u64,
210            self.minimum_window(),
211        );
212        self.state.window = self.state.ssthresh;
213        self.state.k = self.state.cubic_k(self.current_mtu);
214
215        self.state.cwnd_inc = (self.state.cwnd_inc as f64 * BETA_CUBIC) as u64;
216
217        if is_persistent_congestion {
218            self.state.recovery_start_time = None;
219            self.state.w_max = self.state.window as f64;
220
221            // 4.7 Timeout - reduce ssthresh based on BETA_CUBIC
222            self.state.ssthresh = cmp::max(
223                (self.state.window as f64 * BETA_CUBIC) as u64,
224                self.minimum_window(),
225            );
226
227            self.state.cwnd_inc = 0;
228
229            self.state.window = self.minimum_window();
230        }
231    }
232
233    fn on_spurious_congestion_event(&mut self) {
234        if let Some(prior_state) = self.pre_congestion_state.take() {
235            if self.state.window < prior_state.window {
236                self.state = prior_state;
237            }
238        }
239    }
240
241    fn on_mtu_update(&mut self, new_mtu: u16) {
242        self.current_mtu = new_mtu as u64;
243        self.state.window = self.state.window.max(self.minimum_window());
244    }
245
246    fn window(&self) -> u64 {
247        self.state.window
248    }
249
250    fn metrics(&self) -> super::ControllerMetrics {
251        super::ControllerMetrics {
252            congestion_window: self.window(),
253            ssthresh: Some(self.state.ssthresh),
254            pacing_rate: None,
255        }
256    }
257
258    fn clone_box(&self) -> Box<dyn Controller> {
259        Box::new(self.clone())
260    }
261
262    fn initial_window(&self) -> u64 {
263        self.config.initial_window
264    }
265
266    fn into_any(self: Box<Self>) -> Box<dyn Any> {
267        self
268    }
269}
270
271/// Configuration for the `Cubic` congestion controller
272#[derive(Debug, Clone)]
273pub struct CubicConfig {
274    initial_window: u64,
275}
276
277impl CubicConfig {
278    /// Default limit on the amount of outstanding data in bytes.
279    ///
280    /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))`
281    pub fn initial_window(&mut self, value: u64) -> &mut Self {
282        self.initial_window = value;
283        self
284    }
285}
286
287impl Default for CubicConfig {
288    fn default() -> Self {
289        Self {
290            initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE),
291        }
292    }
293}
294
295impl ControllerFactory for CubicConfig {
296    fn build(self: Arc<Self>, now: Instant, current_mtu: u16) -> Box<dyn Controller> {
297        Box::new(Cubic::new(self, now, current_mtu))
298    }
299}