use crate::encrypted::Error; use commonware_codec::{ varint::{Decoder, UInt, MAX_U32_VARINT_SIZE}, Encode, EncodeSize, Write, }; use commonware_runtime::{Buf, IoBuf, IoBufMut, IoBufs, Sink, Stream}; /// Validates the frame size and assembles the frame via the caller's closure. /// /// The `assemble` closure receives the varint prefix and must combine it with /// the payload. This allows callers to choose between: /// - Chunked: prepend the prefix as a separate buffer /// - Contiguous: write the prefix directly into a pre-allocated buffer /// /// Returns an error if the message is too large. pub(crate) fn build_frame( payload_len: usize, max_message_size: u32, assemble: impl FnOnce(UInt) -> Result, ) -> Result { if payload_len > max_message_size as usize { return Err(Error::SendTooLarge(payload_len)); } let prefix = UInt(payload_len as u32); assemble(prefix) } /// Returns the total size of a length-prefixed frame. pub(crate) fn framed_len(payload_len: usize, max_message_size: u32) -> Result { build_frame(payload_len, max_message_size, |prefix| { Ok(prefix.encode_size() + payload_len) }) } /// Appends one length-prefixed frame to a contiguous output buffer. /// /// The callback receives the offset of the frame payload, which is useful when /// callers need to operate on the payload bytes after copying them. pub(crate) fn append_frame( frame: &mut IoBufMut, payload_len: usize, max_message_size: u32, append_payload: impl FnOnce(&mut IoBufMut, usize) -> Result<(), Error>, ) -> Result { build_frame(payload_len, max_message_size, |prefix| { let start = frame.len(); prefix.write(frame); let payload_offset = frame.len(); append_payload(frame, payload_offset)?; assert_eq!(frame.len() - payload_offset, payload_len); Ok(frame.len() - start) }) } /// Sends data to the sink with a varint length prefix. /// /// The varint length prefix is prepended to the buffer(s), which results in a /// chunked `IoBufs`. /// /// Returns an error if the message is too large or the sink is closed. pub async fn send_frame( sink: &mut S, bufs: impl Into + Send, max_message_size: u32, ) -> Result<(), Error> { let mut bufs = bufs.into(); let frame = build_frame(bufs.len(), max_message_size, |prefix| { bufs.prepend(IoBuf::from(prefix.encode())); Ok(bufs) })?; sink.send(frame).await.map_err(Error::SendFailed) } /// Receives data from the stream with a varint length prefix. /// Returns an error if the message is too large, the varint is invalid, or the /// stream is closed. pub async fn recv_frame(stream: &mut T, max_message_size: u32) -> Result { let (len, skip) = recv_length(stream).await?; if len > max_message_size as usize { return Err(Error::RecvTooLarge(len)); } stream .recv(skip + len) .await .map(|mut bufs| { bufs.advance(skip); bufs }) .map_err(Error::RecvFailed) } /// Receives and decodes the varint length prefix from the stream. /// Returns (payload_len, bytes_to_skip) where bytes_to_skip is: /// - varint_len if decoded from peek buffer (bytes not yet consumed) /// - 0 if decoded via recv (bytes already consumed) async fn recv_length(stream: &mut T) -> Result<(usize, usize), Error> { let mut decoder = Decoder::::new(); // Fast path: decode from peek buffer without blocking let peeked = { let peeked = stream.peek(MAX_U32_VARINT_SIZE); for (i, byte) in peeked.iter().enumerate() { match decoder.feed(*byte) { Ok(Some(len)) => return Ok((len as usize, i + 1)), Ok(None) => continue, Err(_) => return Err(Error::InvalidVarint), } } peeked.len() }; // Slow path: fetch bytes one at a time (skipping already-decoded peek bytes) let mut buf = stream.recv(peeked + 1).await.map_err(Error::RecvFailed)?; buf.advance(peeked); loop { match decoder.feed(buf.get_u8()) { Ok(Some(len)) => return Ok((len as usize, 0)), Ok(None) => {} Err(_) => return Err(Error::InvalidVarint), } buf = stream.recv(1).await.map_err(Error::RecvFailed)?; } } #[cfg(test)] mod tests { use super::*; use commonware_runtime::{deterministic, mocks, BufMut, IoBufMut, Runner}; use rand::Rng; const MAX_MESSAGE_SIZE: u32 = 1024; #[test] fn test_send_recv_at_max_message_size() { let (mut sink, mut stream) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|mut context| async move { let mut buf = [0u8; MAX_MESSAGE_SIZE as usize]; context.fill(&mut buf); let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); assert_eq!(data.len(), buf.len()); assert_eq!(data.coalesce(), buf); }); } #[test] fn test_send_recv_multiple() { let (mut sink, mut stream) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|mut context| async move { let mut buf1 = [0u8; MAX_MESSAGE_SIZE as usize]; let mut buf2 = [0u8; (MAX_MESSAGE_SIZE as usize) / 2]; context.fill(&mut buf1); context.fill(&mut buf2); // Send two messages of different sizes let result = send_frame(&mut sink, buf1.to_vec(), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); let result = send_frame(&mut sink, buf2.to_vec(), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); // Read both messages in order let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); assert_eq!(data.len(), buf1.len()); assert_eq!(data.coalesce(), buf1); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); assert_eq!(data.len(), buf2.len()); assert_eq!(data.coalesce(), buf2); }); } #[test] fn test_send_frame() { let (mut sink, mut stream) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|mut context| async move { let mut buf = [0u8; MAX_MESSAGE_SIZE as usize]; context.fill(&mut buf); let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE).await; assert!(result.is_ok()); // Do the reading manually without using recv_frame // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] (2 bytes) let read = stream.recv(2).await.unwrap(); assert_eq!(read.coalesce(), &[0x80, 0x08]); // 1024 as varint let read = stream.recv(MAX_MESSAGE_SIZE as usize).await.unwrap(); assert_eq!(read.coalesce(), buf); }); } #[test] fn test_build_frame_closure_error() { let result: Result = build_frame(10, MAX_MESSAGE_SIZE, |_prefix| { Err(Error::HandshakeError( commonware_cryptography::handshake::Error::EncryptionFailed, )) }); assert!(matches!(&result, Err(Error::HandshakeError(_)))); } #[test] fn test_build_frame_too_large() { let result: Result = build_frame( MAX_MESSAGE_SIZE as usize + 1, MAX_MESSAGE_SIZE, |_prefix| unreachable!(), ); assert!( matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize + 1) ); } #[test] fn test_send_frame_too_large() { let (mut sink, _) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|mut context| async move { let mut buf = [0u8; MAX_MESSAGE_SIZE as usize]; context.fill(&mut buf); let result = send_frame(&mut sink, buf.to_vec(), MAX_MESSAGE_SIZE - 1).await; assert!( matches!(&result, Err(Error::SendTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize) ); }); } #[test] fn test_read_frame() { let (mut sink, mut stream) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|mut context| async move { // Do the writing manually without using send_frame let mut msg = [0u8; MAX_MESSAGE_SIZE as usize]; context.fill(&mut msg); // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] let mut buf = IoBufMut::with_capacity(2 + msg.len()); buf.put_u8(0x80); buf.put_u8(0x08); buf.put_slice(&msg); sink.send(buf.freeze()).await.unwrap(); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); assert_eq!(data.len(), MAX_MESSAGE_SIZE as usize); assert_eq!(data.coalesce(), msg); }); } #[test] fn test_read_frame_too_large() { let (mut sink, mut stream) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|_| async move { // Manually insert a frame that gives MAX_MESSAGE_SIZE as the size // 1024 (MAX_MESSAGE_SIZE) encodes as varint: [0x80, 0x08] let mut buf = IoBufMut::with_capacity(2); buf.put_u8(0x80); buf.put_u8(0x08); sink.send(buf.freeze()).await.unwrap(); let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE - 1).await; assert!( matches!(&result, Err(Error::RecvTooLarge(n)) if *n == MAX_MESSAGE_SIZE as usize) ); }); } #[test] fn test_recv_frame_incomplete_varint() { let (mut sink, mut stream) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|_| async move { // Send incomplete varint (continuation bit set but no following byte) let mut buf = IoBufMut::with_capacity(1); buf.put_u8(0x80); // Continuation bit set, expects more bytes sink.send(buf.freeze()).await.unwrap(); drop(sink); // Close the sink to simulate a closed stream // Expect an error because varint is incomplete let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await; assert!(matches!(&result, Err(Error::RecvFailed(_)))); }); } #[test] fn test_recv_frame_invalid_varint_overflow() { let (mut sink, mut stream) = mocks::Channel::init(); let executor = deterministic::Runner::default(); executor.start(|_| async move { // Send a varint that overflows u32 (more than 5 bytes with continuation bits) let mut buf = IoBufMut::with_capacity(6); buf.put_u8(0xFF); // 7 bits + continue buf.put_u8(0xFF); // 7 bits + continue buf.put_u8(0xFF); // 7 bits + continue buf.put_u8(0xFF); // 7 bits + continue buf.put_u8(0xFF); // 5th byte with overflow bits set + continue buf.put_u8(0x01); // 6th byte sink.send(buf.freeze()).await.unwrap(); // Expect an error because varint overflows u32 let result = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await; assert!(matches!(&result, Err(Error::InvalidVarint))); }); } #[test] fn test_recv_frame_peek_paths() { let executor = deterministic::Runner::default(); executor.start(|mut context| async move { // 300 encodes as [0xAC, 0x02] (2-byte varint) let mut payload = vec![0u8; 300]; context.fill(&mut payload[..]); // Fast path: peek returns complete varint let (mut sink, mut stream) = mocks::Channel::init(); send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE) .await .unwrap(); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); assert_eq!(data.coalesce(), &payload[..]); // Slow path: peek returns empty let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(0); send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE) .await .unwrap(); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); assert_eq!(data.coalesce(), &payload[..]); // Slow path: peek returns partial varint let (mut sink, mut stream) = mocks::Channel::init_with_read_buffer_size(1); send_frame(&mut sink, payload.clone(), MAX_MESSAGE_SIZE) .await .unwrap(); let data = recv_frame(&mut stream, MAX_MESSAGE_SIZE).await.unwrap(); assert_eq!(data.coalesce(), &payload[..]); }); } }