Skip to main content

wasm_dbms/integrity/
update.rs

1// Rust guideline compliant 2026-03-01
2// X-WHERE-CLAUSE, M-CANONICAL-DOCS
3
4//! Integrity validator for update operations.
5
6use wasm_dbms_api::prelude::{
7    ColumnDef, Database as _, DbmsError, DbmsResult, Filter, Query, QueryError, TableRecord,
8    TableSchema, Value,
9};
10use wasm_dbms_memory::prelude::{AccessControl, AccessControlList, MemoryProvider};
11
12use super::common;
13use crate::database::WasmDbmsDatabase;
14
15/// Integrity validator for update operations.
16///
17/// Unlike [`super::InsertIntegrityValidator`], this validator allows the
18/// primary key to remain unchanged during an update.
19pub struct UpdateIntegrityValidator<'a, T, M, A = AccessControlList>
20where
21    T: TableSchema,
22    M: MemoryProvider,
23    A: AccessControl,
24{
25    database: &'a WasmDbmsDatabase<'a, M, A>,
26    /// The current primary key value of the record being updated.
27    old_pk: Value,
28    _marker: std::marker::PhantomData<T>,
29}
30
31impl<'a, T, M, A> UpdateIntegrityValidator<'a, T, M, A>
32where
33    T: TableSchema,
34    M: MemoryProvider,
35    A: AccessControl,
36{
37    /// Creates a new update integrity validator.
38    pub fn new(dbms: &'a WasmDbmsDatabase<'a, M, A>, old_pk: Value) -> Self {
39        Self {
40            database: dbms,
41            old_pk,
42            _marker: std::marker::PhantomData,
43        }
44    }
45}
46
47impl<T, M, A> UpdateIntegrityValidator<'_, T, M, A>
48where
49    T: TableSchema,
50    M: MemoryProvider,
51    A: AccessControl,
52{
53    /// Verifies whether the given updated record values are valid.
54    pub fn validate(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
55        for (col, value) in record_values {
56            common::check_column_validate::<T>(col, value)?;
57        }
58        self.check_primary_key_conflict(record_values)?;
59        self.check_unique_constraints(record_values)?;
60        common::check_foreign_keys::<T>(self.database, record_values)?;
61        common::check_non_nullable_fields::<T>(record_values)?;
62
63        Ok(())
64    }
65
66    /// Checks for primary key conflicts with *other* records.
67    fn check_primary_key_conflict(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
68        let pk_name = T::primary_key();
69        let new_pk = record_values
70            .iter()
71            .find(|(col_def, _)| col_def.name == pk_name)
72            .map(|(_, value)| value.clone())
73            .ok_or(DbmsError::Query(QueryError::MissingNonNullableField(
74                pk_name.to_string(),
75            )))?;
76
77        let query = Query::builder()
78            .field(pk_name)
79            .and_where(Filter::Eq(pk_name.to_string(), new_pk.clone()))
80            .build();
81
82        let res = self.database.select::<T>(query)?;
83        match res.len() {
84            0 => Ok(()),
85            1 => {
86                if new_pk == self.old_pk {
87                    Ok(())
88                } else {
89                    Err(DbmsError::Query(QueryError::PrimaryKeyConflict))
90                }
91            }
92            _ => Err(DbmsError::Query(QueryError::PrimaryKeyConflict)),
93        }
94    }
95
96    /// Checks for unique constraint violations, excluding the record being updated.
97    ///
98    /// For each unique field, queries for existing records with the same value.
99    /// A match is only a conflict if it belongs to a different record (different primary key).
100    fn check_unique_constraints(&self, record_values: &[(ColumnDef, Value)]) -> DbmsResult<()> {
101        let pk_name = T::primary_key();
102
103        for (col_def, value) in record_values.iter().filter(|(col_def, _)| col_def.unique) {
104            let query = Query::builder()
105                .field(pk_name)
106                .and_where(Filter::Eq(col_def.name.to_string(), value.clone()))
107                .build();
108
109            let res = self.database.select::<T>(query)?;
110            for record in &res {
111                let record_pk = record
112                    .to_values()
113                    .into_iter()
114                    .find(|(c, _)| c.name == pk_name)
115                    .map(|(_, v)| v);
116
117                if record_pk.as_ref() != Some(&self.old_pk) {
118                    return Err(DbmsError::Query(QueryError::UniqueConstraintViolation {
119                        field: col_def.name.to_string(),
120                    }));
121                }
122            }
123        }
124
125        Ok(())
126    }
127}
128
129#[cfg(test)]
130mod tests {
131
132    use wasm_dbms_api::prelude::{
133        Database as _, Filter, InsertRecord as _, TableSchema as _, Text, Uint32,
134        UpdateRecord as _, Value,
135    };
136    use wasm_dbms_macros::{DatabaseSchema, Table};
137    use wasm_dbms_memory::prelude::HeapMemoryProvider;
138
139    use crate::prelude::{DbmsContext, WasmDbmsDatabase};
140
141    #[derive(Debug, Table, Clone, PartialEq, Eq)]
142    #[table = "users"]
143    pub struct User {
144        #[primary_key]
145        pub id: Uint32,
146        pub name: Text,
147    }
148
149    #[derive(Debug, Table, Clone, PartialEq, Eq)]
150    #[table = "contracts"]
151    pub struct Contract {
152        #[primary_key]
153        pub id: Uint32,
154        #[unique]
155        pub code: Text,
156        #[foreign_key(entity = "User", table = "users", column = "id")]
157        pub user_id: Uint32,
158    }
159
160    #[derive(DatabaseSchema)]
161    #[tables(User = "users", Contract = "contracts")]
162    pub struct TestSchema;
163
164    fn setup() -> DbmsContext<HeapMemoryProvider> {
165        let ctx = DbmsContext::new(HeapMemoryProvider::default());
166        TestSchema::register_tables(&ctx).unwrap();
167        ctx
168    }
169
170    fn insert_user(db: &WasmDbmsDatabase<'_, HeapMemoryProvider>, id: u32, name: &str) {
171        let insert = UserInsertRequest::from_values(&[
172            (User::columns()[0], Value::Uint32(Uint32(id))),
173            (User::columns()[1], Value::Text(Text(name.to_string()))),
174        ])
175        .unwrap();
176        db.insert::<User>(insert).unwrap();
177    }
178
179    fn insert_contract(
180        db: &WasmDbmsDatabase<'_, HeapMemoryProvider>,
181        id: u32,
182        code: &str,
183        user_id: u32,
184    ) {
185        let insert = ContractInsertRequest::from_values(&[
186            (Contract::columns()[0], Value::Uint32(Uint32(id))),
187            (Contract::columns()[1], Value::Text(Text(code.to_string()))),
188            (Contract::columns()[2], Value::Uint32(Uint32(user_id))),
189        ])
190        .unwrap();
191        db.insert::<Contract>(insert).unwrap();
192    }
193
194    #[test]
195    fn test_update_unique_field_to_new_value_succeeds() {
196        let ctx = setup();
197        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
198        insert_user(&db, 1, "alice");
199        insert_contract(&db, 1, "CONTRACT-001", 1);
200
201        let patch = ContractUpdateRequest::from_values(
202            &[(
203                Contract::columns()[1],
204                Value::Text(Text("CONTRACT-999".to_string())),
205            )],
206            Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
207        );
208        assert!(db.update::<Contract>(patch).is_ok());
209    }
210
211    #[test]
212    fn test_update_keeping_same_unique_value_succeeds() {
213        let ctx = setup();
214        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
215        insert_user(&db, 1, "alice");
216        insert_contract(&db, 1, "CONTRACT-001", 1);
217
218        // Update the record but keep the same unique code
219        let patch = ContractUpdateRequest::from_values(
220            &[(
221                Contract::columns()[1],
222                Value::Text(Text("CONTRACT-001".to_string())),
223            )],
224            Some(Filter::eq("id", Value::Uint32(Uint32(1)))),
225        );
226        assert!(db.update::<Contract>(patch).is_ok());
227    }
228
229    #[test]
230    fn test_update_unique_field_to_existing_value_fails() {
231        let ctx = setup();
232        let db = WasmDbmsDatabase::oneshot(&ctx, TestSchema);
233        insert_user(&db, 1, "alice");
234        insert_contract(&db, 1, "CONTRACT-001", 1);
235        insert_contract(&db, 2, "CONTRACT-002", 1);
236
237        // Try to update contract 2's code to match contract 1's code
238        let patch = ContractUpdateRequest::from_values(
239            &[(
240                Contract::columns()[1],
241                Value::Text(Text("CONTRACT-001".to_string())),
242            )],
243            Some(Filter::eq("id", Value::Uint32(Uint32(2)))),
244        );
245        let result = db.update::<Contract>(patch);
246        assert!(result.is_err());
247        assert!(matches!(
248            result.unwrap_err(),
249            wasm_dbms_api::prelude::DbmsError::Query(
250                wasm_dbms_api::prelude::QueryError::UniqueConstraintViolation { ref field }
251            ) if field == "code"
252        ),);
253    }
254}