noq_proto/congestion/
cubic.rs

1use std::any::Any;
2use std::cmp;
3use std::sync::Arc;
4
5use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
6use crate::connection::RttEstimator;
7use crate::{Duration, Instant};
8
9/// CUBIC Constants.
10///
11/// These are recommended value in RFC8312.
12const BETA_CUBIC: f64 = 0.7;
13
14const C: f64 = 0.4;
15
16/// CUBIC State Variables.
17///
18/// We need to keep those variables across the connection.
19/// k, w_max are described in the RFC.
20#[derive(Debug, Default, Clone)]
21pub(super) struct State {
22    /// Time period that the cubic function takes to increase the window size to W_max.
23    k: f64,
24
25    /// Congestion window size when the last congestion event occurred.
26    w_max: f64,
27
28    /// Congestion window increment stored during congestion avoidance.
29    cwnd_inc: u64,
30
31    /// Maximum number of bytes in flight that may be sent.
32    window: u64,
33
34    /// Slow start threshold in bytes.
35    ///
36    /// When the congestion window is below ssthresh, the mode is slow start
37    /// and the window grows by the number of bytes acknowledged.
38    ssthresh: u64,
39
40    /// The time when QUIC first detects a loss, causing it to enter recovery. When a packet sent
41    /// after this time is acknowledged, QUIC exits recovery.
42    recovery_start_time: Option<Instant>,
43}
44
45/// CUBIC Functions.
46///
47/// Note that these calculations are based on a count of cwnd as bytes,
48/// not packets.
49/// Unit of t (duration) and RTT are based on seconds (f64).
50impl State {
51    // K = cbrt(w_max * (1 - beta_cubic) / C) (Eq. 2)
52    fn cubic_k(&self, max_datagram_size: u64) -> f64 {
53        let w_max = self.w_max / max_datagram_size as f64;
54        (w_max * (1.0 - BETA_CUBIC) / C).cbrt()
55    }
56
57    // W_cubic(t) = C * (t - K)^3 + w_max (Eq. 1)
58    fn w_cubic(&self, t: Duration, max_datagram_size: u64) -> f64 {
59        let w_max = self.w_max / max_datagram_size as f64;
60
61        (C * (t.as_secs_f64() - self.k).powi(3) + w_max) * max_datagram_size as f64
62    }
63
64    // W_est(t) = w_max * beta_cubic + 3 * (1 - beta_cubic) / (1 + beta_cubic) *
65    // (t / RTT) (Eq. 4)
66    fn w_est(&self, t: Duration, rtt: Duration, max_datagram_size: u64) -> f64 {
67        let w_max = self.w_max / max_datagram_size as f64;
68        (w_max * BETA_CUBIC
69            + 3.0 * (1.0 - BETA_CUBIC) / (1.0 + BETA_CUBIC) * t.as_secs_f64() / rtt.as_secs_f64())
70            * max_datagram_size as f64
71    }
72}
73
74/// The RFC8312 congestion controller, as widely used for TCP
75#[derive(Debug, Clone)]
76pub struct Cubic {
77    config: Arc<CubicConfig>,
78    current_mtu: u64,
79    state: State,
80    /// Copy of the controller state to restore when a spurious congestion event is detected.
81    pre_congestion_state: Option<State>,
82}
83
84impl Cubic {
85    /// Construct a state using the given `config` and current time `now`
86    pub fn new(config: Arc<CubicConfig>, _now: Instant, current_mtu: u16) -> Self {
87        Self {
88            state: State {
89                window: config.initial_window,
90                ssthresh: u64::MAX,
91                ..Default::default()
92            },
93            current_mtu: current_mtu as u64,
94            pre_congestion_state: None,
95            config,
96        }
97    }
98
99    fn minimum_window(&self) -> u64 {
100        2 * self.current_mtu
101    }
102}
103
104impl Controller for Cubic {
105    fn on_ack(
106        &mut self,
107        now: Instant,
108        sent: Instant,
109        bytes: u64,
110        _pn: u64,
111        app_limited: bool,
112        rtt: &RttEstimator,
113    ) {
114        if app_limited
115            || self
116                .state
117                .recovery_start_time
118                .map(|recovery_start_time| sent <= recovery_start_time)
119                .unwrap_or(false)
120        {
121            return;
122        }
123
124        if self.state.window < self.state.ssthresh {
125            // Slow start
126            self.state.window += bytes;
127        } else {
128            // Congestion avoidance.
129            let ca_start_time;
130
131            match self.state.recovery_start_time {
132                Some(t) => ca_start_time = t,
133                None => {
134                    // When we come here without congestion_event() triggered,
135                    // initialize congestion_recovery_start_time, w_max and k.
136                    ca_start_time = now;
137                    self.state.recovery_start_time = Some(now);
138
139                    self.state.w_max = self.state.window as f64;
140                    self.state.k = 0.0;
141                }
142            }
143
144            let t = now - ca_start_time;
145
146            // w_cubic(t + rtt)
147            let w_cubic = self.state.w_cubic(t + rtt.get(), self.current_mtu);
148
149            // w_est(t)
150            let w_est = self.state.w_est(t, rtt.get(), self.current_mtu);
151
152            let mut cubic_cwnd = self.state.window;
153
154            if w_cubic < w_est {
155                // TCP friendly region.
156                cubic_cwnd = cmp::max(cubic_cwnd, w_est as u64);
157            } else if cubic_cwnd < w_cubic as u64 {
158                // Concave region or convex region use same increment.
159                let cubic_inc =
160                    (w_cubic - cubic_cwnd as f64) / cubic_cwnd as f64 * self.current_mtu as f64;
161
162                cubic_cwnd += cubic_inc as u64;
163            }
164
165            // Update the increment and increase cwnd by MSS.
166            self.state.cwnd_inc += cubic_cwnd - self.state.window;
167
168            // cwnd_inc can be more than 1 MSS in the late stage of max probing.
169            // however RFC9002 ยง7.3.3 (Congestion Avoidance) limits
170            // the increase of cwnd to 1 max_datagram_size per cwnd acknowledged.
171            if self.state.cwnd_inc >= self.current_mtu {
172                self.state.window += self.current_mtu;
173                self.state.cwnd_inc = 0;
174            }
175        }
176    }
177
178    fn on_congestion_event(
179        &mut self,
180        now: Instant,
181        sent: Instant,
182        is_persistent_congestion: bool,
183        is_ecn: bool,
184        _lost_bytes: u64,
185        _largest_lost_pn: u64,
186    ) {
187        if self
188            .state
189            .recovery_start_time
190            .map(|recovery_start_time| sent <= recovery_start_time)
191            .unwrap_or(false)
192        {
193            return;
194        }
195
196        // Save state in case this event ends up being spurious
197        if !is_ecn {
198            self.pre_congestion_state = Some(self.state.clone());
199        }
200
201        self.state.recovery_start_time = Some(now);
202
203        // Fast convergence
204        if (self.state.window as f64) < self.state.w_max {
205            self.state.w_max = self.state.window as f64 * (1.0 + BETA_CUBIC) / 2.0;
206        } else {
207            self.state.w_max = self.state.window as f64;
208        }
209
210        self.state.ssthresh = cmp::max(
211            (self.state.w_max * BETA_CUBIC) as u64,
212            self.minimum_window(),
213        );
214        self.state.window = self.state.ssthresh;
215        self.state.k = self.state.cubic_k(self.current_mtu);
216
217        self.state.cwnd_inc = (self.state.cwnd_inc as f64 * BETA_CUBIC) as u64;
218
219        if is_persistent_congestion {
220            self.state.recovery_start_time = None;
221            self.state.w_max = self.state.window as f64;
222
223            // 4.7 Timeout - reduce ssthresh based on BETA_CUBIC
224            self.state.ssthresh = cmp::max(
225                (self.state.window as f64 * BETA_CUBIC) as u64,
226                self.minimum_window(),
227            );
228
229            self.state.cwnd_inc = 0;
230
231            self.state.window = self.minimum_window();
232        }
233    }
234
235    fn on_spurious_congestion_event(&mut self) {
236        if let Some(prior_state) = self.pre_congestion_state.take()
237            && self.state.window < prior_state.window
238        {
239            self.state = prior_state;
240        }
241    }
242
243    fn on_mtu_update(&mut self, new_mtu: u16) {
244        self.current_mtu = new_mtu as u64;
245        self.state.window = self.state.window.max(self.minimum_window());
246    }
247
248    fn window(&self) -> u64 {
249        self.state.window
250    }
251
252    fn metrics(&self) -> super::ControllerMetrics {
253        super::ControllerMetrics {
254            congestion_window: self.window(),
255            ssthresh: Some(self.state.ssthresh),
256            pacing_rate: None,
257            send_quantum: None,
258        }
259    }
260
261    fn clone_box(&self) -> Box<dyn Controller> {
262        Box::new(self.clone())
263    }
264
265    fn initial_window(&self) -> u64 {
266        self.config.initial_window
267    }
268
269    fn into_any(self: Box<Self>) -> Box<dyn Any> {
270        self
271    }
272}
273
274/// Configuration for the `Cubic` congestion controller
275#[derive(Debug, Clone)]
276pub struct CubicConfig {
277    initial_window: u64,
278}
279
280impl CubicConfig {
281    /// Default limit on the amount of outstanding data in bytes.
282    ///
283    /// Recommended value: `min(10 * max_datagram_size, max(2 * max_datagram_size, 14720))`
284    pub fn initial_window(&mut self, value: u64) -> &mut Self {
285        self.initial_window = value;
286        self
287    }
288}
289
290impl Default for CubicConfig {
291    fn default() -> Self {
292        Self {
293            initial_window: 14720.clamp(2 * BASE_DATAGRAM_SIZE, 10 * BASE_DATAGRAM_SIZE),
294        }
295    }
296}
297
298impl ControllerFactory for CubicConfig {
299    fn build(self: Arc<Self>, now: Instant, current_mtu: u16) -> Box<dyn Controller> {
300        Box::new(Cubic::new(self, now, current_mtu))
301    }
302}