Skip to main content

shard_core/
metadata.rs

1use anyhow::Result;
2use serde::de::DeserializeOwned;
3use serde::Serialize;
4
5const CBOR_MARKER: u8 = 0x02;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum MetadataFormat {
9    Json,
10    Cbor,
11}
12
13impl MetadataFormat {
14    pub fn from_config(config: &std::collections::BTreeMap<String, String>) -> Self {
15        match config.get("serialization_format").map(|s| s.as_str()) {
16            Some("cbor") => MetadataFormat::Cbor,
17            _ => MetadataFormat::Json,
18        }
19    }
20
21    pub fn config_value(&self) -> &'static str {
22        match self {
23            MetadataFormat::Json => "json",
24            MetadataFormat::Cbor => "cbor",
25        }
26    }
27}
28
29fn sort_json_keys(value: serde_json::Value) -> serde_json::Value {
30    match value {
31        serde_json::Value::Object(map) => {
32            let mut keys: Vec<String> = map.keys().cloned().collect();
33            keys.sort();
34            let mut sorted = serde_json::Map::with_capacity(keys.len());
35            for key in keys {
36                if let Some(val) = map.get(&key) {
37                    sorted.insert(key, sort_json_keys(val.clone()));
38                }
39            }
40            serde_json::Value::Object(sorted)
41        }
42        serde_json::Value::Array(arr) => {
43            serde_json::Value::Array(arr.into_iter().map(sort_json_keys).collect())
44        }
45        other => other,
46    }
47}
48
49pub fn serialize<T: Serialize>(data: &T, format: &MetadataFormat) -> Vec<u8> {
50    match format {
51        MetadataFormat::Json => {
52            let value = serde_json::to_value(data).expect("JSON serialization failed");
53            let sorted = sort_json_keys(value);
54            serde_json::to_vec(&sorted).expect("canonical JSON serialization failed")
55        }
56        MetadataFormat::Cbor => {
57            let mut buf = vec![CBOR_MARKER];
58            ciborium::into_writer(data, &mut buf).expect("CBOR serialization failed");
59            buf
60        }
61    }
62}
63
64pub fn deserialize<T: DeserializeOwned>(data: &[u8]) -> Result<T> {
65    if data.is_empty() {
66        anyhow::bail!("empty metadata");
67    }
68    if data[0] == CBOR_MARKER {
69        return Ok(ciborium::from_reader(&data[1..])?);
70    }
71    Ok(serde_json::from_slice(data)?)
72}
73
74pub fn serialize_for_signing<T: Serialize>(data: &T) -> Vec<u8> {
75    let value = serde_json::to_value(data).expect("JSON serialization failed");
76    let sorted = sort_json_keys(value);
77    serde_json::to_vec(&sorted).expect("canonical JSON serialization failed")
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use serde::{Deserialize, Serialize};
84
85    #[derive(Serialize, Deserialize, Debug, PartialEq)]
86    struct TestData {
87        name: String,
88        value: u64,
89        items: Vec<String>,
90    }
91
92    fn test_obj() -> TestData {
93        TestData {
94            name: "test".into(),
95            value: 42,
96            items: vec!["a".into(), "b".into()],
97        }
98    }
99
100    #[test]
101    fn test_json_roundtrip() {
102        let obj = test_obj();
103        let bytes = serialize(&obj, &MetadataFormat::Json);
104        // JSON has no marker byte
105        assert!(!bytes.is_empty());
106        let decoded: TestData = deserialize(&bytes).unwrap();
107        assert_eq!(decoded, obj);
108    }
109
110    #[test]
111    fn test_cbor_roundtrip() {
112        let obj = test_obj();
113        let bytes = serialize(&obj, &MetadataFormat::Cbor);
114        assert_eq!(bytes[0], CBOR_MARKER);
115        let decoded: TestData = deserialize(&bytes).unwrap();
116        assert_eq!(decoded, obj);
117    }
118
119    #[test]
120    fn test_cbor_backward_compat() {
121        // CBOR-marker data should be readable
122        let obj = test_obj();
123        let cbor_bytes = serialize(&obj, &MetadataFormat::Cbor);
124        let decoded: TestData = deserialize(&cbor_bytes).unwrap();
125        assert_eq!(decoded, obj);
126    }
127
128    #[test]
129    fn test_json_backward_compat_no_marker() {
130        // Legacy JSON (no marker) must still be readable
131        let obj = test_obj();
132        let json_bytes = serde_json::to_vec(&obj).unwrap();
133        let decoded: TestData = deserialize(&json_bytes).unwrap();
134        assert_eq!(decoded, obj);
135    }
136
137    #[test]
138    fn test_empty_data_fails() {
139        let result: Result<TestData> = deserialize(&[]);
140        assert!(result.is_err());
141    }
142
143    #[test]
144    fn test_serialize_for_signing_is_json() {
145        let obj = test_obj();
146        let bytes = serialize_for_signing(&obj);
147        // Should be parseable as JSON
148        let decoded: TestData = serde_json::from_slice(&bytes).unwrap();
149        assert_eq!(decoded, obj);
150    }
151
152    #[test]
153    fn test_format_from_config() {
154        let mut config = std::collections::BTreeMap::new();
155        assert_eq!(MetadataFormat::from_config(&config), MetadataFormat::Json);
156        config.insert("serialization_format".into(), "json".into());
157        assert_eq!(MetadataFormat::from_config(&config), MetadataFormat::Json);
158        config.insert("serialization_format".into(), "cbor".into());
159        assert_eq!(MetadataFormat::from_config(&config), MetadataFormat::Cbor);
160        config.insert("serialization_format".into(), "invalid".into());
161        assert_eq!(MetadataFormat::from_config(&config), MetadataFormat::Json);
162    }
163
164    #[test]
165    fn test_cbor_compactness() {
166        let obj = test_obj();
167        let json_bytes = serialize(&obj, &MetadataFormat::Json);
168        let cbor_bytes = serialize(&obj, &MetadataFormat::Cbor);
169        // CBOR should be smaller than JSON for this struct
170        assert!(cbor_bytes.len() < json_bytes.len());
171    }
172
173    #[test]
174    fn test_cbor_marker_byte() {
175        let obj = test_obj();
176        let bytes = serialize(&obj, &MetadataFormat::Cbor);
177        // First byte must be 0x02
178        assert_eq!(bytes[0], CBOR_MARKER);
179        // Must have content after marker
180        assert!(bytes.len() > 1);
181    }
182
183    #[test]
184    fn test_cbor_btreemap_roundtrip() {
185        let mut map = std::collections::BTreeMap::new();
186        map.insert("key1".to_string(), "value1".to_string());
187        map.insert("key2".to_string(), "value2".to_string());
188        let bytes = serialize(&map, &MetadataFormat::Cbor);
189        assert_eq!(bytes[0], CBOR_MARKER);
190        let decoded: std::collections::BTreeMap<String, String> = deserialize(&bytes).unwrap();
191        assert_eq!(decoded["key1"], "value1");
192        assert_eq!(decoded["key2"], "value2");
193        assert_eq!(decoded.len(), 2);
194    }
195}