//! Channel reservation helpers. use super::mpsc::{ self, error::{SendError, TrySendError}, OwnedPermit, }; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; // The reserve future only reports channel closure; the message value is stored separately. type ReserveResult = Result, SendError<()>>; // Tokio's `reserve_owned` future is not nameable, so box it instead of exposing a future parameter. type ReserveFuture = Pin> + Send>>; /// A reserved channel slot bundled with the value to send. #[must_use = "call send to deliver the reserved message"] pub struct Reserved { permit: OwnedPermit, value: T, } impl Reserved { /// Sends the buffered value through the reserved slot. pub fn send(self) -> mpsc::Sender { self.permit.send(self.value) } } /// A future that waits for a channel slot and keeps ownership of the value. #[must_use = "await the reservation to acquire a channel slot"] pub struct Reservation { future: ReserveFuture, value: Option, } impl Reservation { fn new(future: impl Future> + Send + 'static, value: T) -> Self { Self { future: Box::pin(future), value: Some(value), } } } impl Unpin for Reservation {} impl Future for Reservation { type Output = Result, SendError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let permit = match self.future.as_mut().poll(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(permit) => permit, }; let value = self .value .take() .expect("reservation polled after completion"); Poll::Ready(match permit { Ok(permit) => Ok(Reserved { permit, value }), Err(SendError(())) => Err(SendError(value)), }) } } /// Extension trait for bounded channel sends that can reserve capacity. pub trait ReservationExt { /// Attempts to send immediately, reserving the message when the channel is full. /// /// Returns: /// - `Ok(None)` when the value was sent immediately. /// - `Ok(Some(_))` when the channel was full. Await the reservation and call /// [`Reserved::send`] to deliver the value. /// - `Err(_)` when the receiver has been dropped. #[must_use = "await and send any reservation"] fn send_or_reserve(&self, value: T) -> Result>, SendError> where T: 'static; } impl ReservationExt for mpsc::Sender { fn send_or_reserve(&self, value: T) -> Result>, SendError> where T: 'static, { match self.try_send(value) { Ok(()) => Ok(None), Err(TrySendError::Full(value)) => { Ok(Some(Reservation::new(self.clone().reserve_owned(), value))) } Err(TrySendError::Closed(value)) => Err(SendError(value)), } } } #[cfg(test)] mod tests { use super::*; use commonware_macros::test_async; use std::collections::BTreeMap; #[test] fn test_send_or_reserve_sends_immediately() { let (sender, mut receiver) = mpsc::channel(1); assert!(sender.send_or_reserve(1).unwrap().is_none()); assert_eq!(receiver.try_recv(), Ok(1)); } #[test] fn test_send_or_reserve_closed_returns_value() { let (sender, receiver) = mpsc::channel(1); drop(receiver); match sender.send_or_reserve(1) { Ok(_) => panic!("send should fail"), Err(SendError(value)) => assert_eq!(value, 1), } } #[test_async] async fn test_send_or_reserve_waits_for_capacity() { let (sender, mut receiver) = mpsc::channel(1); sender.try_send(1).unwrap(); let reservation = sender .send_or_reserve(2) .unwrap() .expect("channel should be full"); assert_eq!(receiver.recv().await, Some(1)); reservation.await.unwrap().send(); assert_eq!(receiver.recv().await, Some(2)); } #[test_async] async fn test_send_or_reserve_returns_value_when_closed_while_waiting() { let (sender, receiver) = mpsc::channel(1); sender.try_send(1).unwrap(); let reservation = sender .send_or_reserve(2) .unwrap() .expect("channel should be full"); drop(receiver); match reservation.await { Ok(_) => panic!("reservation should fail"), Err(SendError(value)) => assert_eq!(value, 2), } } #[test_async] async fn test_send_or_reserve_reservations_can_be_stored() { let (sender, mut receiver) = mpsc::channel(1); sender.try_send(0).unwrap(); let mut reservations = Vec::new(); reservations.push( sender .send_or_reserve(1) .unwrap() .expect("channel should be full"), ); let mut reservation_map = BTreeMap::new(); reservation_map.insert( "next", sender .send_or_reserve(2) .unwrap() .expect("channel should be full"), ); assert_eq!(receiver.recv().await, Some(0)); reservations.pop().unwrap().await.unwrap().send(); assert_eq!(receiver.recv().await, Some(1)); reservation_map .remove("next") .unwrap() .await .unwrap() .send(); assert_eq!(receiver.recv().await, Some(2)); } }