trailbase_extension/
jsonschema.rs

1use jsonschema::Validator;
2use mini_moka::sync::Cache;
3use parking_lot::Mutex;
4use rusqlite::Error;
5use rusqlite::functions::Context;
6use std::collections::HashMap;
7use std::sync::{Arc, LazyLock};
8
9// NOTE:: Validation error is very large, we thus Box it.
10pub type ValidationError = Box<jsonschema::ValidationError<'static>>;
11
12type CustomValidatorFn = Arc<dyn Fn(&serde_json::Value, Option<&str>) -> bool + Send + Sync>;
13
14#[derive(Clone)]
15pub struct SchemaEntry {
16  schema: serde_json::Value,
17  validator: Arc<Validator>,
18  custom_validator: Option<CustomValidatorFn>,
19}
20
21impl SchemaEntry {
22  pub fn from(
23    schema: serde_json::Value,
24    custom_validator: Option<CustomValidatorFn>,
25  ) -> Result<Self, ValidationError> {
26    let validator = Validator::new(&schema)?;
27
28    return Ok(Self {
29      schema,
30      validator: validator.into(),
31      custom_validator,
32    });
33  }
34}
35
36static SCHEMA_REGISTRY: LazyLock<Mutex<HashMap<String, SchemaEntry>>> =
37  LazyLock::new(|| Mutex::new(HashMap::<String, SchemaEntry>::new()));
38
39pub fn set_schemas(schema_entries: Option<Vec<(String, SchemaEntry)>>) {
40  let mut lock = SCHEMA_REGISTRY.lock();
41  lock.clear();
42
43  if let Some(entries) = schema_entries {
44    for (name, entry) in entries {
45      lock.insert(name, entry);
46    }
47  }
48}
49
50pub fn set_schema(name: &str, entry: Option<SchemaEntry>) {
51  if let Some(entry) = entry {
52    SCHEMA_REGISTRY.lock().insert(name.to_string(), entry);
53  } else {
54    SCHEMA_REGISTRY.lock().remove(name);
55  }
56}
57
58pub fn get_schema(name: &str) -> Option<serde_json::Value> {
59  SCHEMA_REGISTRY.lock().get(name).map(|s| s.schema.clone())
60}
61
62pub fn get_compiled_schema(name: &str) -> Option<Arc<Validator>> {
63  SCHEMA_REGISTRY
64    .lock()
65    .get(name)
66    .map(|s| s.validator.clone())
67}
68
69pub fn get_schemas() -> Vec<(String, serde_json::Value)> {
70  SCHEMA_REGISTRY
71    .lock()
72    .iter()
73    .map(|(name, schema)| (name.clone(), schema.schema.clone()))
74    .collect()
75}
76
77pub(crate) fn jsonschema_by_name(context: &Context) -> Result<bool, Error> {
78  let schema_name = context.get_raw(0).as_str()?;
79
80  // Get and parse the JSON contents. If it's invalid JSON to start with, there's not much
81  // we can validate.
82  let Some(contents) = context.get_raw(1).as_str_or_null()? else {
83    return Ok(true);
84  };
85
86  let json = serde_json::from_str(contents)
87    .map_err(|err| Error::UserFunctionError(format!("Invalid JSON: {contents} => {err}").into()))?;
88
89  // Then get/build the schema validator for the given pattern.
90  let Some(entry) = SCHEMA_REGISTRY.lock().get(schema_name).cloned() else {
91    return Err(Error::UserFunctionError(
92      format!("Schema {schema_name} not found").into(),
93    ));
94  };
95
96  if !entry.validator.is_valid(&json) {
97    return Ok(false);
98  }
99
100  if let Some(validator) = entry.custom_validator {
101    if !validator(&json, None) {
102      return Ok(false);
103    }
104  }
105
106  return Ok(true);
107}
108
109pub(crate) fn jsonschema_by_name_with_extra_args(context: &Context) -> Result<bool, Error> {
110  let schema_name = context.get_raw(0).as_str()?;
111  let extra_args = context.get_raw(2).as_str()?;
112
113  // Get and parse the JSON contents. If it's invalid JSON to start with, there's not much
114  // we can validate.
115  let Some(contents) = context.get_raw(1).as_str_or_null()? else {
116    return Ok(true);
117  };
118  let json = serde_json::from_str(contents)
119    .map_err(|err| Error::UserFunctionError(format!("Invalid JSON: {contents} => {err}").into()))?;
120
121  // Then get/build the schema validator for the given pattern.
122  let Some(entry) = SCHEMA_REGISTRY.lock().get(schema_name).cloned() else {
123    return Err(Error::UserFunctionError(
124      format!("Schema {schema_name} not found").into(),
125    ));
126  };
127
128  if !entry.validator.is_valid(&json) {
129    return Ok(false);
130  }
131
132  if let Some(validator) = entry.custom_validator {
133    if !validator(&json, Some(extra_args)) {
134      return Ok(false);
135    }
136  }
137
138  return Ok(true);
139}
140
141static SCHEMA_CACHE: LazyLock<Cache<String, Arc<Validator>>> = LazyLock::new(|| Cache::new(256));
142
143pub(crate) fn jsonschema_matches(context: &Context) -> Result<bool, Error> {
144  // First, get and parse the JSON contents. If it's invalid JSON to start with, there's not much
145  // we can validate.
146  let Some(contents) = context.get_raw(1).as_str_or_null()? else {
147    return Ok(true);
148  };
149  let json = serde_json::from_str(contents).map_err(|err| {
150    Error::UserFunctionError(format!("Invalid JSON: '{contents}' => {err}").into())
151  })?;
152
153  let pattern = context.get_raw(0).as_str()?.to_string();
154
155  // Then get/build the schema validator for the given pattern.
156  let valid = match SCHEMA_CACHE.get(&pattern) {
157    Some(validator) => validator.is_valid(&json),
158    None => {
159      let schema = serde_json::from_str(&pattern)
160        .map_err(|err| Error::UserFunctionError(format!("Invalid JSON Schema: {err}").into()))?;
161      let validator = Validator::new(&schema).map_err(|err| {
162        Error::UserFunctionError(format!("Failed to compile Schema: {err}").into())
163      })?;
164
165      let valid = validator.is_valid(&json);
166      SCHEMA_CACHE.insert(pattern, Arc::new(validator));
167      valid
168    }
169  };
170
171  return Ok(valid);
172}
173
174#[cfg(test)]
175mod tests {
176  use super::*;
177  use rusqlite::params;
178
179  #[test]
180  fn test_explicit_jsonschema() {
181    let conn = crate::connect_sqlite(None, None).unwrap();
182
183    let text0_schema = r#"
184        {
185          "type": "object",
186          "properties": {
187            "name": { "type": "string" },
188            "age": { "type": "integer", "minimum": 0 }
189          },
190          "required": ["name"]
191        }
192    "#;
193
194    let text1_schema = r#"{ "type": "string" }"#;
195
196    let create_table = format!(
197      r#"
198        CREATE TABLE test (
199          text0    TEXT NOT NULL CHECK(jsonschema_matches('{text0_schema}', text0)),
200          text1    TEXT NOT NULL CHECK(jsonschema_matches('{text1_schema}', text1))
201        ) STRICT;
202      "#
203    );
204    conn.execute(&create_table, ()).unwrap();
205
206    {
207      conn
208        .execute(
209          r#"INSERT INTO test (text0, text1) VALUES ('{"name": "foo"}', '"text"')"#,
210          params!(),
211        )
212        .unwrap();
213    }
214
215    {
216      assert!(
217        conn
218          .execute(
219            r#"INSERT INTO test (text0, text1) VALUES ('{"name": "foo", "age": -5}', '"text"')"#,
220            params!(),
221          )
222          .is_err()
223      );
224    }
225  }
226
227  #[test]
228  fn test_registerd_jsonschema() {
229    let conn = crate::connect_sqlite(None, None).unwrap();
230
231    let text0_schema = r#"
232        {
233          "type": "object",
234          "properties": {
235            "name": { "type": "string" },
236            "age": { "type": "integer", "minimum": 0 }
237          },
238          "required": ["name"]
239        }
240    "#;
241
242    fn starts_with(v: &serde_json::Value, param: Option<&str>) -> bool {
243      if let Some(param) = param {
244        if let serde_json::Value::Object(map) = v {
245          if let Some(serde_json::Value::String(str)) = map.get("name") {
246            if str.starts_with(param) {
247              return true;
248            }
249          }
250        }
251      }
252      return false;
253    }
254
255    set_schema(
256      "name0",
257      Some(
258        SchemaEntry::from(
259          serde_json::from_str(text0_schema).unwrap(),
260          Some(Arc::new(starts_with)),
261        )
262        .unwrap(),
263      ),
264    );
265
266    let create_table = format!(
267      r#"
268        CREATE TABLE test (
269          text0    TEXT NOT NULL CHECK(jsonschema('name0', text0, 'prefix'))
270        ) STRICT;
271      "#
272    );
273    conn.execute(&create_table, ()).unwrap();
274
275    conn
276      .execute(
277        r#"INSERT INTO test (text0) VALUES ('{"name": "prefix_foo"}')"#,
278        params!(),
279      )
280      .unwrap();
281
282    assert!(
283      conn
284        .execute(
285          r#"INSERT INTO test (text0) VALUES ('{"name": "WRONG_PREFIX_foo"}')"#,
286          params!(),
287        )
288        .is_err()
289    );
290  }
291}