1use std::collections::HashMap;
11
12use wasm_dbms_api::prelude::{
13 DEFAULT_ALIGNMENT, DataSize, Encode, IdentityPerms, MSize, MemoryError, MemoryResult,
14 PageOffset, PermGrant, PermRevoke, TableFingerprint, TablePerms,
15};
16
17use super::traits::AccessControl;
18use crate::{MemoryAccess, MemoryManager, MemoryProvider};
19
20const LAYOUT_VERSION: u8 = 2;
21const FLAG_ADMIN: u8 = 0b0000_0001;
22const FLAG_MANAGE_ACL: u8 = 0b0000_0010;
23const FLAG_MIGRATE: u8 = 0b0000_0100;
24
25#[derive(Clone, Debug, Default, PartialEq, Eq)]
27pub struct AccessControlList {
28 entries: HashMap<Vec<u8>, IdentityPerms>,
29}
30
31impl AccessControlList {
32 fn save<M>(&self, mm: &mut MemoryManager<M>) -> MemoryResult<()>
33 where
34 M: MemoryProvider,
35 {
36 mm.write_at(mm.acl_page(), 0, self)
37 }
38
39 fn manage_acl_count(&self) -> usize {
40 self.entries.values().filter(|p| p.manage_acl).count()
41 }
42
43 fn flags_byte(p: &IdentityPerms) -> u8 {
44 let mut b = 0u8;
45 if p.admin {
46 b |= FLAG_ADMIN;
47 }
48 if p.manage_acl {
49 b |= FLAG_MANAGE_ACL;
50 }
51 if p.migrate {
52 b |= FLAG_MIGRATE;
53 }
54 b
55 }
56
57 fn perms_from_flags(flags: u8) -> (bool, bool, bool) {
58 (
59 flags & FLAG_ADMIN != 0,
60 flags & FLAG_MANAGE_ACL != 0,
61 flags & FLAG_MIGRATE != 0,
62 )
63 }
64}
65
66impl AccessControl for AccessControlList {
67 type Id = Vec<u8>;
68
69 fn load<M>(mm: &mut MemoryManager<M>) -> MemoryResult<Self>
70 where
71 M: MemoryProvider,
72 {
73 mm.read_at(mm.acl_page(), 0)
74 }
75
76 fn granted(&self, id: &Self::Id, table: TableFingerprint, perm: TablePerms) -> bool {
77 match self.entries.get(id) {
78 Some(p) => p.grants_table(table, perm),
79 None => false,
80 }
81 }
82
83 fn granted_admin(&self, id: &Self::Id) -> bool {
84 self.entries.get(id).is_some_and(|p| p.admin)
85 }
86
87 fn granted_manage_acl(&self, id: &Self::Id) -> bool {
88 self.entries.get(id).is_some_and(|p| p.manage_acl)
89 }
90
91 fn granted_migrate(&self, id: &Self::Id) -> bool {
92 self.entries.get(id).is_some_and(|p| p.migrate)
93 }
94
95 fn grant<M>(
96 &mut self,
97 id: Self::Id,
98 grant: PermGrant,
99 mm: &mut MemoryManager<M>,
100 ) -> MemoryResult<()>
101 where
102 M: MemoryProvider,
103 {
104 let entry = self.entries.entry(id).or_default();
105 entry.apply_grant(grant);
106 self.save(mm)
107 }
108
109 fn revoke<M>(
110 &mut self,
111 id: &Self::Id,
112 revoke: PermRevoke,
113 mm: &mut MemoryManager<M>,
114 ) -> MemoryResult<()>
115 where
116 M: MemoryProvider,
117 {
118 if !self.entries.contains_key(id) {
119 return Ok(());
120 }
121 if matches!(revoke, PermRevoke::ManageAcl)
122 && self.entries.get(id).is_some_and(|p| p.manage_acl)
123 && self.manage_acl_count() == 1
124 {
125 return Err(MemoryError::ConstraintViolation(
126 "at least one identity must retain manage_acl".to_string(),
127 ));
128 }
129 let entry = self.entries.get_mut(id).expect("checked above");
130 entry.apply_revoke(revoke);
131 if entry.is_empty() {
132 self.entries.remove(id);
133 }
134 self.save(mm)
135 }
136
137 fn remove_identity<M>(&mut self, id: &Self::Id, mm: &mut MemoryManager<M>) -> MemoryResult<()>
138 where
139 M: MemoryProvider,
140 {
141 let Some(entry) = self.entries.get(id) else {
142 return Ok(());
143 };
144 if entry.manage_acl && self.manage_acl_count() == 1 {
145 return Err(MemoryError::ConstraintViolation(
146 "at least one identity must retain manage_acl".to_string(),
147 ));
148 }
149 self.entries.remove(id);
150 self.save(mm)
151 }
152
153 fn perms(&self, id: &Self::Id) -> IdentityPerms {
154 self.entries.get(id).cloned().unwrap_or_default()
155 }
156
157 fn identities(&self) -> Vec<(Self::Id, IdentityPerms)> {
158 self.entries
159 .iter()
160 .map(|(k, v)| (k.clone(), v.clone()))
161 .collect()
162 }
163}
164
165impl Encode for AccessControlList {
166 const SIZE: DataSize = DataSize::Dynamic;
167 const ALIGNMENT: PageOffset = DEFAULT_ALIGNMENT;
168
169 fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
170 let mut bytes = Vec::with_capacity(self.size() as usize);
171 bytes.push(LAYOUT_VERSION);
172 let count = self.entries.len() as u32;
173 bytes.extend_from_slice(&count.to_le_bytes());
174 let mut sorted: Vec<_> = self.entries.iter().collect();
176 sorted.sort_by_key(|(a, _)| *a);
177 for (id, perms) in sorted {
178 bytes.push(id.len() as u8);
179 bytes.extend_from_slice(id);
180 bytes.push(Self::flags_byte(perms));
181 bytes.push(perms.all_tables.bits());
182 let pt_count = perms.per_table.len() as u16;
183 bytes.extend_from_slice(&pt_count.to_le_bytes());
184 let mut pt_sorted: Vec<_> = perms.per_table.iter().collect();
185 pt_sorted.sort_by_key(|(t, _)| *t);
186 for (table, p) in pt_sorted {
187 bytes.extend_from_slice(&table.to_le_bytes());
188 bytes.push(p.bits());
189 }
190 }
191 std::borrow::Cow::Owned(bytes)
192 }
193
194 fn decode(data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
195 where
196 Self: Sized,
197 {
198 let mut offset = 0;
199 let version = *data.get(offset).ok_or(MemoryError::AclLayoutUnsupported)?;
200 offset += 1;
201 if version != LAYOUT_VERSION {
202 return Err(MemoryError::AclLayoutUnsupported);
203 }
204 let count_bytes = data
205 .get(offset..offset + 4)
206 .ok_or(MemoryError::AclLayoutUnsupported)?;
207 offset += 4;
208 let count = u32::from_le_bytes(count_bytes.try_into()?) as usize;
209 let mut entries = HashMap::with_capacity(count);
210 for _ in 0..count {
211 let id_len = *data.get(offset).ok_or(MemoryError::AclLayoutUnsupported)? as usize;
212 offset += 1;
213 let id = data
214 .get(offset..offset + id_len)
215 .ok_or(MemoryError::AclLayoutUnsupported)?
216 .to_vec();
217 offset += id_len;
218 let flags = *data.get(offset).ok_or(MemoryError::AclLayoutUnsupported)?;
219 offset += 1;
220 let (admin, manage_acl, migrate) = Self::perms_from_flags(flags);
221 let all_tables = TablePerms::from_bits_truncate(
222 *data.get(offset).ok_or(MemoryError::AclLayoutUnsupported)?,
223 );
224 offset += 1;
225 let pt_count_bytes = data
226 .get(offset..offset + 2)
227 .ok_or(MemoryError::AclLayoutUnsupported)?;
228 offset += 2;
229 let pt_count = u16::from_le_bytes(pt_count_bytes.try_into()?) as usize;
230 let mut per_table = Vec::with_capacity(pt_count);
231 for _ in 0..pt_count {
232 let table_bytes = data
233 .get(offset..offset + 8)
234 .ok_or(MemoryError::AclLayoutUnsupported)?;
235 offset += 8;
236 let table = TableFingerprint::from_le_bytes(table_bytes.try_into()?);
237 let p_byte = *data.get(offset).ok_or(MemoryError::AclLayoutUnsupported)?;
238 offset += 1;
239 per_table.push((table, TablePerms::from_bits_truncate(p_byte)));
240 }
241 entries.insert(
242 id,
243 IdentityPerms {
244 admin,
245 manage_acl,
246 migrate,
247 all_tables,
248 per_table,
249 },
250 );
251 }
252 Ok(AccessControlList { entries })
253 }
254
255 fn size(&self) -> MSize {
256 let per_entry: MSize = self
257 .entries
258 .iter()
259 .map(|(id, p)| 1 + id.len() as MSize + 1 + 1 + 2 + p.per_table.len() as MSize * (8 + 1))
260 .sum();
261 1 + 4 + per_entry
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use wasm_dbms_api::prelude::fingerprint_for_name;
268
269 use super::*;
270 use crate::HeapMemoryProvider;
271
272 fn make_mm() -> MemoryManager<HeapMemoryProvider> {
273 MemoryManager::init(HeapMemoryProvider::default())
274 }
275
276 fn fp(name: &str) -> TableFingerprint {
277 fingerprint_for_name(name)
278 }
279
280 #[test]
281 fn test_v2_round_trip_empty() {
282 let acl = AccessControlList::default();
283 let bytes = acl.encode();
284 let decoded = AccessControlList::decode(bytes).unwrap();
285 assert_eq!(acl, decoded);
286 }
287
288 #[test]
289 fn test_v2_round_trip_mixed() {
290 let mut acl = AccessControlList::default();
291 let alice = IdentityPerms {
292 admin: true,
293 manage_acl: true,
294 ..Default::default()
295 };
296 let mut bob = IdentityPerms::default();
297 bob.all_tables = TablePerms::READ;
298 bob.apply_grant(PermGrant::Table(
299 fp("users"),
300 TablePerms::INSERT | TablePerms::UPDATE,
301 ));
302 bob.apply_grant(PermGrant::Table(fp("posts"), TablePerms::READ));
303 acl.entries.insert(b"alice".to_vec(), alice);
304 acl.entries.insert(b"bob".to_vec(), bob);
305 let bytes = acl.encode();
306 let decoded = AccessControlList::decode(bytes).unwrap();
307 assert_eq!(acl, decoded);
308 }
309
310 #[test]
311 fn test_decode_rejects_wrong_version() {
312 let bad = vec![0x99, 0, 0, 0, 0];
313 let err = AccessControlList::decode(std::borrow::Cow::Owned(bad)).unwrap_err();
314 assert!(matches!(err, MemoryError::AclLayoutUnsupported));
315 }
316
317 #[test]
318 fn test_grant_persists_through_save_and_load() {
319 let mut mm = make_mm();
320 let mut acl = AccessControlList::default();
321 acl.grant(b"alice".to_vec(), PermGrant::Admin, &mut mm)
322 .unwrap();
323 let loaded = AccessControlList::load(&mut mm).unwrap();
324 assert!(loaded.granted_admin(&b"alice".to_vec()));
325 }
326
327 #[test]
328 fn test_granted_truth_table() {
329 let mut mm = make_mm();
330 let mut acl = AccessControlList::default();
331 acl.grant(
332 b"alice".to_vec(),
333 PermGrant::Table(fp("users"), TablePerms::READ | TablePerms::INSERT),
334 &mut mm,
335 )
336 .unwrap();
337 let id = b"alice".to_vec();
338 assert!(acl.granted(&id, fp("users"), TablePerms::READ));
339 assert!(acl.granted(&id, fp("users"), TablePerms::INSERT));
340 assert!(!acl.granted(&id, fp("users"), TablePerms::DELETE));
341 assert!(!acl.granted(&id, fp("posts"), TablePerms::READ));
342 }
343
344 #[test]
345 fn test_revoke_partial_keeps_remaining_bits() {
346 let mut mm = make_mm();
347 let mut acl = AccessControlList::default();
348 let id = b"alice".to_vec();
349 acl.grant(
350 id.clone(),
351 PermGrant::Table(
352 fp("users"),
353 TablePerms::READ | TablePerms::INSERT | TablePerms::DELETE,
354 ),
355 &mut mm,
356 )
357 .unwrap();
358 acl.revoke(
359 &id,
360 PermRevoke::Table(fp("users"), TablePerms::INSERT | TablePerms::DELETE),
361 &mut mm,
362 )
363 .unwrap();
364 assert!(acl.granted(&id, fp("users"), TablePerms::READ));
365 assert!(!acl.granted(&id, fp("users"), TablePerms::INSERT));
366 }
367
368 #[test]
369 fn test_last_manage_acl_revoke_rejected() {
370 let mut mm = make_mm();
371 let mut acl = AccessControlList::default();
372 let id = b"alice".to_vec();
373 acl.grant(id.clone(), PermGrant::ManageAcl, &mut mm)
374 .unwrap();
375 let err = acl.revoke(&id, PermRevoke::ManageAcl, &mut mm).unwrap_err();
376 assert!(matches!(err, MemoryError::ConstraintViolation(_)));
377 }
378
379 #[test]
380 fn test_remove_last_manage_acl_identity_rejected() {
381 let mut mm = make_mm();
382 let mut acl = AccessControlList::default();
383 let id = b"alice".to_vec();
384 acl.grant(id.clone(), PermGrant::ManageAcl, &mut mm)
385 .unwrap();
386 let err = acl.remove_identity(&id, &mut mm).unwrap_err();
387 assert!(matches!(err, MemoryError::ConstraintViolation(_)));
388 }
389
390 #[test]
391 fn test_remove_identity_without_manage_acl_succeeds() {
392 let mut mm = make_mm();
393 let mut acl = AccessControlList::default();
394 acl.grant(b"alice".to_vec(), PermGrant::ManageAcl, &mut mm)
395 .unwrap();
396 acl.grant(b"bob".to_vec(), PermGrant::Admin, &mut mm)
397 .unwrap();
398 acl.remove_identity(&b"bob".to_vec(), &mut mm).unwrap();
399 assert!(acl.identities().iter().all(|(id, _)| id != b"bob"));
400 }
401
402 #[test]
403 fn test_admin_does_not_imply_manage_acl_or_migrate() {
404 let mut mm = make_mm();
405 let mut acl = AccessControlList::default();
406 let id = b"alice".to_vec();
407 acl.grant(id.clone(), PermGrant::Admin, &mut mm).unwrap();
408 assert!(acl.granted_admin(&id));
409 assert!(!acl.granted_manage_acl(&id));
410 assert!(!acl.granted_migrate(&id));
411 }
412
413 #[test]
414 fn test_revoke_unknown_identity_is_noop() {
415 let mut mm = make_mm();
416 let mut acl = AccessControlList::default();
417 acl.revoke(&b"ghost".to_vec(), PermRevoke::Admin, &mut mm)
418 .unwrap();
419 assert!(acl.identities().is_empty());
420 }
421
422 #[test]
423 fn test_grant_with_idempotency_does_not_duplicate() {
424 let mut mm = make_mm();
425 let mut acl = AccessControlList::default();
426 let id = b"alice".to_vec();
427 acl.grant(id.clone(), PermGrant::Admin, &mut mm).unwrap();
428 acl.grant(id.clone(), PermGrant::Admin, &mut mm).unwrap();
429 assert_eq!(acl.identities().len(), 1);
430 }
431}