Skip to main content

wasm_dbms_memory/
schema_registry.rs

1// Rust guideline compliant 2026-02-28
2
3use std::collections::HashMap;
4
5use wasm_dbms_api::prelude::{
6    DEFAULT_ALIGNMENT, DataSize, Encode, MSize, MemoryResult, Page, PageOffset, TableFingerprint,
7    TableSchema,
8};
9
10use crate::{MemoryAccess, MemoryManager, MemoryProvider};
11
12/// Data regarding the table registry page.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub struct TableRegistryPage {
15    pub pages_list_page: Page,
16    pub free_segments_page: Page,
17}
18
19/// The schema registry takes care of storing and retrieving table schemas from memory.
20#[derive(Debug, Default, Clone, PartialEq, Eq)]
21pub struct SchemaRegistry {
22    tables: HashMap<TableFingerprint, TableRegistryPage>,
23}
24
25impl SchemaRegistry {
26    /// Load the schema registry from memory.
27    pub fn load(mm: &MemoryManager<impl MemoryProvider>) -> MemoryResult<Self> {
28        let page = mm.schema_page();
29        let registry: Self = mm.read_at(page, 0)?;
30        Ok(registry)
31    }
32
33    /// Registers a table and allocates it registry page.
34    ///
35    /// The [`TableSchema`] type parameter is used to get the [`TableSchema::fingerprint`] of the table schema.
36    pub fn register_table<TS>(
37        &mut self,
38        mm: &mut MemoryManager<impl MemoryProvider>,
39    ) -> MemoryResult<TableRegistryPage>
40    where
41        TS: TableSchema,
42    {
43        // check if already registered
44        let fingerprint = TS::fingerprint();
45        if let Some(pages) = self.tables.get(&fingerprint) {
46            return Ok(*pages);
47        }
48
49        // allocate table registry page
50        let pages_list_page = mm.allocate_page()?;
51        let free_segments_page = mm.allocate_page()?;
52
53        // insert into tables map
54        let pages = TableRegistryPage {
55            pages_list_page,
56            free_segments_page,
57        };
58        self.tables.insert(fingerprint, pages);
59
60        // get schema page
61        let page = mm.schema_page();
62        // write self to schema page
63        mm.write_at(page, 0, self)?;
64
65        Ok(pages)
66    }
67
68    /// Save the schema registry to memory.
69    pub fn save(&self, mm: &mut MemoryManager<impl MemoryProvider>) -> MemoryResult<()> {
70        let page = mm.schema_page();
71        mm.write_at(page, 0, self)
72    }
73
74    /// Returns the table registry page for a given table schema.
75    pub fn table_registry_page<TS>(&self) -> Option<TableRegistryPage>
76    where
77        TS: TableSchema,
78    {
79        self.tables.get(&TS::fingerprint()).copied()
80    }
81}
82
83impl Encode for SchemaRegistry {
84    const SIZE: DataSize = DataSize::Dynamic;
85
86    const ALIGNMENT: PageOffset = DEFAULT_ALIGNMENT;
87
88    fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
89        // prepare buffer; size is 8 bytes for len + (8 + (4 * 2)) bytes for each entry
90        let mut buffer = Vec::with_capacity(self.size() as usize);
91        // write 8 bytes len of map
92        buffer.extend_from_slice(&(self.tables.len() as u64).to_le_bytes());
93        // write each entry
94        for (fingerprint, page) in &self.tables {
95            buffer.extend_from_slice(&fingerprint.to_le_bytes());
96            buffer.extend_from_slice(&page.pages_list_page.to_le_bytes());
97            buffer.extend_from_slice(&page.free_segments_page.to_le_bytes());
98        }
99        std::borrow::Cow::Owned(buffer)
100    }
101
102    fn decode(data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
103    where
104        Self: Sized,
105    {
106        let mut offset = 0;
107        // read len
108        let len = u64::from_le_bytes(
109            data[offset..offset + 8]
110                .try_into()
111                .expect("failed to read length"),
112        ) as usize;
113        offset += 8;
114        let mut tables = HashMap::with_capacity(len);
115        // read each entry
116        for _ in 0..len {
117            let fingerprint = u64::from_le_bytes(data[offset..offset + 8].try_into()?);
118            offset += 8;
119            let pages_list_page = Page::from_le_bytes(data[offset..offset + 4].try_into()?);
120            offset += 4;
121            let deleted_records_page = Page::from_le_bytes(data[offset..offset + 4].try_into()?);
122            offset += 4;
123            tables.insert(
124                fingerprint,
125                TableRegistryPage {
126                    pages_list_page,
127                    free_segments_page: deleted_records_page,
128                },
129            );
130        }
131        Ok(Self { tables })
132    }
133
134    fn size(&self) -> MSize {
135        // 8 bytes for len + (8 + (4 * 2)) bytes for each entry
136        8 + (self.tables.len() as MSize * (4 * 2 + 8))
137    }
138}
139
140#[cfg(test)]
141mod tests {
142
143    use candid::CandidType;
144    use serde::{Deserialize, Serialize};
145    use wasm_dbms_api::prelude::{
146        ColumnDef, DbmsResult, InsertRecord, NoForeignFetcher, TableColumns, TableRecord,
147        UpdateRecord,
148    };
149
150    use super::*;
151    use crate::HeapMemoryProvider;
152
153    fn make_mm() -> MemoryManager<HeapMemoryProvider> {
154        MemoryManager::init(HeapMemoryProvider::default())
155    }
156
157    #[test]
158    fn test_should_encode_and_decode_schema_registry() {
159        let mut mm = make_mm();
160
161        // load
162        let mut registry = SchemaRegistry::load(&mm).expect("failed to load init schema registry");
163
164        // register table
165        let registry_page = registry
166            .register_table::<User>(&mut mm)
167            .expect("failed to register table");
168
169        // get table registry page
170        let fetched_page = registry
171            .table_registry_page::<User>()
172            .expect("failed to get table registry page");
173        assert_eq!(registry_page, fetched_page);
174
175        // encode
176        let encoded = registry.encode();
177        // decode
178        let decoded = SchemaRegistry::decode(encoded).expect("failed to decode");
179        assert_eq!(registry, decoded);
180
181        // try to actually add another
182        let another_registry_page = registry
183            .register_table::<AnotherTable>(&mut mm)
184            .expect("failed to register another table");
185        let another_fetched_page = registry
186            .table_registry_page::<AnotherTable>()
187            .expect("failed to get another table registry page");
188        assert_eq!(another_registry_page, another_fetched_page);
189
190        // re-init
191        let reloaded = SchemaRegistry::load(&mm).expect("failed to reload schema registry");
192        assert_eq!(registry, reloaded);
193        // should have two
194        assert_eq!(reloaded.tables.len(), 2);
195        assert_eq!(
196            reloaded
197                .table_registry_page::<User>()
198                .expect("failed to get first table registry page after reload"),
199            registry_page
200        );
201        assert_eq!(
202            reloaded
203                .table_registry_page::<AnotherTable>()
204                .expect("failed to get second table registry page after reload"),
205            another_registry_page
206        );
207    }
208
209    #[test]
210    fn test_should_not_register_same_table_twice() {
211        let mut mm = make_mm();
212        let mut registry = SchemaRegistry::default();
213
214        let first_page = registry
215            .register_table::<User>(&mut mm)
216            .expect("failed to register table first time");
217        let second_page = registry
218            .register_table::<User>(&mut mm)
219            .expect("failed to register table second time");
220
221        assert_eq!(first_page, second_page);
222        assert_eq!(registry.tables.len(), 1);
223    }
224
225    #[derive(Clone, CandidType)]
226    struct AnotherTable;
227
228    impl Encode for AnotherTable {
229        const SIZE: DataSize = DataSize::Dynamic;
230
231        const ALIGNMENT: PageOffset = DEFAULT_ALIGNMENT;
232
233        fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
234            std::borrow::Cow::Owned(vec![])
235        }
236
237        fn decode(_data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
238        where
239            Self: Sized,
240        {
241            Ok(AnotherTable)
242        }
243
244        fn size(&self) -> MSize {
245            0
246        }
247    }
248
249    #[derive(Clone, CandidType, Deserialize)]
250    struct AnotherTableRecord;
251
252    impl TableRecord for AnotherTableRecord {
253        type Schema = AnotherTable;
254
255        fn from_values(_values: TableColumns) -> Self {
256            AnotherTableRecord
257        }
258
259        fn to_values(&self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
260            vec![]
261        }
262    }
263
264    #[derive(Clone, CandidType, Serialize)]
265    struct AnotherTableInsert;
266
267    impl InsertRecord for AnotherTableInsert {
268        type Record = AnotherTableRecord;
269        type Schema = AnotherTable;
270
271        fn from_values(_values: &[(ColumnDef, wasm_dbms_api::prelude::Value)]) -> DbmsResult<Self> {
272            Ok(AnotherTableInsert)
273        }
274
275        fn into_values(self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
276            vec![]
277        }
278
279        fn into_record(self) -> Self::Schema {
280            AnotherTable
281        }
282    }
283
284    #[derive(Clone, CandidType, Serialize)]
285    struct AnotherTableUpdate;
286
287    impl UpdateRecord for AnotherTableUpdate {
288        type Record = AnotherTableRecord;
289        type Schema = AnotherTable;
290
291        fn from_values(
292            _values: &[(ColumnDef, wasm_dbms_api::prelude::Value)],
293            _where_clause: Option<wasm_dbms_api::prelude::Filter>,
294        ) -> Self {
295            AnotherTableUpdate
296        }
297
298        fn update_values(&self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
299            vec![]
300        }
301
302        fn where_clause(&self) -> Option<wasm_dbms_api::prelude::Filter> {
303            None
304        }
305    }
306
307    impl TableSchema for AnotherTable {
308        type Record = AnotherTableRecord;
309        type Insert = AnotherTableInsert;
310        type Update = AnotherTableUpdate;
311        type ForeignFetcher = NoForeignFetcher;
312
313        fn table_name() -> &'static str {
314            "another_table"
315        }
316
317        fn columns() -> &'static [wasm_dbms_api::prelude::ColumnDef] {
318            &[]
319        }
320
321        fn primary_key() -> &'static str {
322            ""
323        }
324
325        fn to_values(self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
326            vec![]
327        }
328
329        fn sanitizer(
330            _column_name: &'static str,
331        ) -> Option<Box<dyn wasm_dbms_api::prelude::Sanitize>> {
332            None
333        }
334
335        fn validator(
336            _column_name: &'static str,
337        ) -> Option<Box<dyn wasm_dbms_api::prelude::Validate>> {
338            None
339        }
340    }
341
342    // -- User mock for tests --
343
344    #[derive(Clone, CandidType)]
345    struct User;
346
347    impl Encode for User {
348        const SIZE: DataSize = DataSize::Dynamic;
349        const ALIGNMENT: PageOffset = DEFAULT_ALIGNMENT;
350
351        fn encode(&'_ self) -> std::borrow::Cow<'_, [u8]> {
352            std::borrow::Cow::Owned(vec![])
353        }
354
355        fn decode(_data: std::borrow::Cow<[u8]>) -> MemoryResult<Self>
356        where
357            Self: Sized,
358        {
359            Ok(User)
360        }
361
362        fn size(&self) -> MSize {
363            0
364        }
365    }
366
367    #[derive(Clone, CandidType, Deserialize)]
368    struct UserRecord;
369
370    impl TableRecord for UserRecord {
371        type Schema = User;
372
373        fn from_values(_values: TableColumns) -> Self {
374            UserRecord
375        }
376
377        fn to_values(&self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
378            vec![]
379        }
380    }
381
382    #[derive(Clone, CandidType, Serialize)]
383    struct UserInsert;
384
385    impl InsertRecord for UserInsert {
386        type Record = UserRecord;
387        type Schema = User;
388
389        fn from_values(_values: &[(ColumnDef, wasm_dbms_api::prelude::Value)]) -> DbmsResult<Self> {
390            Ok(UserInsert)
391        }
392
393        fn into_values(self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
394            vec![]
395        }
396
397        fn into_record(self) -> Self::Schema {
398            User
399        }
400    }
401
402    #[derive(Clone, CandidType, Serialize)]
403    struct UserUpdate;
404
405    impl UpdateRecord for UserUpdate {
406        type Record = UserRecord;
407        type Schema = User;
408
409        fn from_values(
410            _values: &[(ColumnDef, wasm_dbms_api::prelude::Value)],
411            _where_clause: Option<wasm_dbms_api::prelude::Filter>,
412        ) -> Self {
413            UserUpdate
414        }
415
416        fn update_values(&self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
417            vec![]
418        }
419
420        fn where_clause(&self) -> Option<wasm_dbms_api::prelude::Filter> {
421            None
422        }
423    }
424
425    impl TableSchema for User {
426        type Record = UserRecord;
427        type Insert = UserInsert;
428        type Update = UserUpdate;
429        type ForeignFetcher = NoForeignFetcher;
430
431        fn table_name() -> &'static str {
432            "users"
433        }
434
435        fn columns() -> &'static [wasm_dbms_api::prelude::ColumnDef] {
436            &[]
437        }
438
439        fn primary_key() -> &'static str {
440            "id"
441        }
442
443        fn to_values(self) -> Vec<(ColumnDef, wasm_dbms_api::prelude::Value)> {
444            vec![]
445        }
446
447        fn sanitizer(
448            _column_name: &'static str,
449        ) -> Option<Box<dyn wasm_dbms_api::prelude::Sanitize>> {
450            None
451        }
452
453        fn validator(
454            _column_name: &'static str,
455        ) -> Option<Box<dyn wasm_dbms_api::prelude::Validate>> {
456            None
457        }
458    }
459
460    #[test]
461    fn test_table_registry_page_returns_none_for_unregistered_table() {
462        let registry = SchemaRegistry::default();
463        assert!(registry.table_registry_page::<User>().is_none());
464    }
465
466    #[test]
467    fn test_empty_registry_encode_decode() {
468        let registry = SchemaRegistry::default();
469        let encoded = registry.encode();
470        let decoded = SchemaRegistry::decode(encoded).expect("failed to decode empty registry");
471        assert_eq!(registry, decoded);
472        assert_eq!(decoded.tables.len(), 0);
473    }
474
475    #[test]
476    fn test_load_fresh_memory_returns_empty_registry() {
477        let mm = make_mm();
478        let registry = SchemaRegistry::load(&mm).expect("failed to load from fresh memory");
479        assert_eq!(registry.tables.len(), 0);
480    }
481
482    #[test]
483    fn test_save_and_reload() {
484        let mut mm = make_mm();
485        let mut registry = SchemaRegistry::default();
486        registry
487            .register_table::<User>(&mut mm)
488            .expect("failed to register");
489        // Modify in-memory, then explicitly save
490        registry
491            .register_table::<AnotherTable>(&mut mm)
492            .expect("failed to register another");
493        registry.save(&mut mm).expect("failed to save");
494
495        let reloaded = SchemaRegistry::load(&mm).expect("failed to reload");
496        assert_eq!(reloaded.tables.len(), 2);
497        assert_eq!(registry, reloaded);
498    }
499
500    #[test]
501    fn test_schema_registry_size() {
502        let mut mm = make_mm();
503        let mut registry = SchemaRegistry::default();
504        // Empty size: 8 bytes for length
505        assert_eq!(registry.size(), 8);
506        registry
507            .register_table::<User>(&mut mm)
508            .expect("failed to register");
509        // One entry: 8 + (8 + 4 + 4) = 24
510        assert_eq!(registry.size(), 24);
511    }
512}