use super::location::Location; use bytes::{Buf, BufMut}; use commonware_codec::{varint::UInt, ReadExt}; use core::{ fmt, ops::{Add, AddAssign, Deref, Sub, SubAssign}, }; /// Maximum valid [Position] value: the largest node count (size) an MMR can hold. /// /// An MMR with `2^62` leaves has `2^63 - 1` nodes, so `MAX_POSITION = 2^63 - 1`. /// /// Node indices are 0-based, so valid indices satisfy `pos < MAX_POSITION`. Node counts /// and MMR sizes satisfy `pos <= MAX_POSITION`. pub const MAX_POSITION: Position = Position::new(0x7FFFFFFFFFFFFFFF); // (1 << 63) - 1 /// A [Position] is an index into an MMR's nodes. /// This is in contrast to a [Location], which is an index into an MMR's _leaves_. #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)] pub struct Position(u64); #[cfg(feature = "arbitrary")] impl arbitrary::Arbitrary<'_> for Position { fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result { let value = u.int_in_range(0..=MAX_POSITION.0)?; Ok(Self(value)) } } impl Position { /// Return a new [Position] from a raw `u64`. #[inline] pub const fn new(pos: u64) -> Self { Self(pos) } /// Return the underlying `u64` value. #[inline] pub const fn as_u64(self) -> u64 { self.0 } /// Returns `true` iff this value is within the valid range (`<= MAX_POSITION`). /// This covers both node indices (`< MAX_POSITION`) and node counts (`<= MAX_POSITION`). #[inline] pub const fn is_valid(self) -> bool { self.0 <= MAX_POSITION.0 } /// Return `self + rhs` returning `None` on overflow or if result exceeds [MAX_POSITION]. #[inline] pub const fn checked_add(self, rhs: u64) -> Option { match self.0.checked_add(rhs) { Some(value) => { if value <= MAX_POSITION.0 { Some(Self(value)) } else { None } } None => None, } } /// Return `self - rhs` returning `None` on underflow. #[inline] pub const fn checked_sub(self, rhs: u64) -> Option { match self.0.checked_sub(rhs) { Some(value) => Some(Self(value)), None => None, } } /// Return `self + rhs` saturating at [MAX_POSITION]. #[inline] pub const fn saturating_add(self, rhs: u64) -> Self { let result = self.0.saturating_add(rhs); if result > MAX_POSITION.0 { MAX_POSITION } else { Self(result) } } /// Return `self - rhs` saturating at zero. #[inline] pub const fn saturating_sub(self, rhs: u64) -> Self { Self(self.0.saturating_sub(rhs)) } /// Returns whether this is a valid MMR size. /// /// The implementation verifies that (1) the size won't result in overflow and (2) peaks in the /// MMR of the given size have strictly decreasing height, which is a necessary condition for /// MMR validity. #[inline] pub const fn is_mmr_size(self) -> bool { if self.0 == 0 { return true; } let leading_zeros = self.0.leading_zeros(); if leading_zeros == 0 { // size overflow return false; } let start = u64::MAX >> leading_zeros; let mut two_h = 1 << start.trailing_ones(); let mut node_pos = start.checked_sub(1).expect("start > 0 because size != 0"); while two_h > 1 { if node_pos < self.0 { if two_h == 2 { // If this peak is a leaf yet there are more nodes remaining, then this MMR is // invalid. return node_pos == self.0 - 1; } // move to the right sibling node_pos += two_h - 1; if node_pos < self.0 { // If the right sibling is in the MMR, then it is invalid. return false; } continue; } // descend to the left child two_h >>= 1; node_pos -= two_h; } true } } impl fmt::Display for Position { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Position({})", self.0) } } impl Deref for Position { type Target = u64; fn deref(&self) -> &Self::Target { &self.0 } } impl AsRef for Position { fn as_ref(&self) -> &u64 { &self.0 } } impl From for Position { #[inline] fn from(value: u64) -> Self { Self::new(value) } } impl From for Position { #[inline] fn from(value: usize) -> Self { Self::new(value as u64) } } impl From for u64 { #[inline] fn from(position: Position) -> Self { *position } } /// Try to convert a leaf [Location] to its node [Position]. /// /// Returns [super::Error::LocationOverflow] if `!loc.is_valid()`. /// /// # Examples /// /// ``` /// use commonware_storage::mmr::{Location, Position, MAX_LOCATION}; /// use core::convert::TryFrom; /// /// let loc = Location::new(5); /// let pos = Position::try_from(loc).unwrap(); /// assert_eq!(pos, Position::new(8)); /// /// // MAX_LOCATION converts successfully (it is the leaf count for 2^62 leaves) /// let pos = Position::try_from(MAX_LOCATION).unwrap(); /// assert!(pos.is_valid()); /// ``` impl TryFrom for Position { type Error = super::Error; #[inline] fn try_from(loc: Location) -> Result { if !loc.is_valid() { return Err(super::Error::LocationOverflow(loc)); } // This will never underflow since 2*n >= count_ones(n). let loc_val = *loc; Ok(Self( loc_val .checked_mul(2) .expect("should not overflow for valid leaf index") - loc_val.count_ones() as u64, )) } } /// Add two positions together. /// /// # Panics /// /// Panics if the result overflows. impl Add for Position { type Output = Self; #[inline] fn add(self, rhs: Self) -> Self::Output { Self(self.0 + rhs.0) } } /// Add a position and a `u64`. /// /// # Panics /// /// Panics if the result overflows. impl Add for Position { type Output = Self; #[inline] fn add(self, rhs: u64) -> Self::Output { Self(self.0 + rhs) } } /// Subtract two positions. /// /// # Panics /// /// Panics if the result underflows. impl Sub for Position { type Output = Self; #[inline] fn sub(self, rhs: Self) -> Self::Output { Self(self.0 - rhs.0) } } /// Subtract a `u64` from a position. /// /// # Panics /// /// Panics if the result underflows. impl Sub for Position { type Output = Self; #[inline] fn sub(self, rhs: u64) -> Self::Output { Self(self.0 - rhs) } } impl PartialEq for Position { #[inline] fn eq(&self, other: &u64) -> bool { self.0 == *other } } impl PartialOrd for Position { #[inline] fn partial_cmp(&self, other: &u64) -> Option { self.0.partial_cmp(other) } } // Allow u64 to be compared with Position too impl PartialEq for u64 { #[inline] fn eq(&self, other: &Position) -> bool { *self == other.0 } } impl PartialOrd for u64 { #[inline] fn partial_cmp(&self, other: &Position) -> Option { self.partial_cmp(&other.0) } } /// Add a `u64` to a position. /// /// # Panics /// /// Panics if the result overflows. impl AddAssign for Position { #[inline] fn add_assign(&mut self, rhs: u64) { self.0 += rhs; } } /// Subtract a `u64` from a position. /// /// # Panics /// /// Panics if the result underflows. impl SubAssign for Position { #[inline] fn sub_assign(&mut self, rhs: u64) { self.0 -= rhs; } } // Codec implementations using varint encoding for efficient storage impl commonware_codec::Write for Position { #[inline] fn write(&self, buf: &mut impl BufMut) { UInt(self.0).write(buf); } } impl commonware_codec::EncodeSize for Position { #[inline] fn encode_size(&self) -> usize { UInt(self.0).encode_size() } } impl commonware_codec::Read for Position { type Cfg = (); #[inline] fn read_cfg(buf: &mut impl Buf, _: &()) -> Result { let pos = Self(UInt::read(buf)?.into()); if pos.is_valid() { Ok(pos) } else { Err(commonware_codec::Error::Invalid( "Position", "value exceeds MAX_POSITION", )) } } } #[cfg(test)] mod tests { use super::{Location, Position}; use crate::mmr::{mem::Mmr, StandardHasher as Standard, MAX_LOCATION, MAX_POSITION}; use commonware_cryptography::Sha256; // Test that the [Position::from] function returns the correct position for leaf locations. #[test] fn test_from_location() { const CASES: &[(Location, Position)] = &[ (Location::new(0), Position::new(0)), (Location::new(1), Position::new(1)), (Location::new(2), Position::new(3)), (Location::new(3), Position::new(4)), (Location::new(4), Position::new(7)), (Location::new(5), Position::new(8)), (Location::new(6), Position::new(10)), (Location::new(7), Position::new(11)), (Location::new(8), Position::new(15)), (Location::new(9), Position::new(16)), (Location::new(10), Position::new(18)), (Location::new(11), Position::new(19)), (Location::new(12), Position::new(22)), (Location::new(13), Position::new(23)), (Location::new(14), Position::new(25)), (Location::new(15), Position::new(26)), ]; for (loc, expected_pos) in CASES { let pos = Position::try_from(*loc).unwrap(); assert_eq!(pos, *expected_pos); } } #[test] fn test_checked_add() { let pos = Position::new(10); assert_eq!(pos.checked_add(5).unwrap(), 15); // Overflow returns None assert!(Position::new(u64::MAX).checked_add(1).is_none()); // Exceeding MAX_POSITION returns None assert!(MAX_POSITION.checked_add(1).is_none()); assert!(Position::new(*MAX_POSITION - 5).checked_add(10).is_none()); // At MAX_POSITION is OK assert_eq!( Position::new(*MAX_POSITION - 10).checked_add(10).unwrap(), MAX_POSITION ); } #[test] fn test_checked_sub() { let pos = Position::new(10); assert_eq!(pos.checked_sub(5).unwrap(), 5); assert!(pos.checked_sub(11).is_none()); } #[test] fn test_saturating_add() { let pos = Position::new(10); assert_eq!(pos.saturating_add(5), 15); // Saturates at MAX_POSITION, not u64::MAX assert_eq!(Position::new(u64::MAX).saturating_add(1), MAX_POSITION); assert_eq!(MAX_POSITION.saturating_add(1), MAX_POSITION); assert_eq!(MAX_POSITION.saturating_add(1000), MAX_POSITION); assert_eq!( Position::new(*MAX_POSITION - 5).saturating_add(10), MAX_POSITION ); } #[test] fn test_saturating_sub() { let pos = Position::new(10); assert_eq!(pos.saturating_sub(5), 5); assert_eq!(Position::new(0).saturating_sub(1), 0); } #[test] fn test_display() { let position = Position::new(42); assert_eq!(position.to_string(), "Position(42)"); } #[test] fn test_add() { let pos1 = Position::new(10); let pos2 = Position::new(5); assert_eq!((pos1 + pos2), 15); } #[test] fn test_sub() { let pos1 = Position::new(10); let pos2 = Position::new(3); assert_eq!((pos1 - pos2), 7); } #[test] fn test_comparison_with_u64() { let pos = Position::new(42); // Test equality assert_eq!(pos, 42u64); assert_eq!(42u64, pos); assert_ne!(pos, 43u64); assert_ne!(43u64, pos); // Test ordering assert!(pos < 43u64); assert!(43u64 > pos); assert!(pos > 41u64); assert!(41u64 < pos); assert!(pos <= 42u64); assert!(42u64 >= pos); } #[test] fn test_assignment_with_u64() { let mut pos = Position::new(10); // Test add assignment pos += 5; assert_eq!(pos, 15u64); // Test sub assignment pos -= 3; assert_eq!(pos, 12u64); } #[test] fn test_max_position() { // MAX_POSITION = max MMR size = 2^63 - 1 (for 2^62 leaves). let max_leaves = 1u64 << 62; let max_size = 2 * max_leaves - 1; // 2^63 - 1 assert_eq!(*MAX_POSITION, max_size); assert_eq!(*MAX_POSITION, (1u64 << 63) - 1); assert_eq!(max_size.leading_zeros(), 1); // top bit clear // One more leaf would overflow: size = 2^63, top bit set. let overflow_size = 2 * (max_leaves + 1) - 1; assert_eq!(overflow_size.leading_zeros(), 0); // MAX_LOCATION converts to MAX_POSITION. let pos = Position::try_from(MAX_LOCATION).unwrap(); assert_eq!(pos, MAX_POSITION); assert!(pos.is_valid()); } #[test] fn test_is_mmr_size() { // Build an MMR one node at a time and check that the validity check is correct for all // sizes up to the current size. let mut size_to_check = Position::new(0); let mut hasher = Standard::::new(); let mut mmr = Mmr::new(&mut hasher); let digest = [1u8; 32]; for _i in 0..10000 { while size_to_check != mmr.size() { assert!( !size_to_check.is_mmr_size(), "size_to_check: {} {}", size_to_check, mmr.size() ); size_to_check += 1; } assert!(size_to_check.is_mmr_size()); let changeset = { let mut batch = mmr.new_batch(); batch.add(&mut hasher, &digest); batch.merkleize(&mut hasher).finalize() }; mmr.apply(changeset).unwrap(); size_to_check += 1; } // Test overflow boundaries. assert!(!Position::new(u64::MAX).is_mmr_size()); assert!(Position::new(u64::MAX >> 1).is_mmr_size()); // 2^63 - 1 = MAX_POSITION assert!(!Position::new((u64::MAX >> 1) + 1).is_mmr_size()); assert!(MAX_POSITION.is_mmr_size()); // MAX_POSITION is the largest valid MMR size } #[test] fn test_read_cfg_valid_values() { use commonware_codec::{Encode, ReadExt}; // Test zero let pos = Position::new(0); let encoded = pos.encode(); let decoded = Position::read(&mut encoded.as_ref()).unwrap(); assert_eq!(decoded, pos); // Test middle value let pos = Position::new(12345); let encoded = pos.encode(); let decoded = Position::read(&mut encoded.as_ref()).unwrap(); assert_eq!(decoded, pos); // Test MAX_POSITION (boundary) let pos = MAX_POSITION; let encoded = pos.encode(); let decoded = Position::read(&mut encoded.as_ref()).unwrap(); assert_eq!(decoded, pos); } #[test] fn test_read_cfg_invalid_values() { use commonware_codec::{varint::UInt, Encode, ReadExt}; // Encode MAX_POSITION + 1 as a raw varint, then try to decode as Position let invalid_value = *MAX_POSITION + 1; let encoded = UInt(invalid_value).encode(); let result = Position::read(&mut encoded.as_ref()); assert!(result.is_err()); assert!(matches!( result, Err(commonware_codec::Error::Invalid("Position", _)) )); // Encode u64::MAX as a raw varint let encoded = UInt(u64::MAX).encode(); let result = Position::read(&mut encoded.as_ref()); assert!(result.is_err()); assert!(matches!( result, Err(commonware_codec::Error::Invalid("Position", _)) )); } }