ruvector_dag/dag/
serialization.rs1use serde::{Deserialize, Serialize};
4
5use super::operator_node::OperatorNode;
6use super::query_dag::{DagError, QueryDag};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10struct SerializableDag {
11 nodes: Vec<OperatorNode>,
12 edges: Vec<(usize, usize)>, root: Option<usize>,
14}
15
16pub trait DagSerializer {
18 fn to_json(&self) -> Result<String, serde_json::Error>;
20
21 fn to_bytes(&self) -> Vec<u8>;
23}
24
25pub trait DagDeserializer {
27 fn from_json(json: &str) -> Result<Self, serde_json::Error>
29 where
30 Self: Sized;
31
32 fn from_bytes(bytes: &[u8]) -> Result<Self, DagError>
34 where
35 Self: Sized;
36}
37
38impl DagSerializer for QueryDag {
39 fn to_json(&self) -> Result<String, serde_json::Error> {
40 let nodes: Vec<OperatorNode> = self.nodes.values().cloned().collect();
41
42 let mut edges = Vec::new();
43 for (&parent, children) in &self.edges {
44 for &child in children {
45 edges.push((parent, child));
46 }
47 }
48
49 let serializable = SerializableDag {
50 nodes,
51 edges,
52 root: self.root,
53 };
54
55 serde_json::to_string_pretty(&serializable)
56 }
57
58 fn to_bytes(&self) -> Vec<u8> {
59 self.to_json().unwrap_or_default().into_bytes()
61 }
62}
63
64impl DagDeserializer for QueryDag {
65 fn from_json(json: &str) -> Result<Self, serde_json::Error> {
66 let serializable: SerializableDag = serde_json::from_str(json)?;
67
68 let mut dag = QueryDag::new();
69
70 let mut id_map = std::collections::HashMap::new();
72
73 for node in serializable.nodes {
75 let old_id = node.id;
76 let new_id = dag.add_node(node);
77 id_map.insert(old_id, new_id);
78 }
79
80 for (parent, child) in serializable.edges {
82 if let (Some(&new_parent), Some(&new_child)) = (id_map.get(&parent), id_map.get(&child))
83 {
84 let _ = dag.add_edge(new_parent, new_child);
86 }
87 }
88
89 if let Some(old_root) = serializable.root {
91 dag.root = id_map.get(&old_root).copied();
92 }
93
94 Ok(dag)
95 }
96
97 fn from_bytes(bytes: &[u8]) -> Result<Self, DagError> {
98 let json = String::from_utf8(bytes.to_vec())
99 .map_err(|e| DagError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
100
101 Self::from_json(&json)
102 .map_err(|e| DagError::InvalidOperation(format!("Deserialization failed: {}", e)))
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::OperatorNode;
110
111 #[test]
112 fn test_json_serialization() {
113 let mut dag = QueryDag::new();
114 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
115 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
116 let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
117
118 dag.add_edge(id1, id2).unwrap();
119 dag.add_edge(id2, id3).unwrap();
120
121 let json = dag.to_json().unwrap();
123 assert!(!json.is_empty());
124
125 let deserialized = QueryDag::from_json(&json).unwrap();
127 assert_eq!(deserialized.node_count(), 3);
128 assert_eq!(deserialized.edge_count(), 2);
129 }
130
131 #[test]
132 fn test_bytes_serialization() {
133 let mut dag = QueryDag::new();
134 let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
135 let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
136
137 dag.add_edge(id1, id2).unwrap();
138
139 let bytes = dag.to_bytes();
141 assert!(!bytes.is_empty());
142
143 let deserialized = QueryDag::from_bytes(&bytes).unwrap();
145 assert_eq!(deserialized.node_count(), 2);
146 assert_eq!(deserialized.edge_count(), 1);
147 }
148
149 #[test]
150 fn test_complex_dag_roundtrip() {
151 let mut dag = QueryDag::new();
152
153 let scan1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
155 let scan2 = dag.add_node(OperatorNode::seq_scan(0, "orders"));
156 let join = dag.add_node(OperatorNode::hash_join(0, "user_id"));
157 let filter = dag.add_node(OperatorNode::filter(0, "total > 100"));
158 let sort = dag.add_node(OperatorNode::sort(0, vec!["date".to_string()]));
159 let limit = dag.add_node(OperatorNode::limit(0, 10));
160
161 dag.add_edge(scan1, join).unwrap();
162 dag.add_edge(scan2, join).unwrap();
163 dag.add_edge(join, filter).unwrap();
164 dag.add_edge(filter, sort).unwrap();
165 dag.add_edge(sort, limit).unwrap();
166
167 let json = dag.to_json().unwrap();
169 let restored = QueryDag::from_json(&json).unwrap();
170
171 assert_eq!(restored.node_count(), dag.node_count());
172 assert_eq!(restored.edge_count(), dag.edge_count());
173 }
174
175 #[test]
176 fn test_empty_dag_serialization() {
177 let dag = QueryDag::new();
178 let json = dag.to_json().unwrap();
179 let restored = QueryDag::from_json(&json).unwrap();
180
181 assert_eq!(restored.node_count(), 0);
182 assert_eq!(restored.edge_count(), 0);
183 }
184}