Skip to main content

tensorlogic_ir/
serialization.rs

1//! Enhanced serialization support for TensorLogic IR.
2//!
3//! This module provides improved serialization formats including:
4//! - Human-readable JSON with version tagging
5//! - Compact binary format using bincode
6//! - Versioned format support for backward compatibility
7
8use crate::{EinsumGraph, TLExpr};
9use serde::{Deserialize, Serialize};
10
11/// Current serialization format version
12pub const FORMAT_VERSION: &str = "1.0.0";
13
14/// Versioned wrapper for TLExpr serialization
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VersionedExpr {
17    /// Format version (semver)
18    pub version: String,
19    /// Creation timestamp (ISO 8601)
20    pub created_at: Option<String>,
21    /// Optional metadata
22    pub metadata: Option<serde_json::Map<String, serde_json::Value>>,
23    /// The expression
24    pub expr: TLExpr,
25}
26
27impl VersionedExpr {
28    /// Create a new versioned expression
29    pub fn new(expr: TLExpr) -> Self {
30        VersionedExpr {
31            version: FORMAT_VERSION.to_string(),
32            created_at: Some(chrono::Utc::now().to_rfc3339()),
33            metadata: None,
34            expr,
35        }
36    }
37
38    /// Create with custom metadata
39    pub fn with_metadata(
40        expr: TLExpr,
41        metadata: serde_json::Map<String, serde_json::Value>,
42    ) -> Self {
43        VersionedExpr {
44            version: FORMAT_VERSION.to_string(),
45            created_at: Some(chrono::Utc::now().to_rfc3339()),
46            metadata: Some(metadata),
47            expr,
48        }
49    }
50
51    /// Serialize to pretty JSON
52    pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
53        serde_json::to_string_pretty(self)
54    }
55
56    /// Serialize to compact JSON
57    pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
58        serde_json::to_string(self)
59    }
60
61    /// Deserialize from JSON
62    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
63        serde_json::from_str(json)
64    }
65
66    /// Serialize to binary format
67    pub fn to_binary(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
68        oxicode::serde::encode_to_vec(self, oxicode::config::standard())
69            .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
70    }
71
72    /// Deserialize from binary format
73    pub fn from_binary(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
74        let (result, _): (Self, usize) =
75            oxicode::serde::decode_from_slice(bytes, oxicode::config::standard())
76                .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
77        Ok(result)
78    }
79}
80
81/// Versioned wrapper for EinsumGraph serialization
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct VersionedGraph {
84    /// Format version (semver)
85    pub version: String,
86    /// Creation timestamp (ISO 8601)
87    pub created_at: Option<String>,
88    /// Optional metadata
89    pub metadata: Option<serde_json::Map<String, serde_json::Value>>,
90    /// The graph
91    pub graph: EinsumGraph,
92}
93
94impl VersionedGraph {
95    /// Create a new versioned graph
96    pub fn new(graph: EinsumGraph) -> Self {
97        VersionedGraph {
98            version: FORMAT_VERSION.to_string(),
99            created_at: Some(chrono::Utc::now().to_rfc3339()),
100            metadata: None,
101            graph,
102        }
103    }
104
105    /// Create with custom metadata
106    pub fn with_metadata(
107        graph: EinsumGraph,
108        metadata: serde_json::Map<String, serde_json::Value>,
109    ) -> Self {
110        VersionedGraph {
111            version: FORMAT_VERSION.to_string(),
112            created_at: Some(chrono::Utc::now().to_rfc3339()),
113            metadata: Some(metadata),
114            graph,
115        }
116    }
117
118    /// Serialize to pretty JSON
119    pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
120        serde_json::to_string_pretty(self)
121    }
122
123    /// Serialize to compact JSON
124    pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
125        serde_json::to_string(self)
126    }
127
128    /// Deserialize from JSON
129    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
130        serde_json::from_str(json)
131    }
132
133    /// Serialize to binary format
134    pub fn to_binary(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
135        oxicode::serde::encode_to_vec(self, oxicode::config::standard())
136            .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
137    }
138
139    /// Deserialize from binary format
140    pub fn from_binary(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
141        let (result, _): (Self, usize) =
142            oxicode::serde::decode_from_slice(bytes, oxicode::config::standard())
143                .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
144        Ok(result)
145    }
146
147    /// Check if version is compatible with current version
148    pub fn is_compatible(&self) -> bool {
149        // Simple check: major version must match
150        let self_major = self
151            .version
152            .split('.')
153            .next()
154            .and_then(|s| s.parse::<u32>().ok());
155        let current_major = FORMAT_VERSION
156            .split('.')
157            .next()
158            .and_then(|s| s.parse::<u32>().ok());
159
160        self_major == current_major
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::{TLExpr, Term};
168
169    #[test]
170    fn test_versioned_expr_creation() {
171        let expr = TLExpr::pred("test", vec![Term::var("x")]);
172        let versioned = VersionedExpr::new(expr.clone());
173
174        assert_eq!(versioned.version, FORMAT_VERSION);
175        assert!(versioned.created_at.is_some());
176        assert!(versioned.metadata.is_none());
177        assert_eq!(versioned.expr, expr);
178    }
179
180    #[test]
181    fn test_versioned_expr_with_metadata() {
182        let expr = TLExpr::pred("test", vec![Term::var("x")]);
183        let mut metadata = serde_json::Map::new();
184        metadata.insert("author".to_string(), serde_json::json!("test"));
185
186        let versioned = VersionedExpr::with_metadata(expr.clone(), metadata.clone());
187
188        assert_eq!(versioned.version, FORMAT_VERSION);
189        assert!(versioned.created_at.is_some());
190        assert_eq!(versioned.metadata, Some(metadata));
191        assert_eq!(versioned.expr, expr);
192    }
193
194    #[test]
195    fn test_versioned_expr_json_roundtrip() {
196        let expr = TLExpr::pred("test", vec![Term::var("x")]);
197        let versioned = VersionedExpr::new(expr.clone());
198
199        let json = versioned.to_json_pretty().unwrap();
200        let deserialized = VersionedExpr::from_json(&json).unwrap();
201
202        assert_eq!(deserialized.version, versioned.version);
203        assert_eq!(deserialized.expr, versioned.expr);
204    }
205
206    #[test]
207    fn test_versioned_expr_binary_roundtrip() {
208        let expr = TLExpr::pred("test", vec![Term::var("x")]);
209        let versioned = VersionedExpr::new(expr.clone());
210
211        let binary = versioned.to_binary().unwrap();
212        let deserialized = VersionedExpr::from_binary(&binary).unwrap();
213
214        assert_eq!(deserialized.version, versioned.version);
215        assert_eq!(deserialized.expr, versioned.expr);
216    }
217
218    #[test]
219    fn test_versioned_graph_creation() {
220        let graph = EinsumGraph {
221            tensors: vec![],
222            nodes: vec![],
223            inputs: vec![],
224            outputs: vec![],
225            tensor_metadata: std::collections::HashMap::new(),
226        };
227        let versioned = VersionedGraph::new(graph.clone());
228
229        assert_eq!(versioned.version, FORMAT_VERSION);
230        assert!(versioned.created_at.is_some());
231        assert!(versioned.metadata.is_none());
232        assert_eq!(versioned.graph, graph);
233    }
234
235    #[test]
236    fn test_versioned_graph_json_roundtrip() {
237        let graph = EinsumGraph {
238            tensors: vec![],
239            nodes: vec![],
240            inputs: vec![],
241            outputs: vec![],
242            tensor_metadata: std::collections::HashMap::new(),
243        };
244        let versioned = VersionedGraph::new(graph.clone());
245
246        let json = versioned.to_json_pretty().unwrap();
247        let deserialized = VersionedGraph::from_json(&json).unwrap();
248
249        assert_eq!(deserialized.version, versioned.version);
250        assert_eq!(deserialized.graph, versioned.graph);
251    }
252
253    #[test]
254    fn test_versioned_graph_binary_roundtrip() {
255        let graph = EinsumGraph {
256            tensors: vec![],
257            nodes: vec![],
258            inputs: vec![],
259            outputs: vec![],
260            tensor_metadata: std::collections::HashMap::new(),
261        };
262        let versioned = VersionedGraph::new(graph.clone());
263
264        let binary = versioned.to_binary().unwrap();
265        let deserialized = VersionedGraph::from_binary(&binary).unwrap();
266
267        assert_eq!(deserialized.version, versioned.version);
268        assert_eq!(deserialized.graph, versioned.graph);
269    }
270
271    #[test]
272    fn test_version_compatibility() {
273        let graph = EinsumGraph {
274            tensors: vec![],
275            nodes: vec![],
276            inputs: vec![],
277            outputs: vec![],
278            tensor_metadata: std::collections::HashMap::new(),
279        };
280        let versioned = VersionedGraph::new(graph);
281
282        assert!(versioned.is_compatible());
283
284        // Test with different version
285        let mut incompatible = versioned.clone();
286        incompatible.version = "2.0.0".to_string();
287        assert!(!incompatible.is_compatible());
288    }
289
290    #[test]
291    fn test_json_is_human_readable() {
292        let expr = TLExpr::and(
293            TLExpr::pred("p", vec![Term::var("x")]),
294            TLExpr::pred("q", vec![Term::var("y")]),
295        );
296        let versioned = VersionedExpr::new(expr);
297
298        let json = versioned.to_json_pretty().unwrap();
299
300        // Check that JSON contains human-readable structure
301        assert!(json.contains("version"));
302        assert!(json.contains("created_at"));
303        assert!(json.contains("expr"));
304        assert!(json.contains("And"));
305    }
306
307    #[test]
308    fn test_binary_smaller_than_json() {
309        let expr = TLExpr::and(
310            TLExpr::pred("p", vec![Term::var("x")]),
311            TLExpr::pred("q", vec![Term::var("y")]),
312        );
313        let versioned = VersionedExpr::new(expr);
314
315        let json = versioned.to_json_compact().unwrap();
316        let binary = versioned.to_binary().unwrap();
317
318        // Binary should typically be smaller than JSON
319        // (though not guaranteed for very small structures)
320        assert!(binary.len() <= json.len() * 2); // Allow some flexibility
321    }
322}