whale/
revision.rs

1//! Revision and Durability types for the Whale incremental computation system.
2//!
3//! This module provides the foundational types for tracking revisions
4//! at different durability levels, following the Lean4 formal specification.
5
6use std::sync::atomic::{AtomicU64, Ordering};
7
8/// Revision counter type - monotonically increasing counter for tracking changes.
9pub type RevisionCounter = u64;
10
11/// Durability level: 0 = most volatile, N-1 = most stable.
12///
13/// A node's durability determines which revision counter(s) are used to track
14/// its validity. Lower durability levels change more frequently.
15///
16/// # Invariant
17/// A node's durability must not exceed the minimum durability of its dependencies.
18/// This ensures that a node doesn't promise to be more stable than its sources.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct Durability<const N: usize>(usize);
22
23impl<const N: usize> Durability<N> {
24    /// Create a new durability level.
25    ///
26    /// Returns `None` if `level >= N`.
27    #[inline]
28    pub const fn new(level: usize) -> Option<Self> {
29        if level < N {
30            Some(Self(level))
31        } else {
32            None
33        }
34    }
35
36    /// Get the numeric value of this durability level.
37    #[inline]
38    pub const fn value(&self) -> usize {
39        self.0
40    }
41
42    /// Most volatile durability level (0).
43    #[inline]
44    pub const fn volatile() -> Self {
45        Self(0)
46    }
47
48    /// Most stable durability level (N-1).
49    #[inline]
50    pub const fn stable() -> Self {
51        Self(N - 1)
52    }
53
54    /// Get the minimum of two durability levels.
55    #[inline]
56    pub fn min(self, other: Self) -> Self {
57        if self.0 <= other.0 {
58            self
59        } else {
60            other
61        }
62    }
63}
64
65/// Revision snapshot: array of counters indexed by durability level.
66///
67/// `revision[d]` is the counter for durability level `d`.
68/// When a node at durability D changes, we increment `revision[0..=D]`.
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub struct Revision<const N: usize> {
71    counters: [RevisionCounter; N],
72}
73
74// Manual serde implementation to handle const generic arrays
75#[cfg(feature = "serde")]
76impl<const N: usize> serde::Serialize for Revision<N> {
77    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
78    where
79        S: serde::Serializer,
80    {
81        self.counters.as_slice().serialize(serializer)
82    }
83}
84
85#[cfg(feature = "serde")]
86impl<'de, const N: usize> serde::Deserialize<'de> for Revision<N> {
87    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
88    where
89        D: serde::Deserializer<'de>,
90    {
91        let vec: Vec<RevisionCounter> = Vec::deserialize(deserializer)?;
92        if vec.len() != N {
93            return Err(serde::de::Error::custom(format!(
94                "expected {} elements, got {}",
95                N,
96                vec.len()
97            )));
98        }
99        let mut counters = [0; N];
100        counters.copy_from_slice(&vec);
101        Ok(Self { counters })
102    }
103}
104
105impl<const N: usize> Default for Revision<N> {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl<const N: usize> Revision<N> {
112    /// Create a new revision with all counters at 0.
113    #[inline]
114    pub const fn new() -> Self {
115        Self { counters: [0; N] }
116    }
117
118    /// Get the revision counter for durability level `d`.
119    #[inline]
120    pub fn get(&self, d: Durability<N>) -> RevisionCounter {
121        self.counters[d.0]
122    }
123
124    /// Set the revision counter for durability level `d`.
125    #[inline]
126    pub fn set(&mut self, d: Durability<N>, value: RevisionCounter) {
127        self.counters[d.0] = value;
128    }
129
130    /// Get mutable reference to the underlying counters array.
131    #[inline]
132    pub fn counters_mut(&mut self) -> &mut [RevisionCounter; N] {
133        &mut self.counters
134    }
135}
136
137/// Atomic revision counters for concurrent access.
138///
139/// This structure holds per-durability-level revision counters that can be
140/// atomically incremented. Per the specification, when incrementing at level D,
141/// we must also increment all levels 0..=D.
142pub struct AtomicRevision<const N: usize> {
143    counters: [AtomicU64; N],
144}
145
146impl<const N: usize> Default for AtomicRevision<N> {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152impl<const N: usize> AtomicRevision<N> {
153    /// Create new atomic revision with all counters at 0.
154    pub fn new() -> Self {
155        Self {
156            counters: std::array::from_fn(|_| AtomicU64::new(0)),
157        }
158    }
159
160    /// Get current revision counter for durability level `d`.
161    #[inline]
162    pub fn get(&self, d: Durability<N>) -> RevisionCounter {
163        self.counters[d.0].load(Ordering::Acquire)
164    }
165
166    /// Increment revision at durability level `d` and all lower levels (0..=d).
167    ///
168    /// Per specification: "for i in 0..=D: revision[i].fetch_add(1, AcqRel)"
169    ///
170    /// Returns the new revision counter at level `d`.
171    pub fn increment(&self, d: Durability<N>) -> RevisionCounter {
172        let mut new_rev = 0;
173        for i in 0..=d.0 {
174            new_rev = self.counters[i].fetch_add(1, Ordering::AcqRel) + 1;
175        }
176        new_rev
177    }
178
179    /// Take a snapshot of the current revision state.
180    pub fn snapshot(&self) -> Revision<N> {
181        let mut counters = [0; N];
182        for (dst, src) in counters.iter_mut().zip(self.counters.iter()) {
183            *dst = src.load(Ordering::Acquire);
184        }
185        Revision { counters }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_durability_creation() {
195        let d: Option<Durability<3>> = Durability::new(0);
196        assert!(d.is_some());
197        assert_eq!(d.unwrap().value(), 0);
198
199        let d: Option<Durability<3>> = Durability::new(2);
200        assert!(d.is_some());
201        assert_eq!(d.unwrap().value(), 2);
202
203        let d: Option<Durability<3>> = Durability::new(3);
204        assert!(d.is_none());
205    }
206
207    #[test]
208    fn test_durability_volatile_stable() {
209        let volatile: Durability<3> = Durability::volatile();
210        assert_eq!(volatile.value(), 0);
211
212        let stable: Durability<3> = Durability::stable();
213        assert_eq!(stable.value(), 2);
214    }
215
216    #[test]
217    fn test_durability_min() {
218        let d0: Durability<3> = Durability::new(0).unwrap();
219        let d1: Durability<3> = Durability::new(1).unwrap();
220        let d2: Durability<3> = Durability::new(2).unwrap();
221
222        assert_eq!(d0.min(d1), d0);
223        assert_eq!(d1.min(d0), d0);
224        assert_eq!(d1.min(d2), d1);
225    }
226
227    #[test]
228    fn test_revision_operations() {
229        let mut rev: Revision<3> = Revision::new();
230        let d1: Durability<3> = Durability::new(1).unwrap();
231
232        assert_eq!(rev.get(d1), 0);
233        rev.set(d1, 42);
234        assert_eq!(rev.get(d1), 42);
235    }
236
237    #[test]
238    fn test_atomic_revision_increment() {
239        let atomic: AtomicRevision<3> = AtomicRevision::new();
240        let d0: Durability<3> = Durability::new(0).unwrap();
241        let d1: Durability<3> = Durability::new(1).unwrap();
242        let d2: Durability<3> = Durability::new(2).unwrap();
243
244        // Increment at level 0 - only level 0 should increase
245        let new_rev = atomic.increment(d0);
246        assert_eq!(new_rev, 1);
247        assert_eq!(atomic.get(d0), 1);
248        assert_eq!(atomic.get(d1), 0);
249        assert_eq!(atomic.get(d2), 0);
250
251        // Increment at level 2 - levels 0, 1, 2 should all increase
252        let new_rev = atomic.increment(d2);
253        assert_eq!(new_rev, 1); // d2 goes from 0 to 1
254        assert_eq!(atomic.get(d0), 2); // d0 goes from 1 to 2
255        assert_eq!(atomic.get(d1), 1); // d1 goes from 0 to 1
256        assert_eq!(atomic.get(d2), 1); // d2 goes from 0 to 1
257    }
258
259    #[test]
260    fn test_atomic_revision_snapshot() {
261        let atomic: AtomicRevision<3> = AtomicRevision::new();
262        let d1: Durability<3> = Durability::new(1).unwrap();
263
264        atomic.increment(d1);
265
266        let snapshot = atomic.snapshot();
267        assert_eq!(snapshot.get(Durability::new(0).unwrap()), 1);
268        assert_eq!(snapshot.get(Durability::new(1).unwrap()), 1);
269        assert_eq!(snapshot.get(Durability::new(2).unwrap()), 0);
270    }
271}