spnr_lib/
storage.rs

1//! Utilities to help manage canister stable memory.
2//!
3//! [ic-cdk] gives only low-level interface to canister stable memory, with no memory manager or allocator.
4//!
5//! To help with easy access:
6//!
7//! * [StableStorage] supports sequential read & write of bytes by implementing [io::Read] and [io::Write] traits.
8//!
9//! * [StorageStack] provides a stack interface allowing arbitrary values to be pushed onto and popped off the stable memory.
10//!   [StableStorage] implements this trait.
11//!   Being a trait it allows alternative implementations, for example in testing code.
12//!
13//! [ic-cdk]: https://docs.rs/ic-cdk/latest
14use ic_cdk::stable;
15use std::{error, fmt, io};
16
17/// Possible errors when dealing with stable memory.
18#[derive(Debug)]
19pub enum StorageError {
20    /// No more stable memory could be allocated.
21    OutOfMemory,
22    /// Attempted to read more stable memory than had been allocated.
23    OutOfBounds,
24    /// Candid encoding error.
25    Candid(candid::Error),
26}
27
28impl From<candid::Error> for StorageError {
29    fn from(err: candid::Error) -> StorageError {
30        StorageError::Candid(err)
31    }
32}
33
34impl From<StorageError> for io::Error {
35    fn from(err: StorageError) -> io::Error {
36        match err {
37            StorageError::Candid(err) => io::Error::new(io::ErrorKind::Other, err),
38            err => io::Error::new(io::ErrorKind::OutOfMemory, err),
39        }
40    }
41}
42
43impl fmt::Display for StorageError {
44    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45        match self {
46            Self::OutOfMemory => f.write_str("Out of memory"),
47            Self::OutOfBounds => f.write_str("Read exceeds allocated memory"),
48            Self::Candid(err) => write!(f, "{}", err),
49        }
50    }
51}
52
53impl error::Error for StorageError {}
54
55/// Representation of a memory address.
56pub type Offset = u64;
57
58/// Reader/Writer of the canister stable memory.
59///
60/// It keeps track of the current read/write offset, and will attempt to grow the stable memory as needed.
61pub struct StableStorage {
62    /// Offset of the next read or write.
63    pub offset: Offset,
64    /// Current capacity, as in number of pages.
65    capacity: u64,
66}
67
68/// The default instance of `StableStorage` starts at offset 0 if the current stable memory capacity is 0 (which means it is never used).
69///
70/// Otherwise it reads the offset value from the last 8 bytes (in little endian) of the stable memory.
71impl Default for StableStorage {
72    fn default() -> Self {
73        let mut storage = Self {
74            offset: 0,
75            capacity: stable::stable_size(),
76        };
77        if storage.capacity > 0 {
78            let cap = storage.capacity << 16;
79            let mut bytes = [0; 8];
80            stable::stable_read(cap - 8, &mut bytes);
81            storage.offset = u64::from_le_bytes(bytes);
82        }
83        storage
84    }
85}
86
87impl StableStorage {
88    /// Attempt to grow the memory by adding new pages.
89    fn grow(&mut self, added_pages: u64) -> Result<(), StorageError> {
90        let old_page_count =
91            stable::stable_grow(added_pages).map_err(|_| StorageError::OutOfMemory)?;
92        self.capacity = old_page_count + added_pages;
93        Ok(())
94    }
95
96    /// Create a new instance of [StableStorage].
97    pub fn new() -> Self {
98        Default::default()
99    }
100
101    /// Write current offset value to the last 8 bytes (in little-endian) of the stable memory.
102    /// This is an important step if you plan to later resume by reconstructing `StableStorage` from the stable memory.
103    pub fn finalize(mut self) -> Result<(), io::Error> {
104        let mut cap = self.capacity << 16;
105        if self.offset + 8 > cap {
106            self.grow(1)?;
107            cap = self.capacity << 16;
108        }
109        let bytes = self.offset.to_le_bytes();
110        io::Write::write(&mut self, &bytes)?;
111        stable::stable_write(cap - 8, &bytes);
112        Ok(())
113    }
114}
115
116impl io::Write for StableStorage {
117    fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
118        if self.offset + buf.len() as u64 > (self.capacity << 16) {
119            self.grow((buf.len() >> 16) as u64 + 1)?
120        }
121
122        stable::stable_write(self.offset, buf);
123        self.offset += buf.len() as u64;
124        Ok(buf.len())
125    }
126
127    fn flush(&mut self) -> Result<(), io::Error> {
128        Ok(())
129    }
130}
131
132/// Stack interface for stable memory that supports push and pop of arbitrary values that implement the [Candid] interface.
133///
134/// [Candid]: https://docs.rs/candid/latest
135pub trait StorageStack {
136    /// Return a new [StorageStack] object with the given offset.
137    fn new_with(&self, offset: Offset) -> Self;
138
139    /// Return the current read/write offset.
140    fn offset(&self) -> Offset;
141
142    /// Push a value to the end of the stack.
143    fn push<T>(&mut self, t: T) -> Result<(), io::Error>
144    where
145        T: candid::utils::ArgumentEncoder;
146
147    /// Pop a value from the end of the stack.
148    /// In case of `OutOfBounds` error, offset is not changed.
149    /// In case of Candid decoding error, offset may be changed.
150    fn pop<T>(&mut self) -> Result<T, io::Error>
151    where
152        T: for<'de> candid::utils::ArgumentDecoder<'de>;
153
154    /// Seek to the start of previous value by changing the offset.
155    /// This is similar to `pop` but without reading the actual value.
156    fn seek_prev(&mut self) -> Result<(), io::Error>;
157}
158
159impl StorageStack for StableStorage {
160    fn new_with(&self, offset: Offset) -> Self {
161        Self {
162            offset,
163            capacity: self.capacity,
164        }
165    }
166
167    fn offset(&self) -> Offset {
168        self.offset
169    }
170
171    fn push<T>(&mut self, t: T) -> Result<(), io::Error>
172    where
173        T: candid::utils::ArgumentEncoder,
174    {
175        let prev_offset = self.offset;
176        candid::write_args(self, t).map_err(StorageError::from)?;
177        let bytes = prev_offset.to_le_bytes();
178        io::Write::write(self, &bytes)?;
179        Ok(())
180    }
181
182    fn pop<T>(&mut self) -> Result<T, io::Error>
183    where
184        T: for<'de> candid::utils::ArgumentDecoder<'de>,
185    {
186        let end = self.offset - 8;
187        self.seek_prev()?;
188        let size = (end - self.offset) as usize;
189        let mut bytes = vec![0; size];
190        stable::stable_read(self.offset, &mut bytes);
191        let mut de = candid::de::IDLDeserialize::new(&bytes).map_err(StorageError::Candid)?;
192        let res = candid::utils::ArgumentDecoder::decode(&mut de).map_err(StorageError::Candid)?;
193        Ok(res)
194    }
195
196    fn seek_prev(&mut self) -> Result<(), io::Error> {
197        if self.offset < 8 {
198            return Err(StorageError::OutOfBounds.into());
199        }
200        let mut bytes = [0; 8];
201        let end = self.offset - 8;
202        stable::stable_read(end, &mut bytes);
203        let start = u64::from_le_bytes(bytes);
204        if start > end {
205            return Err(StorageError::OutOfBounds.into());
206        }
207        self.offset = start;
208        Ok(())
209    }
210}
211
212impl io::Read for StableStorage {
213    fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
214        let cap = self.capacity << 16;
215        let read_buf = if buf.len() as u64 + self.offset > cap {
216            if self.offset < cap {
217                &mut buf[..(cap - self.offset) as usize]
218            } else {
219                return Err(StorageError::OutOfBounds.into());
220            }
221        } else {
222            buf
223        };
224        stable::stable_read(self.offset, read_buf);
225        self.offset += read_buf.len() as u64;
226        Ok(read_buf.len())
227    }
228}
229
230pub mod test {
231    use super::*;
232    use candid::encode_args;
233    use std::cell::RefCell;
234    use std::io;
235    use std::rc::Rc;
236
237    /// A vector-based implementation of [StorageStack], used for testing purpose.
238    #[derive(Clone, Default)]
239    pub struct Stack {
240        stack: Rc<RefCell<Vec<Vec<u8>>>>,
241        offset: Offset,
242        index: usize,
243    }
244
245    impl StorageStack for Stack {
246        fn new_with(&self, offset: Offset) -> Stack {
247            let mut s = 0;
248            let mut index = 0;
249            while s < offset {
250                s += self.stack.as_ref().borrow()[index].len() as Offset;
251                index += 1;
252            }
253            Stack {
254                stack: Rc::clone(&self.stack),
255                offset,
256                index,
257            }
258        }
259
260        fn offset(&self) -> Offset {
261            self.offset
262        }
263
264        /// Save a value to the end of stable memory.
265        fn push<T>(&mut self, t: T) -> Result<(), io::Error>
266        where
267            T: candid::utils::ArgumentEncoder,
268        {
269            let bytes: Vec<u8> = encode_args(t).unwrap();
270            self.offset += bytes.len() as Offset;
271            let mut stack = self.stack.borrow_mut();
272            if stack.len() > self.index {
273                stack[self.index] = bytes;
274            } else {
275                stack.push(bytes)
276            }
277            self.index += 1;
278            Ok(())
279        }
280
281        /// Pop a value from the end of stable memory.
282        /// In case of `OutOfBounds` error, offset is not changed.
283        /// In case of Candid decoding error, offset will be changed anyway.
284        fn pop<T>(&mut self) -> Result<T, io::Error>
285        where
286            T: for<'de> candid::utils::ArgumentDecoder<'de>,
287        {
288            self.seek_prev()?;
289            let bytes = self.stack.borrow()[self.index].clone();
290            let mut de = candid::de::IDLDeserialize::new(&bytes).unwrap();
291            Ok(candid::utils::ArgumentDecoder::decode(&mut de).unwrap())
292        }
293
294        fn seek_prev(&mut self) -> Result<(), io::Error> {
295            assert!(self.index > 0);
296            let bytes = self.stack.borrow()[self.index - 1].clone();
297            self.index -= 1;
298            assert!(self.offset >= bytes.len() as Offset);
299            self.offset -= bytes.len() as Offset;
300            Ok(())
301        }
302    }
303}