//! Utility functions for interacting with any runtime. #[cfg(test)] use crate::Runner; use crate::{Metrics, Spawner}; #[cfg(test)] use futures::stream::{FuturesUnordered, StreamExt}; use rayon::{ThreadPool as RThreadPool, ThreadPoolBuildError, ThreadPoolBuilder}; use std::{ any::Any, future::Future, pin::Pin, sync::Arc, task::{Context, Poll}, }; pub mod buffer; pub mod signal; mod handle; pub use handle::Handle; /// Yield control back to the runtime. pub async fn reschedule() { struct Reschedule { yielded: bool, } impl Future for Reschedule { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { if self.yielded { Poll::Ready(()) } else { self.yielded = true; cx.waker().wake_by_ref(); Poll::Pending } } } Reschedule { yielded: false }.await } fn extract_panic_message(err: &(dyn Any + Send)) -> String { if let Some(s) = err.downcast_ref::<&str>() { s.to_string() } else if let Some(s) = err.downcast_ref::() { s.clone() } else { format!("{err:?}") } } /// A clone-able wrapper around a [rayon]-compatible thread pool. pub type ThreadPool = Arc; /// Creates a clone-able [rayon]-compatible thread pool with [Spawner::spawn_blocking]. /// /// # Arguments /// - `context`: The runtime context implementing the [Spawner] trait. /// - `concurrency`: The number of tasks to execute concurrently in the pool. /// /// # Returns /// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built. pub fn create_pool( context: S, concurrency: usize, ) -> Result { let pool = ThreadPoolBuilder::new() .num_threads(concurrency) .spawn_handler(move |thread| { // Tasks spawned in a thread pool are expected to run longer than any single // task and thus should be provisioned as a dedicated thread. context .with_label("rayon-thread") .spawn_blocking(true, move |_| thread.run()); Ok(()) }) .build()?; Ok(Arc::new(pool)) } /// Async reader–writer lock. /// /// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition /// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies). /// /// Usage: /// ```rust /// use commonware_runtime::{Spawner, Runner, deterministic, RwLock}; /// /// let executor = deterministic::Runner::default(); /// executor.start(|context| async move { /// // Create a new RwLock /// let lock = RwLock::new(2); /// /// // many concurrent readers /// let r1 = lock.read().await; /// let r2 = lock.read().await; /// assert_eq!(*r1 + *r2, 4); /// /// // exclusive writer /// drop((r1, r2)); /// let mut w = lock.write().await; /// *w += 1; /// }); /// ``` pub struct RwLock(async_lock::RwLock); /// Shared guard returned by [RwLock::read]. pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>; /// Exclusive guard returned by [RwLock::write]. pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>; impl RwLock { /// Create a new lock. #[inline] pub const fn new(value: T) -> Self { Self(async_lock::RwLock::new(value)) } /// Acquire a shared read guard. #[inline] pub async fn read(&self) -> RwLockReadGuard<'_, T> { self.0.read().await } /// Acquire an exclusive write guard. #[inline] pub async fn write(&self) -> RwLockWriteGuard<'_, T> { self.0.write().await } /// Try to get a read guard without waiting. #[inline] pub fn try_read(&self) -> Option> { self.0.try_read() } /// Try to get a write guard without waiting. #[inline] pub fn try_write(&self) -> Option> { self.0.try_write() } /// Get mutable access without locking (requires `&mut self`). #[inline] pub fn get_mut(&mut self) -> &mut T { self.0.get_mut() } /// Consume the lock, returning the inner value. #[inline] pub fn into_inner(self) -> T { self.0.into_inner() } } #[cfg(test)] async fn task(i: usize) -> usize { for _ in 0..5 { reschedule().await; } i } #[cfg(test)] pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec) { runner.start(|context| async move { // Randomly schedule tasks let mut handles = FuturesUnordered::new(); for i in 0..=tasks - 1 { handles.push(context.clone().spawn(move |_| task(i))); } // Collect output order let mut outputs = Vec::new(); while let Some(result) = handles.next().await { outputs.push(result.unwrap()); } assert_eq!(outputs.len(), tasks); (context.auditor().state(), outputs) }) } #[cfg(test)] mod tests { use super::*; use crate::{deterministic, tokio, Metrics}; use commonware_macros::test_traced; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; #[test_traced] fn test_create_pool() { let executor = tokio::Runner::default(); executor.start(|context| async move { // Create a thread pool with 4 threads let pool = create_pool(context.with_label("pool"), 4).unwrap(); // Create a vector of numbers let v: Vec<_> = (0..10000).collect(); // Use the thread pool to sum the numbers pool.install(|| { assert_eq!(v.par_iter().sum::(), 10000 * 9999 / 2); }); }); } #[test_traced] fn test_rwlock() { let executor = deterministic::Runner::default(); executor.start(|_| async move { // Create a new RwLock let lock = RwLock::new(100); // many concurrent readers let r1 = lock.read().await; let r2 = lock.read().await; assert_eq!(*r1 + *r2, 200); // exclusive writer drop((r1, r2)); // all readers must go away let mut w = lock.write().await; *w += 1; // Check the value assert_eq!(*w, 101); }); } }