1use ic_cdk::stable;
15use std::{error, fmt, io};
16
17#[derive(Debug)]
19pub enum StorageError {
20 OutOfMemory,
22 OutOfBounds,
24 Candid(candid::Error),
26}
27
28impl From<candid::Error> for StorageError {
29 fn from(err: candid::Error) -> StorageError {
30 StorageError::Candid(err)
31 }
32}
33
34impl From<StorageError> for io::Error {
35 fn from(err: StorageError) -> io::Error {
36 match err {
37 StorageError::Candid(err) => io::Error::new(io::ErrorKind::Other, err),
38 err => io::Error::new(io::ErrorKind::OutOfMemory, err),
39 }
40 }
41}
42
43impl fmt::Display for StorageError {
44 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45 match self {
46 Self::OutOfMemory => f.write_str("Out of memory"),
47 Self::OutOfBounds => f.write_str("Read exceeds allocated memory"),
48 Self::Candid(err) => write!(f, "{}", err),
49 }
50 }
51}
52
53impl error::Error for StorageError {}
54
55pub type Offset = u64;
57
58pub struct StableStorage {
62 pub offset: Offset,
64 capacity: u64,
66}
67
68impl Default for StableStorage {
72 fn default() -> Self {
73 let mut storage = Self {
74 offset: 0,
75 capacity: stable::stable_size(),
76 };
77 if storage.capacity > 0 {
78 let cap = storage.capacity << 16;
79 let mut bytes = [0; 8];
80 stable::stable_read(cap - 8, &mut bytes);
81 storage.offset = u64::from_le_bytes(bytes);
82 }
83 storage
84 }
85}
86
87impl StableStorage {
88 fn grow(&mut self, added_pages: u64) -> Result<(), StorageError> {
90 let old_page_count =
91 stable::stable_grow(added_pages).map_err(|_| StorageError::OutOfMemory)?;
92 self.capacity = old_page_count + added_pages;
93 Ok(())
94 }
95
96 pub fn new() -> Self {
98 Default::default()
99 }
100
101 pub fn finalize(mut self) -> Result<(), io::Error> {
104 let mut cap = self.capacity << 16;
105 if self.offset + 8 > cap {
106 self.grow(1)?;
107 cap = self.capacity << 16;
108 }
109 let bytes = self.offset.to_le_bytes();
110 io::Write::write(&mut self, &bytes)?;
111 stable::stable_write(cap - 8, &bytes);
112 Ok(())
113 }
114}
115
116impl io::Write for StableStorage {
117 fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
118 if self.offset + buf.len() as u64 > (self.capacity << 16) {
119 self.grow((buf.len() >> 16) as u64 + 1)?
120 }
121
122 stable::stable_write(self.offset, buf);
123 self.offset += buf.len() as u64;
124 Ok(buf.len())
125 }
126
127 fn flush(&mut self) -> Result<(), io::Error> {
128 Ok(())
129 }
130}
131
132pub trait StorageStack {
136 fn new_with(&self, offset: Offset) -> Self;
138
139 fn offset(&self) -> Offset;
141
142 fn push<T>(&mut self, t: T) -> Result<(), io::Error>
144 where
145 T: candid::utils::ArgumentEncoder;
146
147 fn pop<T>(&mut self) -> Result<T, io::Error>
151 where
152 T: for<'de> candid::utils::ArgumentDecoder<'de>;
153
154 fn seek_prev(&mut self) -> Result<(), io::Error>;
157}
158
159impl StorageStack for StableStorage {
160 fn new_with(&self, offset: Offset) -> Self {
161 Self {
162 offset,
163 capacity: self.capacity,
164 }
165 }
166
167 fn offset(&self) -> Offset {
168 self.offset
169 }
170
171 fn push<T>(&mut self, t: T) -> Result<(), io::Error>
172 where
173 T: candid::utils::ArgumentEncoder,
174 {
175 let prev_offset = self.offset;
176 candid::write_args(self, t).map_err(StorageError::from)?;
177 let bytes = prev_offset.to_le_bytes();
178 io::Write::write(self, &bytes)?;
179 Ok(())
180 }
181
182 fn pop<T>(&mut self) -> Result<T, io::Error>
183 where
184 T: for<'de> candid::utils::ArgumentDecoder<'de>,
185 {
186 let end = self.offset - 8;
187 self.seek_prev()?;
188 let size = (end - self.offset) as usize;
189 let mut bytes = vec![0; size];
190 stable::stable_read(self.offset, &mut bytes);
191 let mut de = candid::de::IDLDeserialize::new(&bytes).map_err(StorageError::Candid)?;
192 let res = candid::utils::ArgumentDecoder::decode(&mut de).map_err(StorageError::Candid)?;
193 Ok(res)
194 }
195
196 fn seek_prev(&mut self) -> Result<(), io::Error> {
197 if self.offset < 8 {
198 return Err(StorageError::OutOfBounds.into());
199 }
200 let mut bytes = [0; 8];
201 let end = self.offset - 8;
202 stable::stable_read(end, &mut bytes);
203 let start = u64::from_le_bytes(bytes);
204 if start > end {
205 return Err(StorageError::OutOfBounds.into());
206 }
207 self.offset = start;
208 Ok(())
209 }
210}
211
212impl io::Read for StableStorage {
213 fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
214 let cap = self.capacity << 16;
215 let read_buf = if buf.len() as u64 + self.offset > cap {
216 if self.offset < cap {
217 &mut buf[..(cap - self.offset) as usize]
218 } else {
219 return Err(StorageError::OutOfBounds.into());
220 }
221 } else {
222 buf
223 };
224 stable::stable_read(self.offset, read_buf);
225 self.offset += read_buf.len() as u64;
226 Ok(read_buf.len())
227 }
228}
229
230pub mod test {
231 use super::*;
232 use candid::encode_args;
233 use std::cell::RefCell;
234 use std::io;
235 use std::rc::Rc;
236
237 #[derive(Clone, Default)]
239 pub struct Stack {
240 stack: Rc<RefCell<Vec<Vec<u8>>>>,
241 offset: Offset,
242 index: usize,
243 }
244
245 impl StorageStack for Stack {
246 fn new_with(&self, offset: Offset) -> Stack {
247 let mut s = 0;
248 let mut index = 0;
249 while s < offset {
250 s += self.stack.as_ref().borrow()[index].len() as Offset;
251 index += 1;
252 }
253 Stack {
254 stack: Rc::clone(&self.stack),
255 offset,
256 index,
257 }
258 }
259
260 fn offset(&self) -> Offset {
261 self.offset
262 }
263
264 fn push<T>(&mut self, t: T) -> Result<(), io::Error>
266 where
267 T: candid::utils::ArgumentEncoder,
268 {
269 let bytes: Vec<u8> = encode_args(t).unwrap();
270 self.offset += bytes.len() as Offset;
271 let mut stack = self.stack.borrow_mut();
272 if stack.len() > self.index {
273 stack[self.index] = bytes;
274 } else {
275 stack.push(bytes)
276 }
277 self.index += 1;
278 Ok(())
279 }
280
281 fn pop<T>(&mut self) -> Result<T, io::Error>
285 where
286 T: for<'de> candid::utils::ArgumentDecoder<'de>,
287 {
288 self.seek_prev()?;
289 let bytes = self.stack.borrow()[self.index].clone();
290 let mut de = candid::de::IDLDeserialize::new(&bytes).unwrap();
291 Ok(candid::utils::ArgumentDecoder::decode(&mut de).unwrap())
292 }
293
294 fn seek_prev(&mut self) -> Result<(), io::Error> {
295 assert!(self.index > 0);
296 let bytes = self.stack.borrow()[self.index - 1].clone();
297 self.index -= 1;
298 assert!(self.offset >= bytes.len() as Offset);
299 self.offset -= bytes.len() as Offset;
300 Ok(())
301 }
302 }
303}