Skip to main content

zamsync_core/
validation.rs

1use crate::{ZamError, ZamResult};
2
3/// Controls which payloads `ZamEngine` accepts at submit and replicate time.
4///
5/// `None` is the default and accepts any bytes. Use `Json` or `JsonRequired`
6/// for deployments (like Bhutan ePIS) where all events must carry structured data.
7#[derive(Debug, Clone, Default)]
8pub enum PayloadSchema {
9    /// Accept any bytes (default -- backward compatible).
10    #[default]
11    None,
12    /// Payload must be valid JSON.
13    Json,
14    /// Payload must be valid JSON **and** contain all listed top-level keys.
15    JsonRequired(Vec<String>),
16}
17
18impl std::str::FromStr for PayloadSchema {
19    type Err = String;
20
21    /// Parse a schema from a CLI flag value (`"none"`, `"json"`, `"json+key1,key2"`).
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        if s == "none" {
24            return Ok(Self::None);
25        }
26        if s == "json" {
27            return Ok(Self::Json);
28        }
29        if let Some(rest) = s.strip_prefix("json+") {
30            let fields = rest.split(',').map(str::to_owned).collect();
31            return Ok(Self::JsonRequired(fields));
32        }
33        Err(format!(
34            "unknown schema '{s}': use 'none', 'json', or 'json+field1,field2'"
35        ))
36    }
37}
38
39impl PayloadSchema {
40    pub fn is_none(&self) -> bool {
41        matches!(self, Self::None)
42    }
43
44    pub fn validate(&self, payload: &[u8]) -> ZamResult<()> {
45        match self {
46            Self::None => Ok(()),
47            Self::Json => json_parse(payload).map(|_| ()),
48            Self::JsonRequired(fields) => {
49                let v = json_parse(payload)?;
50                for field in fields {
51                    if v.get(field.as_str()).is_none() {
52                        return Err(ZamError::Validation(format!(
53                            "missing required field '{field}'"
54                        )));
55                    }
56                }
57                Ok(())
58            }
59        }
60    }
61}
62
63fn json_parse(payload: &[u8]) -> ZamResult<serde_json::Value> {
64    serde_json::from_slice(payload).map_err(|e| ZamError::Validation(format!("invalid JSON: {e}")))
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn test_none_accepts_anything() {
73        assert!(PayloadSchema::None.validate(b"not json at all").is_ok());
74        assert!(PayloadSchema::None.validate(b"").is_ok());
75    }
76
77    #[test]
78    fn test_json_accepts_valid() {
79        assert!(PayloadSchema::Json
80            .validate(br#"{"type":"patient_admitted"}"#)
81            .is_ok());
82        assert!(PayloadSchema::Json.validate(b"42").is_ok());
83        assert!(PayloadSchema::Json.validate(b"null").is_ok());
84    }
85
86    #[test]
87    fn test_json_rejects_invalid() {
88        let err = PayloadSchema::Json.validate(b"not json").unwrap_err();
89        assert!(matches!(err, ZamError::Validation(_)));
90    }
91
92    #[test]
93    fn test_json_required_accepts_all_fields() {
94        let schema = PayloadSchema::JsonRequired(vec!["type".into(), "patient_id".into()]);
95        let payload = br#"{"type":"discharge","patient_id":"BT-001","ward":"3A"}"#;
96        assert!(schema.validate(payload).is_ok());
97    }
98
99    #[test]
100    fn test_json_required_rejects_missing_field() {
101        let schema = PayloadSchema::JsonRequired(vec!["type".into(), "patient_id".into()]);
102        let payload = br#"{"type":"discharge"}"#;
103        let err = schema.validate(payload).unwrap_err();
104        assert!(matches!(&err, ZamError::Validation(msg) if msg.contains("patient_id")));
105    }
106
107    #[test]
108    fn test_from_str_round_trip() {
109        use std::str::FromStr;
110        assert!(matches!(
111            PayloadSchema::from_str("none").unwrap(),
112            PayloadSchema::None
113        ));
114        assert!(matches!(
115            PayloadSchema::from_str("json").unwrap(),
116            PayloadSchema::Json
117        ));
118        let PayloadSchema::JsonRequired(fields) =
119            PayloadSchema::from_str("json+type,patient_id").unwrap()
120        else {
121            panic!()
122        };
123        assert_eq!(fields, ["type", "patient_id"]);
124        assert!(PayloadSchema::from_str("bad").is_err());
125    }
126}