Skip to main content

torsh_fx/fx/
serialization.rs

1//! Serialization support for FX graphs
2
3use crate::fx::types::{Edge, Node};
4use crate::graph_analysis::GraphMetrics;
5use crate::FxGraph;
6use petgraph::graph::Graph;
7use petgraph::visit::EdgeRef;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use torsh_core::{Result, TorshError};
11
12/// Convenience type alias for Results in this crate
13pub type TorshResult<T> = Result<T>;
14
15/// Serializable representation of FxGraph
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SerializableGraph {
18    nodes: Vec<(usize, Node)>,
19    edges: Vec<(usize, usize, Edge)>,
20    inputs: Vec<usize>,
21    outputs: Vec<usize>,
22}
23
24impl SerializableGraph {
25    /// Convert FxGraph to serializable format
26    pub fn from_graph(graph: &FxGraph) -> Self {
27        let mut nodes = Vec::new();
28        let mut edges = Vec::new();
29
30        // Collect nodes
31        for (idx, node) in graph.nodes() {
32            nodes.push((idx.index(), node.clone()));
33        }
34
35        // Collect edges
36        for edge_ref in graph.graph.edge_references() {
37            edges.push((
38                edge_ref.source().index(),
39                edge_ref.target().index(),
40                edge_ref.weight().clone(),
41            ));
42        }
43
44        Self {
45            nodes,
46            edges,
47            inputs: graph.inputs.iter().map(|idx| idx.index()).collect(),
48            outputs: graph.outputs.iter().map(|idx| idx.index()).collect(),
49        }
50    }
51
52    /// Convert serializable format to FxGraph
53    pub fn to_graph(self) -> FxGraph {
54        let mut graph = Graph::new();
55        let mut node_mapping = std::collections::HashMap::new();
56
57        // Add nodes
58        for (original_idx, node) in self.nodes {
59            let new_idx = graph.add_node(node);
60            node_mapping.insert(original_idx, new_idx);
61        }
62
63        // Add edges
64        for (src_idx, target_idx, edge) in self.edges {
65            if let (Some(&src), Some(&target)) =
66                (node_mapping.get(&src_idx), node_mapping.get(&target_idx))
67            {
68                graph.add_edge(src, target, edge);
69            }
70        }
71
72        // Map input and output indices
73        let inputs = self
74            .inputs
75            .into_iter()
76            .filter_map(|idx| node_mapping.get(&idx).copied())
77            .collect();
78        let outputs = self
79            .outputs
80            .into_iter()
81            .filter_map(|idx| node_mapping.get(&idx).copied())
82            .collect();
83
84        FxGraph {
85            graph,
86            inputs,
87            outputs,
88        }
89    }
90
91    /// Get the number of nodes in the graph
92    pub fn node_count(&self) -> usize {
93        self.nodes.len()
94    }
95
96    /// Get the number of edges in the graph
97    pub fn edge_count(&self) -> usize {
98        self.edges.len()
99    }
100
101    /// Basic validation of the graph structure
102    pub fn validate(&self) -> TorshResult<()> {
103        // Check that all edge endpoints refer to valid nodes
104        let node_indices: std::collections::HashSet<usize> =
105            self.nodes.iter().map(|(idx, _)| *idx).collect();
106
107        for (src, target, _) in &self.edges {
108            if !node_indices.contains(src) {
109                return Err(TorshError::InvalidArgument(format!(
110                    "Edge source {src} not found in nodes"
111                )));
112            }
113            if !node_indices.contains(target) {
114                return Err(TorshError::InvalidArgument(format!(
115                    "Edge target {target} not found in nodes"
116                )));
117            }
118        }
119
120        // Check that input/output indices are valid
121        for &input_idx in &self.inputs {
122            if !node_indices.contains(&input_idx) {
123                return Err(TorshError::InvalidArgument(format!(
124                    "Input index {input_idx} not found in nodes"
125                )));
126            }
127        }
128
129        for &output_idx in &self.outputs {
130            if !node_indices.contains(&output_idx) {
131                return Err(TorshError::InvalidArgument(format!(
132                    "Output index {output_idx} not found in nodes"
133                )));
134            }
135        }
136
137        Ok(())
138    }
139
140    /// Count operations by type
141    pub fn operation_counts(&self) -> HashMap<String, usize> {
142        let mut counts = HashMap::new();
143
144        for (_, node) in &self.nodes {
145            let op_name = match node {
146                Node::Input(_) => "input".to_string(),
147                Node::Call(op, _) => op.clone(),
148                Node::Output => "output".to_string(),
149                Node::Conditional { .. } => "conditional".to_string(),
150                Node::Loop { .. } => "loop".to_string(),
151                Node::Merge { .. } => "merge".to_string(),
152                Node::GetAttr { .. } => "getattr".to_string(),
153            };
154
155            *counts.entry(op_name).or_insert(0) += 1;
156        }
157
158        counts
159    }
160
161    /// Check if the graph is a linear chain
162    pub fn is_linear_chain(&self) -> bool {
163        if self.nodes.len() <= 1 {
164            return true;
165        }
166
167        // Build adjacency list
168        let mut outgoing: HashMap<usize, Vec<usize>> = HashMap::new();
169        let mut incoming: HashMap<usize, Vec<usize>> = HashMap::new();
170
171        for (src, target, _) in &self.edges {
172            outgoing.entry(*src).or_default().push(*target);
173            incoming.entry(*target).or_default().push(*src);
174        }
175
176        // Check that each node has at most 1 outgoing and 1 incoming edge
177        for (idx, _) in &self.nodes {
178            let out_count = outgoing.get(idx).map_or(0, |v| v.len());
179            let in_count = incoming.get(idx).map_or(0, |v| v.len());
180
181            if out_count > 1 || in_count > 1 {
182                return false;
183            }
184        }
185
186        true
187    }
188
189    /// Check if the graph has cycles
190    pub fn has_cycles(&self) -> bool {
191        let mut visited = std::collections::HashSet::new();
192        let mut rec_stack = std::collections::HashSet::new();
193
194        // Build adjacency list
195        let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
196        for (src, target, _) in &self.edges {
197            adj.entry(*src).or_default().push(*target);
198        }
199
200        fn dfs_has_cycle(
201            node: usize,
202            adj: &HashMap<usize, Vec<usize>>,
203            visited: &mut std::collections::HashSet<usize>,
204            rec_stack: &mut std::collections::HashSet<usize>,
205        ) -> bool {
206            visited.insert(node);
207            rec_stack.insert(node);
208
209            if let Some(neighbors) = adj.get(&node) {
210                for &neighbor in neighbors {
211                    if !visited.contains(&neighbor) {
212                        if dfs_has_cycle(neighbor, adj, visited, rec_stack) {
213                            return true;
214                        }
215                    } else if rec_stack.contains(&neighbor) {
216                        return true;
217                    }
218                }
219            }
220
221            rec_stack.remove(&node);
222            false
223        }
224
225        for (idx, _) in &self.nodes {
226            if !visited.contains(idx) && dfs_has_cycle(*idx, &adj, &mut visited, &mut rec_stack) {
227                return true;
228            }
229        }
230
231        false
232    }
233
234    /// Get the maximum depth of the graph
235    pub fn get_depth(&self) -> usize {
236        if self.nodes.is_empty() {
237            return 0;
238        }
239
240        // Build adjacency list
241        let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
242        for (src, target, _) in &self.edges {
243            adj.entry(*src).or_default().push(*target);
244        }
245
246        fn dfs_depth(
247            node: usize,
248            adj: &HashMap<usize, Vec<usize>>,
249            visited: &mut std::collections::HashSet<usize>,
250        ) -> usize {
251            if visited.contains(&node) {
252                return 0; // Avoid infinite recursion in cycles
253            }
254            visited.insert(node);
255
256            let mut max_depth = 0;
257            if let Some(neighbors) = adj.get(&node) {
258                for &neighbor in neighbors {
259                    let depth = dfs_depth(neighbor, adj, visited);
260                    max_depth = max_depth.max(depth);
261                }
262            }
263
264            visited.remove(&node);
265            max_depth + 1
266        }
267
268        let mut max_depth = 0;
269        for (idx, _) in &self.nodes {
270            let mut visited = std::collections::HashSet::new();
271            let depth = dfs_depth(*idx, &adj, &mut visited);
272            max_depth = max_depth.max(depth);
273        }
274
275        max_depth
276    }
277
278    /// Find orphaned nodes (nodes with no incoming or outgoing edges)
279    pub fn find_orphaned_nodes(&self) -> Vec<usize> {
280        let mut connected_nodes = std::collections::HashSet::new();
281
282        for (src, target, _) in &self.edges {
283            connected_nodes.insert(*src);
284            connected_nodes.insert(*target);
285        }
286
287        self.nodes
288            .iter()
289            .filter_map(|(idx, _)| {
290                if !connected_nodes.contains(idx) {
291                    Some(*idx)
292                } else {
293                    None
294                }
295            })
296            .collect()
297    }
298
299    /// Find dead-end nodes (nodes that don't lead to any output)
300    pub fn find_dead_end_nodes(&self) -> Vec<usize> {
301        if self.outputs.is_empty() {
302            return Vec::new();
303        }
304
305        // Build reverse adjacency list (incoming edges)
306        let mut incoming: HashMap<usize, Vec<usize>> = HashMap::new();
307        for (src, target, _) in &self.edges {
308            incoming.entry(*target).or_default().push(*src);
309        }
310
311        // BFS from all output nodes to find reachable nodes
312        let mut reachable = std::collections::HashSet::new();
313        let mut queue = std::collections::VecDeque::new();
314
315        for &output in &self.outputs {
316            queue.push_back(output);
317            reachable.insert(output);
318        }
319
320        while let Some(node) = queue.pop_front() {
321            if let Some(predecessors) = incoming.get(&node) {
322                for &pred in predecessors {
323                    if !reachable.contains(&pred) {
324                        reachable.insert(pred);
325                        queue.push_back(pred);
326                    }
327                }
328            }
329        }
330
331        // Return nodes that are not reachable from any output
332        self.nodes
333            .iter()
334            .filter_map(|(idx, _)| {
335                if !reachable.contains(idx) {
336                    Some(*idx)
337                } else {
338                    None
339                }
340            })
341            .collect()
342    }
343
344    /// Get all call nodes
345    pub fn call_nodes(&self) -> Vec<usize> {
346        self.nodes
347            .iter()
348            .filter_map(|(idx, node)| match node {
349                Node::Call(_, _) => Some(*idx),
350                _ => None,
351            })
352            .collect()
353    }
354
355    /// Graph metrics for analysis
356    pub fn metrics(&self) -> GraphMetrics {
357        let node_count = self.node_count();
358        let edge_count = self.edge_count();
359        let depth = self.get_depth();
360        let has_cycles = self.has_cycles();
361        let is_linear = self.is_linear_chain();
362
363        // Simple complexity score based on various factors
364        let complexity_score = (node_count as f32 * 0.1)
365            + (edge_count as f32 * 0.15)
366            + (depth as f32 * 0.2)
367            + if has_cycles { 10.0 } else { 0.0 }
368            + if is_linear { -2.0 } else { 5.0 };
369
370        GraphMetrics {
371            node_count,
372            edge_count,
373            input_count: self.inputs.len(),
374            output_count: self.outputs.len(),
375            max_depth: depth,
376            average_fanout: if node_count > 0 {
377                edge_count as f64 / node_count as f64
378            } else {
379                0.0
380            },
381            connectivity_ratio: if node_count > 1 {
382                edge_count as f64 / ((node_count * (node_count - 1)) as f64)
383            } else {
384                0.0
385            },
386            complexity_score: complexity_score as f64,
387            operation_distribution: self
388                .operation_counts()
389                .into_iter()
390                .map(|(k, v)| (k, v as u32))
391                .collect(),
392            critical_path_length: depth,
393        }
394    }
395
396    /// Create a new empty graph
397    pub fn new() -> Self {
398        Self {
399            nodes: Vec::new(),
400            edges: Vec::new(),
401            inputs: Vec::new(),
402            outputs: Vec::new(),
403        }
404    }
405
406    /// Add a node to the graph
407    pub fn add_node(&mut self, node: Node) -> usize {
408        let idx = self.nodes.len();
409        self.nodes.push((idx, node));
410        idx
411    }
412
413    /// Add an input node index
414    pub fn add_input(&mut self, idx: usize) {
415        self.inputs.push(idx);
416    }
417
418    /// Add an output node index
419    pub fn add_output(&mut self, idx: usize) {
420        self.outputs.push(idx);
421    }
422
423    /// Add an edge between nodes
424    pub fn add_edge(&mut self, src: usize, target: usize, edge: Edge) {
425        self.edges.push((src, target, edge));
426    }
427
428    /// Create a sequential chain of operations
429    pub fn sequential_ops(ops: &[&str]) -> Self {
430        let mut graph = Self::new();
431
432        if ops.is_empty() {
433            return graph;
434        }
435
436        let input = graph.add_node(Node::Input("x".to_string()));
437        graph.add_input(input);
438
439        let mut prev = input;
440        for (i, &op) in ops.iter().enumerate() {
441            let node = graph.add_node(Node::Call(op.to_string(), vec![format!("arg_{i}")]));
442            graph.add_edge(
443                prev,
444                node,
445                Edge {
446                    name: format!("edge_{i}"),
447                },
448            );
449            prev = node;
450        }
451
452        let output = graph.add_node(Node::Output);
453        graph.add_edge(
454            prev,
455            output,
456            Edge {
457                name: "final".to_string(),
458            },
459        );
460        graph.add_output(output);
461
462        graph
463    }
464}
465
466impl FxGraph {
467    /// Serialize graph to JSON
468    pub fn to_json(&self) -> TorshResult<String> {
469        let serializable = SerializableGraph::from_graph(self);
470        serde_json::to_string_pretty(&serializable).map_err(|e| {
471            torsh_core::error::TorshError::SerializationError(format!(
472                "Failed to serialize graph to JSON: {}",
473                e
474            ))
475        })
476    }
477
478    /// Deserialize graph from JSON
479    pub fn from_json(json: &str) -> TorshResult<Self> {
480        let serializable: SerializableGraph = serde_json::from_str(json).map_err(|e| {
481            torsh_core::error::TorshError::SerializationError(format!(
482                "Failed to deserialize graph from JSON: {}",
483                e
484            ))
485        })?;
486        Ok(serializable.to_graph())
487    }
488
489    /// Serialize graph to binary format
490    pub fn to_binary(&self) -> TorshResult<Vec<u8>> {
491        let serializable = SerializableGraph::from_graph(self);
492        oxicode::serde::encode_to_vec(&serializable, oxicode::config::standard()).map_err(|e| {
493            torsh_core::error::TorshError::SerializationError(format!(
494                "Failed to serialize graph to binary: {}",
495                e
496            ))
497        })
498    }
499
500    /// Deserialize graph from binary format
501    pub fn from_binary(data: &[u8]) -> TorshResult<Self> {
502        let (serializable, _): (SerializableGraph, usize) =
503            oxicode::serde::decode_from_slice(data, oxicode::config::standard()).map_err(|e| {
504                torsh_core::error::TorshError::SerializationError(format!(
505                    "Failed to deserialize graph from binary: {}",
506                    e
507                ))
508            })?;
509        Ok(serializable.to_graph())
510    }
511}