1use crate::{EinsumGraph, TLExpr};
9use serde::{Deserialize, Serialize};
10
11pub const FORMAT_VERSION: &str = "1.0.0";
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VersionedExpr {
17 pub version: String,
19 pub created_at: Option<String>,
21 pub metadata: Option<serde_json::Map<String, serde_json::Value>>,
23 pub expr: TLExpr,
25}
26
27impl VersionedExpr {
28 pub fn new(expr: TLExpr) -> Self {
30 VersionedExpr {
31 version: FORMAT_VERSION.to_string(),
32 created_at: Some(chrono::Utc::now().to_rfc3339()),
33 metadata: None,
34 expr,
35 }
36 }
37
38 pub fn with_metadata(
40 expr: TLExpr,
41 metadata: serde_json::Map<String, serde_json::Value>,
42 ) -> Self {
43 VersionedExpr {
44 version: FORMAT_VERSION.to_string(),
45 created_at: Some(chrono::Utc::now().to_rfc3339()),
46 metadata: Some(metadata),
47 expr,
48 }
49 }
50
51 pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
53 serde_json::to_string_pretty(self)
54 }
55
56 pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
58 serde_json::to_string(self)
59 }
60
61 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
63 serde_json::from_str(json)
64 }
65
66 pub fn to_binary(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
68 oxicode::serde::encode_to_vec(self, oxicode::config::standard())
69 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
70 }
71
72 pub fn from_binary(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
74 let (result, _): (Self, usize) =
75 oxicode::serde::decode_from_slice(bytes, oxicode::config::standard())
76 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
77 Ok(result)
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct VersionedGraph {
84 pub version: String,
86 pub created_at: Option<String>,
88 pub metadata: Option<serde_json::Map<String, serde_json::Value>>,
90 pub graph: EinsumGraph,
92}
93
94impl VersionedGraph {
95 pub fn new(graph: EinsumGraph) -> Self {
97 VersionedGraph {
98 version: FORMAT_VERSION.to_string(),
99 created_at: Some(chrono::Utc::now().to_rfc3339()),
100 metadata: None,
101 graph,
102 }
103 }
104
105 pub fn with_metadata(
107 graph: EinsumGraph,
108 metadata: serde_json::Map<String, serde_json::Value>,
109 ) -> Self {
110 VersionedGraph {
111 version: FORMAT_VERSION.to_string(),
112 created_at: Some(chrono::Utc::now().to_rfc3339()),
113 metadata: Some(metadata),
114 graph,
115 }
116 }
117
118 pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
120 serde_json::to_string_pretty(self)
121 }
122
123 pub fn to_json_compact(&self) -> Result<String, serde_json::Error> {
125 serde_json::to_string(self)
126 }
127
128 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
130 serde_json::from_str(json)
131 }
132
133 pub fn to_binary(&self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
135 oxicode::serde::encode_to_vec(self, oxicode::config::standard())
136 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
137 }
138
139 pub fn from_binary(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
141 let (result, _): (Self, usize) =
142 oxicode::serde::decode_from_slice(bytes, oxicode::config::standard())
143 .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
144 Ok(result)
145 }
146
147 pub fn is_compatible(&self) -> bool {
149 let self_major = self
151 .version
152 .split('.')
153 .next()
154 .and_then(|s| s.parse::<u32>().ok());
155 let current_major = FORMAT_VERSION
156 .split('.')
157 .next()
158 .and_then(|s| s.parse::<u32>().ok());
159
160 self_major == current_major
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::{TLExpr, Term};
168
169 #[test]
170 fn test_versioned_expr_creation() {
171 let expr = TLExpr::pred("test", vec![Term::var("x")]);
172 let versioned = VersionedExpr::new(expr.clone());
173
174 assert_eq!(versioned.version, FORMAT_VERSION);
175 assert!(versioned.created_at.is_some());
176 assert!(versioned.metadata.is_none());
177 assert_eq!(versioned.expr, expr);
178 }
179
180 #[test]
181 fn test_versioned_expr_with_metadata() {
182 let expr = TLExpr::pred("test", vec![Term::var("x")]);
183 let mut metadata = serde_json::Map::new();
184 metadata.insert("author".to_string(), serde_json::json!("test"));
185
186 let versioned = VersionedExpr::with_metadata(expr.clone(), metadata.clone());
187
188 assert_eq!(versioned.version, FORMAT_VERSION);
189 assert!(versioned.created_at.is_some());
190 assert_eq!(versioned.metadata, Some(metadata));
191 assert_eq!(versioned.expr, expr);
192 }
193
194 #[test]
195 fn test_versioned_expr_json_roundtrip() {
196 let expr = TLExpr::pred("test", vec![Term::var("x")]);
197 let versioned = VersionedExpr::new(expr.clone());
198
199 let json = versioned.to_json_pretty().unwrap();
200 let deserialized = VersionedExpr::from_json(&json).unwrap();
201
202 assert_eq!(deserialized.version, versioned.version);
203 assert_eq!(deserialized.expr, versioned.expr);
204 }
205
206 #[test]
207 fn test_versioned_expr_binary_roundtrip() {
208 let expr = TLExpr::pred("test", vec![Term::var("x")]);
209 let versioned = VersionedExpr::new(expr.clone());
210
211 let binary = versioned.to_binary().unwrap();
212 let deserialized = VersionedExpr::from_binary(&binary).unwrap();
213
214 assert_eq!(deserialized.version, versioned.version);
215 assert_eq!(deserialized.expr, versioned.expr);
216 }
217
218 #[test]
219 fn test_versioned_graph_creation() {
220 let graph = EinsumGraph {
221 tensors: vec![],
222 nodes: vec![],
223 inputs: vec![],
224 outputs: vec![],
225 tensor_metadata: std::collections::HashMap::new(),
226 };
227 let versioned = VersionedGraph::new(graph.clone());
228
229 assert_eq!(versioned.version, FORMAT_VERSION);
230 assert!(versioned.created_at.is_some());
231 assert!(versioned.metadata.is_none());
232 assert_eq!(versioned.graph, graph);
233 }
234
235 #[test]
236 fn test_versioned_graph_json_roundtrip() {
237 let graph = EinsumGraph {
238 tensors: vec![],
239 nodes: vec![],
240 inputs: vec![],
241 outputs: vec![],
242 tensor_metadata: std::collections::HashMap::new(),
243 };
244 let versioned = VersionedGraph::new(graph.clone());
245
246 let json = versioned.to_json_pretty().unwrap();
247 let deserialized = VersionedGraph::from_json(&json).unwrap();
248
249 assert_eq!(deserialized.version, versioned.version);
250 assert_eq!(deserialized.graph, versioned.graph);
251 }
252
253 #[test]
254 fn test_versioned_graph_binary_roundtrip() {
255 let graph = EinsumGraph {
256 tensors: vec![],
257 nodes: vec![],
258 inputs: vec![],
259 outputs: vec![],
260 tensor_metadata: std::collections::HashMap::new(),
261 };
262 let versioned = VersionedGraph::new(graph.clone());
263
264 let binary = versioned.to_binary().unwrap();
265 let deserialized = VersionedGraph::from_binary(&binary).unwrap();
266
267 assert_eq!(deserialized.version, versioned.version);
268 assert_eq!(deserialized.graph, versioned.graph);
269 }
270
271 #[test]
272 fn test_version_compatibility() {
273 let graph = EinsumGraph {
274 tensors: vec![],
275 nodes: vec![],
276 inputs: vec![],
277 outputs: vec![],
278 tensor_metadata: std::collections::HashMap::new(),
279 };
280 let versioned = VersionedGraph::new(graph);
281
282 assert!(versioned.is_compatible());
283
284 let mut incompatible = versioned.clone();
286 incompatible.version = "2.0.0".to_string();
287 assert!(!incompatible.is_compatible());
288 }
289
290 #[test]
291 fn test_json_is_human_readable() {
292 let expr = TLExpr::and(
293 TLExpr::pred("p", vec![Term::var("x")]),
294 TLExpr::pred("q", vec![Term::var("y")]),
295 );
296 let versioned = VersionedExpr::new(expr);
297
298 let json = versioned.to_json_pretty().unwrap();
299
300 assert!(json.contains("version"));
302 assert!(json.contains("created_at"));
303 assert!(json.contains("expr"));
304 assert!(json.contains("And"));
305 }
306
307 #[test]
308 fn test_binary_smaller_than_json() {
309 let expr = TLExpr::and(
310 TLExpr::pred("p", vec![Term::var("x")]),
311 TLExpr::pred("q", vec![Term::var("y")]),
312 );
313 let versioned = VersionedExpr::new(expr);
314
315 let json = versioned.to_json_compact().unwrap();
316 let binary = versioned.to_binary().unwrap();
317
318 assert!(binary.len() <= json.len() * 2); }
322}