1use wasm_dbms_api::prelude::{
5 DEFAULT_ALIGNMENT, DataSize, Encode, MSize, MemoryError, MemoryResult, PageOffset,
6};
7
8use crate::{MemoryAccess, MemoryManager, MemoryProvider};
9
10pub trait AccessControl: Default {
18 type Id;
20
21 fn load<M>(mm: &MemoryManager<M>) -> MemoryResult<Self>
23 where
24 M: MemoryProvider,
25 Self: Sized;
26
27 fn is_allowed(&self, identity: &Self::Id) -> bool;
29
30 fn allowed_identities(&self) -> Vec<Self::Id>;
32
33 fn add_identity<M>(
35 &mut self,
36 identity: Self::Id,
37 mm: &mut MemoryManager<M>,
38 ) -> MemoryResult<()>
39 where
40 M: MemoryProvider;
41
42 fn remove_identity<M>(
44 &mut self,
45 identity: &Self::Id,
46 mm: &mut MemoryManager<M>,
47 ) -> MemoryResult<()>
48 where
49 M: MemoryProvider;
50}
51
52#[derive(Default, Debug, Clone, PartialEq, Eq)]
57pub struct NoAccessControl;
58
59impl AccessControl for NoAccessControl {
60 type Id = ();
61
62 fn load<M>(_mm: &MemoryManager<M>) -> MemoryResult<Self>
63 where
64 M: MemoryProvider,
65 {
66 Ok(Self)
67 }
68
69 fn is_allowed(&self, _identity: &Self::Id) -> bool {
70 true
71 }
72
73 fn allowed_identities(&self) -> Vec<Self::Id> {
74 vec![]
75 }
76
77 fn add_identity<M>(
78 &mut self,
79 _identity: Self::Id,
80 _mm: &mut MemoryManager<M>,
81 ) -> MemoryResult<()>
82 where
83 M: MemoryProvider,
84 {
85 Ok(())
86 }
87
88 fn remove_identity<M>(
89 &mut self,
90 _identity: &Self::Id,
91 _mm: &mut MemoryManager<M>,
92 ) -> MemoryResult<()>
93 where
94 M: MemoryProvider,
95 {
96 Ok(())
97 }
98}
99
100#[derive(Clone, Debug, Default, PartialEq, Eq)]
105pub struct AccessControlList {
106 allowed: Vec<Vec<u8>>,
107}
108
109impl AccessControlList {
110 fn save<M>(&self, mm: &mut MemoryManager<M>) -> MemoryResult<()>
112 where
113 M: MemoryProvider,
114 {
115 mm.write_at(mm.acl_page(), 0, self)
116 }
117}
118
119impl AccessControl for AccessControlList {
120 type Id = Vec<u8>;
121
122 fn load<M>(mm: &MemoryManager<M>) -> MemoryResult<Self>
123 where
124 M: MemoryProvider,
125 {
126 mm.read_at(mm.acl_page(), 0)
127 }
128
129 fn is_allowed(&self, identity: &Self::Id) -> bool {
130 self.allowed
131 .iter()
132 .any(|a| a.as_slice() == identity.as_slice())
133 }
134
135 fn allowed_identities(&self) -> Vec<Self::Id> {
136 self.allowed.clone()
137 }
138
139 fn add_identity<M>(&mut self, identity: Self::Id, mm: &mut MemoryManager<M>) -> MemoryResult<()>
140 where
141 M: MemoryProvider,
142 {
143 if !self.is_allowed(&identity) {
144 self.allowed.push(identity);
145 self.save(mm)?;
146 }
147
148 Ok(())
149 }
150
151 fn remove_identity<M>(
152 &mut self,
153 identity: &Self::Id,
154 mm: &mut MemoryManager<M>,
155 ) -> MemoryResult<()>
156 where
157 M: MemoryProvider,
158 {
159 if let Some(pos) = self
160 .allowed
161 .iter()
162 .position(|p| p.as_slice() == identity.as_slice())
163 {
164 if self.allowed.len() == 1 {
165 return Err(MemoryError::ConstraintViolation(
166 "ACL must contain at least one identity".to_string(),
167 ));
168 }
169 self.allowed.swap_remove(pos);
170 self.save(mm)?;
171 }
172
173 Ok(())
174 }
175}
176
177impl Encode for AccessControlList {
178 const SIZE: DataSize = DataSize::Dynamic;
179
180 const ALIGNMENT: PageOffset = DEFAULT_ALIGNMENT;
181
182 fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
183 let mut bytes = Vec::with_capacity(self.size() as usize);
185 let len = self.allowed.len() as u32;
186 bytes.extend_from_slice(&len.to_le_bytes());
187 for identity in &self.allowed {
188 let identity_len = identity.len() as u8;
189 bytes.extend_from_slice(&identity_len.to_le_bytes());
190 bytes.extend_from_slice(identity);
191 }
192 std::borrow::Cow::Owned(bytes)
193 }
194
195 fn decode(data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
196 where
197 Self: Sized,
198 {
199 let mut offset = 0;
201 let len_bytes = &data[offset..offset + 4];
202 offset += 4;
203 let len = u32::from_le_bytes(len_bytes.try_into()?) as usize;
204
205 let mut allowed = Vec::with_capacity(len);
207 for _ in 0..len {
208 let identity_len_bytes = &data[offset..offset + 1];
209 offset += 1;
210 let identity_len = u8::from_le_bytes(identity_len_bytes.try_into()?) as usize;
211
212 let identity_bytes = data[offset..offset + identity_len].to_vec();
213 offset += identity_len;
214
215 allowed.push(identity_bytes);
216 }
217 Ok(AccessControlList { allowed })
218 }
219
220 fn size(&self) -> MSize {
221 4 + self
223 .allowed
224 .iter()
225 .map(|p| 1 + p.len() as MSize)
226 .sum::<MSize>()
227 }
228}
229
230#[cfg(test)]
231mod tests {
232
233 use super::*;
234 use crate::HeapMemoryProvider;
235
236 fn make_mm() -> MemoryManager<HeapMemoryProvider> {
237 MemoryManager::init(HeapMemoryProvider::default())
238 }
239
240 #[test]
241 fn test_acl_encode_decode() {
242 let acl = AccessControlList {
243 allowed: vec![
244 vec![0x04], vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01],
246 vec![0xDE, 0xAD, 0xBE, 0xEF],
247 vec![0x01, 0x02, 0x03, 0x04, 0x05],
248 ],
249 };
250
251 let encoded = acl.encode();
252 let decoded = AccessControlList::decode(encoded).unwrap();
253
254 assert_eq!(acl, decoded);
255 }
256
257 #[test]
258 fn test_acl_add_remove_identity() {
259 let mut mm = make_mm();
260 let mut acl = AccessControlList::default();
261 let identity = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
262 assert!(!acl.is_allowed(&identity));
263 acl.add_identity(identity.clone(), &mut mm).unwrap();
264 let other = vec![0xDE, 0xAD, 0xBE, 0xEF];
265 acl.add_identity(other.clone(), &mut mm).unwrap();
266 assert!(acl.is_allowed(&identity));
267 assert!(acl.is_allowed(&other));
268 assert_eq!(acl.allowed_identities().len(), 2);
269 acl.remove_identity(&other, &mut mm).unwrap();
270 }
271
272 #[test]
273 fn test_remove_last_identity_returns_error() {
274 let mut mm = make_mm();
275 let mut acl = AccessControlList::default();
276 let identity = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
277 acl.add_identity(identity.clone(), &mut mm).unwrap();
278 assert!(acl.is_allowed(&identity));
279 let result = acl.remove_identity(&identity, &mut mm);
280 assert!(result.is_err());
281 assert!(matches!(
282 result.unwrap_err(),
283 MemoryError::ConstraintViolation(_)
284 ));
285 }
286
287 #[test]
288 fn test_should_add_more_identities() {
289 let mut mm = make_mm();
290 let mut acl = AccessControlList::default();
291 let identity1 = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
292 let identity2 = vec![0xDE, 0xAD, 0xBE, 0xEF];
293 acl.add_identity(identity1.clone(), &mut mm).unwrap();
294 acl.add_identity(identity2.clone(), &mut mm).unwrap();
295 assert!(acl.is_allowed(&identity1));
296 assert!(acl.is_allowed(&identity2));
297 assert_eq!(
298 acl.allowed_identities(),
299 vec![identity1.clone(), identity2.clone()]
300 );
301 }
302
303 #[test]
304 fn test_add_identity_should_write_to_memory() {
305 let mut mm = make_mm();
306 let mut acl = AccessControlList::default();
307 let identity = vec![0x00, 0x00, 0x00, 0x00, 0x01, 0x01];
308 acl.add_identity(identity.clone(), &mut mm).unwrap();
309
310 let loaded_acl = AccessControlList::load(&mm).unwrap();
312 assert!(loaded_acl.is_allowed(&identity));
313 }
314
315 #[test]
316 fn test_no_access_control_allows_everything() {
317 let acl = NoAccessControl;
318 assert!(acl.is_allowed(&()));
319 assert!(acl.allowed_identities().is_empty());
320 }
321
322 #[test]
323 fn test_add_duplicate_identity_is_idempotent() {
324 let mut mm = make_mm();
325 let mut acl = AccessControlList::default();
326 let identity = vec![0x01, 0x02, 0x03];
327 acl.add_identity(identity.clone(), &mut mm).unwrap();
328 acl.add_identity(identity.clone(), &mut mm).unwrap();
329 assert_eq!(acl.allowed_identities().len(), 1);
330 assert!(acl.is_allowed(&identity));
331 }
332
333 #[test]
334 fn test_remove_nonexistent_identity_is_noop() {
335 let mut mm = make_mm();
336 let mut acl = AccessControlList::default();
337 let identity_a = vec![0x01, 0x02];
338 let identity_b = vec![0x03, 0x04];
339 acl.add_identity(identity_a.clone(), &mut mm).unwrap();
340 acl.add_identity(identity_b.clone(), &mut mm).unwrap();
341 let nonexistent = vec![0xFF, 0xFF];
343 acl.remove_identity(&nonexistent, &mut mm).unwrap();
344 assert_eq!(acl.allowed_identities().len(), 2);
345 }
346
347 #[test]
348 fn test_no_access_control_load() {
349 let mm = make_mm();
350 let acl = NoAccessControl::load(&mm).unwrap();
351 assert!(acl.is_allowed(&()));
352 }
353
354 #[test]
355 fn test_no_access_control_add_and_remove_identity() {
356 let mut mm = make_mm();
357 let mut acl = NoAccessControl;
358 acl.add_identity((), &mut mm).unwrap();
359 acl.remove_identity(&(), &mut mm).unwrap();
360 assert!(acl.is_allowed(&()));
362 }
363
364 #[test]
365 fn test_empty_acl_encode_decode() {
366 let acl = AccessControlList::default();
367 let encoded = acl.encode();
368 let decoded = AccessControlList::decode(encoded).unwrap();
369 assert_eq!(acl, decoded);
370 assert!(decoded.allowed_identities().is_empty());
371 }
372
373 #[test]
374 fn test_acl_size() {
375 let acl = AccessControlList {
376 allowed: vec![vec![0x01, 0x02], vec![0x03]],
377 };
378 assert_eq!(acl.size(), 4 + 3 + 2);
380 }
381
382 #[test]
383 fn test_empty_acl_size() {
384 let acl = AccessControlList::default();
385 assert_eq!(acl.size(), 4);
386 }
387}