Skip to main content

ruvector_dag/dag/
serialization.rs

1//! DAG serialization and deserialization
2
3use serde::{Deserialize, Serialize};
4
5use super::operator_node::OperatorNode;
6use super::query_dag::{DagError, QueryDag};
7
8/// Serializable representation of a DAG
9#[derive(Debug, Clone, Serialize, Deserialize)]
10struct SerializableDag {
11    nodes: Vec<OperatorNode>,
12    edges: Vec<(usize, usize)>, // (parent, child) pairs
13    root: Option<usize>,
14}
15
16/// Trait for DAG serialization
17pub trait DagSerializer {
18    /// Serialize to JSON string
19    fn to_json(&self) -> Result<String, serde_json::Error>;
20
21    /// Serialize to bytes (using bincode-like format via JSON for now)
22    fn to_bytes(&self) -> Vec<u8>;
23}
24
25/// Trait for DAG deserialization
26pub trait DagDeserializer {
27    /// Deserialize from JSON string
28    fn from_json(json: &str) -> Result<Self, serde_json::Error>
29    where
30        Self: Sized;
31
32    /// Deserialize from bytes
33    fn from_bytes(bytes: &[u8]) -> Result<Self, DagError>
34    where
35        Self: Sized;
36}
37
38impl DagSerializer for QueryDag {
39    fn to_json(&self) -> Result<String, serde_json::Error> {
40        let nodes: Vec<OperatorNode> = self.nodes.values().cloned().collect();
41
42        let mut edges = Vec::new();
43        for (&parent, children) in &self.edges {
44            for &child in children {
45                edges.push((parent, child));
46            }
47        }
48
49        let serializable = SerializableDag {
50            nodes,
51            edges,
52            root: self.root,
53        };
54
55        serde_json::to_string_pretty(&serializable)
56    }
57
58    fn to_bytes(&self) -> Vec<u8> {
59        // For now, use JSON as bytes. In production, use bincode or similar
60        self.to_json().unwrap_or_default().into_bytes()
61    }
62}
63
64impl DagDeserializer for QueryDag {
65    fn from_json(json: &str) -> Result<Self, serde_json::Error> {
66        let serializable: SerializableDag = serde_json::from_str(json)?;
67
68        let mut dag = QueryDag::new();
69
70        // Create a mapping from old IDs to new IDs
71        let mut id_map = std::collections::HashMap::new();
72
73        // Add all nodes
74        for node in serializable.nodes {
75            let old_id = node.id;
76            let new_id = dag.add_node(node);
77            id_map.insert(old_id, new_id);
78        }
79
80        // Add all edges using mapped IDs
81        for (parent, child) in serializable.edges {
82            if let (Some(&new_parent), Some(&new_child)) = (id_map.get(&parent), id_map.get(&child))
83            {
84                // Ignore errors from edge addition during deserialization
85                let _ = dag.add_edge(new_parent, new_child);
86            }
87        }
88
89        // Map root if it exists
90        if let Some(old_root) = serializable.root {
91            dag.root = id_map.get(&old_root).copied();
92        }
93
94        Ok(dag)
95    }
96
97    fn from_bytes(bytes: &[u8]) -> Result<Self, DagError> {
98        let json = String::from_utf8(bytes.to_vec())
99            .map_err(|e| DagError::InvalidOperation(format!("Invalid UTF-8: {}", e)))?;
100
101        Self::from_json(&json)
102            .map_err(|e| DagError::InvalidOperation(format!("Deserialization failed: {}", e)))
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::OperatorNode;
110
111    #[test]
112    fn test_json_serialization() {
113        let mut dag = QueryDag::new();
114        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
115        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
116        let id3 = dag.add_node(OperatorNode::sort(0, vec!["name".to_string()]));
117
118        dag.add_edge(id1, id2).unwrap();
119        dag.add_edge(id2, id3).unwrap();
120
121        // Serialize
122        let json = dag.to_json().unwrap();
123        assert!(!json.is_empty());
124
125        // Deserialize
126        let deserialized = QueryDag::from_json(&json).unwrap();
127        assert_eq!(deserialized.node_count(), 3);
128        assert_eq!(deserialized.edge_count(), 2);
129    }
130
131    #[test]
132    fn test_bytes_serialization() {
133        let mut dag = QueryDag::new();
134        let id1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
135        let id2 = dag.add_node(OperatorNode::filter(0, "age > 18"));
136
137        dag.add_edge(id1, id2).unwrap();
138
139        // Serialize to bytes
140        let bytes = dag.to_bytes();
141        assert!(!bytes.is_empty());
142
143        // Deserialize from bytes
144        let deserialized = QueryDag::from_bytes(&bytes).unwrap();
145        assert_eq!(deserialized.node_count(), 2);
146        assert_eq!(deserialized.edge_count(), 1);
147    }
148
149    #[test]
150    fn test_complex_dag_roundtrip() {
151        let mut dag = QueryDag::new();
152
153        // Create a more complex DAG
154        let scan1 = dag.add_node(OperatorNode::seq_scan(0, "users"));
155        let scan2 = dag.add_node(OperatorNode::seq_scan(0, "orders"));
156        let join = dag.add_node(OperatorNode::hash_join(0, "user_id"));
157        let filter = dag.add_node(OperatorNode::filter(0, "total > 100"));
158        let sort = dag.add_node(OperatorNode::sort(0, vec!["date".to_string()]));
159        let limit = dag.add_node(OperatorNode::limit(0, 10));
160
161        dag.add_edge(scan1, join).unwrap();
162        dag.add_edge(scan2, join).unwrap();
163        dag.add_edge(join, filter).unwrap();
164        dag.add_edge(filter, sort).unwrap();
165        dag.add_edge(sort, limit).unwrap();
166
167        // Round trip
168        let json = dag.to_json().unwrap();
169        let restored = QueryDag::from_json(&json).unwrap();
170
171        assert_eq!(restored.node_count(), dag.node_count());
172        assert_eq!(restored.edge_count(), dag.edge_count());
173    }
174
175    #[test]
176    fn test_empty_dag_serialization() {
177        let dag = QueryDag::new();
178        let json = dag.to_json().unwrap();
179        let restored = QueryDag::from_json(&json).unwrap();
180
181        assert_eq!(restored.node_count(), 0);
182        assert_eq!(restored.edge_count(), 0);
183    }
184}