1use std::any::Any;
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use rand::{Rng, SeedableRng};
6
7use crate::congestion::ControllerMetrics;
8use crate::congestion::bbr::bw_estimation::BandwidthEstimation;
9use crate::congestion::bbr::min_max::MinMax;
10use crate::connection::RttEstimator;
11use crate::{Duration, Instant};
12
13use super::{BASE_DATAGRAM_SIZE, Controller, ControllerFactory};
14
15mod bw_estimation;
16mod min_max;
17
18#[derive(Debug, Clone)]
25pub struct Bbr {
26 config: Arc<BbrConfig>,
27 current_mtu: u64,
28 max_bandwidth: BandwidthEstimation,
29 acked_bytes: u64,
30 mode: Mode,
31 loss_state: LossState,
32 recovery_state: RecoveryState,
33 recovery_window: u64,
34 is_at_full_bandwidth: bool,
35 pacing_gain: f32,
36 high_gain: f32,
37 drain_gain: f32,
38 cwnd_gain: f32,
39 high_cwnd_gain: f32,
40 last_cycle_start: Option<Instant>,
41 current_cycle_offset: u8,
42 init_cwnd: u64,
43 min_cwnd: u64,
44 prev_in_flight_count: u64,
45 exit_probe_rtt_at: Option<Instant>,
46 probe_rtt_last_started_at: Option<Instant>,
47 min_rtt: Duration,
48 exiting_quiescence: bool,
49 pacing_rate: u64,
50 max_acked_packet_number: u64,
51 max_sent_packet_number: u64,
52 end_recovery_at_packet_number: u64,
53 cwnd: u64,
54 current_round_trip_end_packet_number: u64,
55 round_count: u64,
56 bw_at_last_round: u64,
57 round_wo_bw_gain: u64,
58 ack_aggregation: AckAggregationState,
59 random_number_generator: rand::rngs::StdRng,
60}
61
62impl Bbr {
63 pub fn new(config: Arc<BbrConfig>, current_mtu: u16) -> Self {
65 let initial_window = config.initial_window;
66 Self {
67 config,
68 current_mtu: current_mtu as u64,
69 max_bandwidth: BandwidthEstimation::default(),
70 acked_bytes: 0,
71 mode: Mode::Startup,
72 loss_state: Default::default(),
73 recovery_state: RecoveryState::NotInRecovery,
74 recovery_window: 0,
75 is_at_full_bandwidth: false,
76 pacing_gain: K_DEFAULT_HIGH_GAIN,
77 high_gain: K_DEFAULT_HIGH_GAIN,
78 drain_gain: 1.0 / K_DEFAULT_HIGH_GAIN,
79 cwnd_gain: K_DEFAULT_HIGH_GAIN,
80 high_cwnd_gain: K_DEFAULT_HIGH_GAIN,
81 last_cycle_start: None,
82 current_cycle_offset: 0,
83 init_cwnd: initial_window,
84 min_cwnd: calculate_min_window(current_mtu as u64),
85 prev_in_flight_count: 0,
86 exit_probe_rtt_at: None,
87 probe_rtt_last_started_at: None,
88 min_rtt: Default::default(),
89 exiting_quiescence: false,
90 pacing_rate: 0,
91 max_acked_packet_number: 0,
92 max_sent_packet_number: 0,
93 end_recovery_at_packet_number: 0,
94 cwnd: initial_window,
95 current_round_trip_end_packet_number: 0,
96 round_count: 0,
97 bw_at_last_round: 0,
98 round_wo_bw_gain: 0,
99 ack_aggregation: AckAggregationState::default(),
100 random_number_generator: rand::rngs::StdRng::from_os_rng(),
101 }
102 }
103
104 fn enter_startup_mode(&mut self) {
105 self.mode = Mode::Startup;
106 self.pacing_gain = self.high_gain;
107 self.cwnd_gain = self.high_cwnd_gain;
108 }
109
110 fn enter_probe_bandwidth_mode(&mut self, now: Instant) {
111 self.mode = Mode::ProbeBw;
112 self.cwnd_gain = K_DERIVED_HIGH_CWNDGAIN;
113 self.last_cycle_start = Some(now);
114 let mut rand_index = self
118 .random_number_generator
119 .random_range(0..K_PACING_GAIN.len() as u8 - 1);
120 if rand_index >= 1 {
121 rand_index += 1;
122 }
123 self.current_cycle_offset = rand_index;
124 self.pacing_gain = K_PACING_GAIN[rand_index as usize];
125 }
126
127 fn update_recovery_state(&mut self, is_round_start: bool) {
128 if self.loss_state.has_losses() {
130 self.end_recovery_at_packet_number = self.max_sent_packet_number;
131 }
132 match self.recovery_state {
133 RecoveryState::NotInRecovery if self.loss_state.has_losses() => {
135 self.recovery_state = RecoveryState::Conservation;
136 self.recovery_window = 0;
139 self.current_round_trip_end_packet_number = self.max_sent_packet_number;
142 }
143 RecoveryState::Growth | RecoveryState::Conservation => {
144 if self.recovery_state == RecoveryState::Conservation && is_round_start {
145 self.recovery_state = RecoveryState::Growth;
146 }
147 if !self.loss_state.has_losses()
149 && self.max_acked_packet_number > self.end_recovery_at_packet_number
150 {
151 self.recovery_state = RecoveryState::NotInRecovery;
152 }
153 }
154 _ => {}
155 }
156 }
157
158 fn update_gain_cycle_phase(&mut self, now: Instant, in_flight: u64) {
159 let mut should_advance_gain_cycling = self
161 .last_cycle_start
162 .map(|last_cycle_start| now.duration_since(last_cycle_start) > self.min_rtt)
163 .unwrap_or(false);
164 if self.pacing_gain > 1.0
170 && !self.loss_state.has_losses()
171 && self.prev_in_flight_count < self.get_target_cwnd(self.pacing_gain)
172 {
173 should_advance_gain_cycling = false;
174 }
175
176 if self.pacing_gain < 1.0 && in_flight <= self.get_target_cwnd(1.0) {
182 should_advance_gain_cycling = true;
183 }
184
185 if should_advance_gain_cycling {
186 self.current_cycle_offset = (self.current_cycle_offset + 1) % K_PACING_GAIN.len() as u8;
187 self.last_cycle_start = Some(now);
188 if DRAIN_TO_TARGET
191 && self.pacing_gain < 1.0
192 && (K_PACING_GAIN[self.current_cycle_offset as usize] - 1.0).abs() < f32::EPSILON
193 && in_flight > self.get_target_cwnd(1.0)
194 {
195 return;
196 }
197 self.pacing_gain = K_PACING_GAIN[self.current_cycle_offset as usize];
198 }
199 }
200
201 fn maybe_exit_startup_or_drain(&mut self, now: Instant, in_flight: u64) {
202 if self.mode == Mode::Startup && self.is_at_full_bandwidth {
203 self.mode = Mode::Drain;
204 self.pacing_gain = self.drain_gain;
205 self.cwnd_gain = self.high_cwnd_gain;
206 }
207 if self.mode == Mode::Drain && in_flight <= self.get_target_cwnd(1.0) {
208 self.enter_probe_bandwidth_mode(now);
209 }
210 }
211
212 fn is_min_rtt_expired(&self, now: Instant, app_limited: bool) -> bool {
213 !app_limited
214 && self
215 .probe_rtt_last_started_at
216 .map(|last| now.saturating_duration_since(last) > Duration::from_secs(10))
217 .unwrap_or(true)
218 }
219
220 fn maybe_enter_or_exit_probe_rtt(
221 &mut self,
222 now: Instant,
223 is_round_start: bool,
224 bytes_in_flight: u64,
225 app_limited: bool,
226 ) {
227 let min_rtt_expired = self.is_min_rtt_expired(now, app_limited);
228 if min_rtt_expired && !self.exiting_quiescence && self.mode != Mode::ProbeRtt {
229 self.mode = Mode::ProbeRtt;
230 self.pacing_gain = 1.0;
231 self.exit_probe_rtt_at = None;
234 self.probe_rtt_last_started_at = Some(now);
235 }
236
237 if self.mode == Mode::ProbeRtt {
238 match self.exit_probe_rtt_at {
239 None => {
240 if bytes_in_flight < self.get_probe_rtt_cwnd() + self.current_mtu {
245 const K_PROBE_RTT_TIME: Duration = Duration::from_millis(200);
246 self.exit_probe_rtt_at = Some(now + K_PROBE_RTT_TIME);
247 }
248 }
249 Some(exit_time) if is_round_start && now >= exit_time => {
250 if !self.is_at_full_bandwidth {
251 self.enter_startup_mode();
252 } else {
253 self.enter_probe_bandwidth_mode(now);
254 }
255 }
256 Some(_) => {}
257 }
258 }
259
260 self.exiting_quiescence = false;
261 }
262
263 fn get_target_cwnd(&self, gain: f32) -> u64 {
264 let bw = self.max_bandwidth.get_estimate();
265 let bdp = self.min_rtt.as_micros() as u64 * bw;
266 let bdpf = bdp as f64;
267 let cwnd = ((gain as f64 * bdpf) / 1_000_000f64) as u64;
268 if cwnd == 0 {
270 return self.init_cwnd;
271 }
272 cwnd.max(self.min_cwnd)
273 }
274
275 fn get_probe_rtt_cwnd(&self) -> u64 {
276 const K_MODERATE_PROBE_RTT_MULTIPLIER: f32 = 0.75;
277 if PROBE_RTT_BASED_ON_BDP {
278 return self.get_target_cwnd(K_MODERATE_PROBE_RTT_MULTIPLIER);
279 }
280 self.min_cwnd
281 }
282
283 fn calculate_pacing_rate(&mut self) {
284 let bw = self.max_bandwidth.get_estimate();
285 if bw == 0 {
286 return;
287 }
288 let target_rate = (bw as f64 * self.pacing_gain as f64) as u64;
289 if self.is_at_full_bandwidth {
290 self.pacing_rate = target_rate;
291 return;
292 }
293
294 if self.pacing_rate == 0 && self.min_rtt.as_nanos() != 0 {
297 self.pacing_rate =
298 BandwidthEstimation::bw_from_delta(self.init_cwnd, self.min_rtt).unwrap();
299 return;
300 }
301
302 if self.pacing_rate < target_rate {
304 self.pacing_rate = target_rate;
305 }
306 }
307
308 fn calculate_cwnd(&mut self, bytes_acked: u64, excess_acked: u64) {
309 if self.mode == Mode::ProbeRtt {
310 return;
311 }
312 let mut target_window = self.get_target_cwnd(self.cwnd_gain);
313 if self.is_at_full_bandwidth {
314 target_window += self.ack_aggregation.max_ack_height.get();
316 } else {
317 target_window += excess_acked;
320 }
321 if self.is_at_full_bandwidth {
325 self.cwnd = target_window.min(self.cwnd + bytes_acked);
326 } else if (self.cwnd_gain < target_window as f32) || (self.acked_bytes < self.init_cwnd) {
327 self.cwnd += bytes_acked;
330 }
331
332 if self.cwnd < self.min_cwnd {
334 self.cwnd = self.min_cwnd;
335 }
336 }
337
338 fn calculate_recovery_window(&mut self, bytes_acked: u64, bytes_lost: u64, in_flight: u64) {
339 if !self.recovery_state.in_recovery() {
340 return;
341 }
342 if self.recovery_window == 0 {
344 self.recovery_window = self.min_cwnd.max(in_flight + bytes_acked);
345 return;
346 }
347
348 if self.recovery_window >= bytes_lost {
351 self.recovery_window -= bytes_lost;
352 } else {
353 self.recovery_window = self.current_mtu;
355 }
356 if self.recovery_state == RecoveryState::Growth {
359 self.recovery_window += bytes_acked;
360 }
361
362 self.recovery_window = self
365 .recovery_window
366 .max(in_flight + bytes_acked)
367 .max(self.min_cwnd);
368 }
369
370 fn check_if_full_bw_reached(&mut self, app_limited: bool) {
372 if app_limited {
373 return;
374 }
375 let target = (self.bw_at_last_round as f64 * K_STARTUP_GROWTH_TARGET as f64) as u64;
376 let bw = self.max_bandwidth.get_estimate();
377 if bw >= target {
378 self.bw_at_last_round = bw;
379 self.round_wo_bw_gain = 0;
380 self.ack_aggregation.max_ack_height.reset();
381 return;
382 }
383
384 self.round_wo_bw_gain += 1;
385 if self.round_wo_bw_gain >= K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP as u64
386 || (self.recovery_state.in_recovery())
387 {
388 self.is_at_full_bandwidth = true;
389 }
390 }
391}
392
393impl Controller for Bbr {
394 fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) {
395 self.max_sent_packet_number = last_packet_number;
396 self.max_bandwidth.on_sent(now, bytes);
397 }
398
399 fn on_ack(
400 &mut self,
401 now: Instant,
402 sent: Instant,
403 bytes: u64,
404 app_limited: bool,
405 rtt: &RttEstimator,
406 ) {
407 self.max_bandwidth
408 .on_ack(now, sent, bytes, self.round_count, app_limited);
409 self.acked_bytes += bytes;
410 if self.is_min_rtt_expired(now, app_limited) || self.min_rtt > rtt.min() {
411 self.min_rtt = rtt.min();
412 }
413 }
414
415 fn on_end_acks(
416 &mut self,
417 now: Instant,
418 in_flight: u64,
419 app_limited: bool,
420 largest_packet_num_acked: Option<u64>,
421 ) {
422 let bytes_acked = self.max_bandwidth.bytes_acked_this_window();
423 let excess_acked = self.ack_aggregation.update_ack_aggregation_bytes(
424 bytes_acked,
425 now,
426 self.round_count,
427 self.max_bandwidth.get_estimate(),
428 );
429 self.max_bandwidth.end_acks(self.round_count, app_limited);
430 if let Some(largest_acked_packet) = largest_packet_num_acked {
431 self.max_acked_packet_number = largest_acked_packet;
432 }
433
434 let mut is_round_start = false;
435 if bytes_acked > 0 {
436 is_round_start =
437 self.max_acked_packet_number > self.current_round_trip_end_packet_number;
438 if is_round_start {
439 self.current_round_trip_end_packet_number = self.max_sent_packet_number;
440 self.round_count += 1;
441 }
442 }
443
444 self.update_recovery_state(is_round_start);
445
446 if self.mode == Mode::ProbeBw {
447 self.update_gain_cycle_phase(now, in_flight);
448 }
449
450 if is_round_start && !self.is_at_full_bandwidth {
451 self.check_if_full_bw_reached(app_limited);
452 }
453
454 self.maybe_exit_startup_or_drain(now, in_flight);
455
456 self.maybe_enter_or_exit_probe_rtt(now, is_round_start, in_flight, app_limited);
457
458 self.calculate_pacing_rate();
460 self.calculate_cwnd(bytes_acked, excess_acked);
461 self.calculate_recovery_window(bytes_acked, self.loss_state.lost_bytes, in_flight);
462
463 self.prev_in_flight_count = in_flight;
464 self.loss_state.reset();
465 }
466
467 fn on_congestion_event(
468 &mut self,
469 _now: Instant,
470 _sent: Instant,
471 _is_persistent_congestion: bool,
472 _is_ecn: bool,
473 lost_bytes: u64,
474 ) {
475 self.loss_state.lost_bytes += lost_bytes;
476 }
477
478 fn on_mtu_update(&mut self, new_mtu: u16) {
479 self.current_mtu = new_mtu as u64;
480 self.min_cwnd = calculate_min_window(self.current_mtu);
481 self.init_cwnd = self.config.initial_window.max(self.min_cwnd);
482 self.cwnd = self.cwnd.max(self.min_cwnd);
483 }
484
485 fn window(&self) -> u64 {
486 if self.mode == Mode::ProbeRtt {
487 return self.get_probe_rtt_cwnd();
488 } else if self.recovery_state.in_recovery() && self.mode != Mode::Startup {
489 return self.cwnd.min(self.recovery_window);
490 }
491 self.cwnd
492 }
493
494 fn metrics(&self) -> ControllerMetrics {
495 ControllerMetrics {
496 congestion_window: self.window(),
497 ssthresh: None,
498 pacing_rate: Some(self.pacing_rate * 8),
499 }
500 }
501
502 fn clone_box(&self) -> Box<dyn Controller> {
503 Box::new(self.clone())
504 }
505
506 fn initial_window(&self) -> u64 {
507 self.config.initial_window
508 }
509
510 fn into_any(self: Box<Self>) -> Box<dyn Any> {
511 self
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct BbrConfig {
518 initial_window: u64,
519}
520
521impl BbrConfig {
522 pub fn initial_window(&mut self, value: u64) -> &mut Self {
526 self.initial_window = value;
527 self
528 }
529}
530
531impl Default for BbrConfig {
532 fn default() -> Self {
533 Self {
534 initial_window: K_MAX_INITIAL_CONGESTION_WINDOW * BASE_DATAGRAM_SIZE,
535 }
536 }
537}
538
539impl ControllerFactory for BbrConfig {
540 fn build(self: Arc<Self>, _now: Instant, current_mtu: u16) -> Box<dyn Controller> {
541 Box::new(Bbr::new(self, current_mtu))
542 }
543}
544
545#[derive(Debug, Default, Copy, Clone)]
546struct AckAggregationState {
547 max_ack_height: MinMax,
548 aggregation_epoch_start_time: Option<Instant>,
549 aggregation_epoch_bytes: u64,
550}
551
552impl AckAggregationState {
553 fn update_ack_aggregation_bytes(
554 &mut self,
555 newly_acked_bytes: u64,
556 now: Instant,
557 round: u64,
558 max_bandwidth: u64,
559 ) -> u64 {
560 let expected_bytes_acked = max_bandwidth
563 * now
564 .saturating_duration_since(self.aggregation_epoch_start_time.unwrap_or(now))
565 .as_micros() as u64
566 / 1_000_000;
567
568 if self.aggregation_epoch_bytes <= expected_bytes_acked {
571 self.aggregation_epoch_bytes = newly_acked_bytes;
573 self.aggregation_epoch_start_time = Some(now);
574 return 0;
575 }
576
577 self.aggregation_epoch_bytes += newly_acked_bytes;
580 let diff = self.aggregation_epoch_bytes - expected_bytes_acked;
581 self.max_ack_height.update_max(round, diff);
582 diff
583 }
584}
585
586#[derive(Debug, Clone, Copy, Eq, PartialEq)]
587enum Mode {
588 Startup,
590 Drain,
593 ProbeBw,
595 ProbeRtt,
598}
599
600#[derive(Debug, Clone, Copy, Eq, PartialEq)]
602enum RecoveryState {
603 NotInRecovery,
605 Conservation,
607 Growth,
610}
611
612impl RecoveryState {
613 pub(super) fn in_recovery(&self) -> bool {
614 !matches!(self, Self::NotInRecovery)
615 }
616}
617
618#[derive(Debug, Clone, Default)]
619struct LossState {
620 lost_bytes: u64,
621}
622
623impl LossState {
624 pub(super) fn reset(&mut self) {
625 self.lost_bytes = 0;
626 }
627
628 pub(super) fn has_losses(&self) -> bool {
629 self.lost_bytes != 0
630 }
631}
632
633fn calculate_min_window(current_mtu: u64) -> u64 {
634 4 * current_mtu
635}
636
637const K_DEFAULT_HIGH_GAIN: f32 = 2.885;
639const K_DERIVED_HIGH_CWNDGAIN: f32 = 2.0;
641const K_PACING_GAIN: [f32; 8] = [1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
643
644const K_STARTUP_GROWTH_TARGET: f32 = 1.25;
645const K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP: u8 = 3;
646
647const K_MAX_INITIAL_CONGESTION_WINDOW: u64 = 200;
649
650const PROBE_RTT_BASED_ON_BDP: bool = true;
651const DRAIN_TO_TARGET: bool = true;