1use crate::fx::types::{Edge, Node};
4use crate::graph_analysis::GraphMetrics;
5use crate::FxGraph;
6use petgraph::graph::Graph;
7use petgraph::visit::EdgeRef;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use torsh_core::{Result, TorshError};
11
12pub type TorshResult<T> = Result<T>;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct SerializableGraph {
18 nodes: Vec<(usize, Node)>,
19 edges: Vec<(usize, usize, Edge)>,
20 inputs: Vec<usize>,
21 outputs: Vec<usize>,
22}
23
24impl SerializableGraph {
25 pub fn from_graph(graph: &FxGraph) -> Self {
27 let mut nodes = Vec::new();
28 let mut edges = Vec::new();
29
30 for (idx, node) in graph.nodes() {
32 nodes.push((idx.index(), node.clone()));
33 }
34
35 for edge_ref in graph.graph.edge_references() {
37 edges.push((
38 edge_ref.source().index(),
39 edge_ref.target().index(),
40 edge_ref.weight().clone(),
41 ));
42 }
43
44 Self {
45 nodes,
46 edges,
47 inputs: graph.inputs.iter().map(|idx| idx.index()).collect(),
48 outputs: graph.outputs.iter().map(|idx| idx.index()).collect(),
49 }
50 }
51
52 pub fn to_graph(self) -> FxGraph {
54 let mut graph = Graph::new();
55 let mut node_mapping = std::collections::HashMap::new();
56
57 for (original_idx, node) in self.nodes {
59 let new_idx = graph.add_node(node);
60 node_mapping.insert(original_idx, new_idx);
61 }
62
63 for (src_idx, target_idx, edge) in self.edges {
65 if let (Some(&src), Some(&target)) =
66 (node_mapping.get(&src_idx), node_mapping.get(&target_idx))
67 {
68 graph.add_edge(src, target, edge);
69 }
70 }
71
72 let inputs = self
74 .inputs
75 .into_iter()
76 .filter_map(|idx| node_mapping.get(&idx).copied())
77 .collect();
78 let outputs = self
79 .outputs
80 .into_iter()
81 .filter_map(|idx| node_mapping.get(&idx).copied())
82 .collect();
83
84 FxGraph {
85 graph,
86 inputs,
87 outputs,
88 }
89 }
90
91 pub fn node_count(&self) -> usize {
93 self.nodes.len()
94 }
95
96 pub fn edge_count(&self) -> usize {
98 self.edges.len()
99 }
100
101 pub fn validate(&self) -> TorshResult<()> {
103 let node_indices: std::collections::HashSet<usize> =
105 self.nodes.iter().map(|(idx, _)| *idx).collect();
106
107 for (src, target, _) in &self.edges {
108 if !node_indices.contains(src) {
109 return Err(TorshError::InvalidArgument(format!(
110 "Edge source {src} not found in nodes"
111 )));
112 }
113 if !node_indices.contains(target) {
114 return Err(TorshError::InvalidArgument(format!(
115 "Edge target {target} not found in nodes"
116 )));
117 }
118 }
119
120 for &input_idx in &self.inputs {
122 if !node_indices.contains(&input_idx) {
123 return Err(TorshError::InvalidArgument(format!(
124 "Input index {input_idx} not found in nodes"
125 )));
126 }
127 }
128
129 for &output_idx in &self.outputs {
130 if !node_indices.contains(&output_idx) {
131 return Err(TorshError::InvalidArgument(format!(
132 "Output index {output_idx} not found in nodes"
133 )));
134 }
135 }
136
137 Ok(())
138 }
139
140 pub fn operation_counts(&self) -> HashMap<String, usize> {
142 let mut counts = HashMap::new();
143
144 for (_, node) in &self.nodes {
145 let op_name = match node {
146 Node::Input(_) => "input".to_string(),
147 Node::Call(op, _) => op.clone(),
148 Node::Output => "output".to_string(),
149 Node::Conditional { .. } => "conditional".to_string(),
150 Node::Loop { .. } => "loop".to_string(),
151 Node::Merge { .. } => "merge".to_string(),
152 Node::GetAttr { .. } => "getattr".to_string(),
153 };
154
155 *counts.entry(op_name).or_insert(0) += 1;
156 }
157
158 counts
159 }
160
161 pub fn is_linear_chain(&self) -> bool {
163 if self.nodes.len() <= 1 {
164 return true;
165 }
166
167 let mut outgoing: HashMap<usize, Vec<usize>> = HashMap::new();
169 let mut incoming: HashMap<usize, Vec<usize>> = HashMap::new();
170
171 for (src, target, _) in &self.edges {
172 outgoing.entry(*src).or_default().push(*target);
173 incoming.entry(*target).or_default().push(*src);
174 }
175
176 for (idx, _) in &self.nodes {
178 let out_count = outgoing.get(idx).map_or(0, |v| v.len());
179 let in_count = incoming.get(idx).map_or(0, |v| v.len());
180
181 if out_count > 1 || in_count > 1 {
182 return false;
183 }
184 }
185
186 true
187 }
188
189 pub fn has_cycles(&self) -> bool {
191 let mut visited = std::collections::HashSet::new();
192 let mut rec_stack = std::collections::HashSet::new();
193
194 let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
196 for (src, target, _) in &self.edges {
197 adj.entry(*src).or_default().push(*target);
198 }
199
200 fn dfs_has_cycle(
201 node: usize,
202 adj: &HashMap<usize, Vec<usize>>,
203 visited: &mut std::collections::HashSet<usize>,
204 rec_stack: &mut std::collections::HashSet<usize>,
205 ) -> bool {
206 visited.insert(node);
207 rec_stack.insert(node);
208
209 if let Some(neighbors) = adj.get(&node) {
210 for &neighbor in neighbors {
211 if !visited.contains(&neighbor) {
212 if dfs_has_cycle(neighbor, adj, visited, rec_stack) {
213 return true;
214 }
215 } else if rec_stack.contains(&neighbor) {
216 return true;
217 }
218 }
219 }
220
221 rec_stack.remove(&node);
222 false
223 }
224
225 for (idx, _) in &self.nodes {
226 if !visited.contains(idx) && dfs_has_cycle(*idx, &adj, &mut visited, &mut rec_stack) {
227 return true;
228 }
229 }
230
231 false
232 }
233
234 pub fn get_depth(&self) -> usize {
236 if self.nodes.is_empty() {
237 return 0;
238 }
239
240 let mut adj: HashMap<usize, Vec<usize>> = HashMap::new();
242 for (src, target, _) in &self.edges {
243 adj.entry(*src).or_default().push(*target);
244 }
245
246 fn dfs_depth(
247 node: usize,
248 adj: &HashMap<usize, Vec<usize>>,
249 visited: &mut std::collections::HashSet<usize>,
250 ) -> usize {
251 if visited.contains(&node) {
252 return 0; }
254 visited.insert(node);
255
256 let mut max_depth = 0;
257 if let Some(neighbors) = adj.get(&node) {
258 for &neighbor in neighbors {
259 let depth = dfs_depth(neighbor, adj, visited);
260 max_depth = max_depth.max(depth);
261 }
262 }
263
264 visited.remove(&node);
265 max_depth + 1
266 }
267
268 let mut max_depth = 0;
269 for (idx, _) in &self.nodes {
270 let mut visited = std::collections::HashSet::new();
271 let depth = dfs_depth(*idx, &adj, &mut visited);
272 max_depth = max_depth.max(depth);
273 }
274
275 max_depth
276 }
277
278 pub fn find_orphaned_nodes(&self) -> Vec<usize> {
280 let mut connected_nodes = std::collections::HashSet::new();
281
282 for (src, target, _) in &self.edges {
283 connected_nodes.insert(*src);
284 connected_nodes.insert(*target);
285 }
286
287 self.nodes
288 .iter()
289 .filter_map(|(idx, _)| {
290 if !connected_nodes.contains(idx) {
291 Some(*idx)
292 } else {
293 None
294 }
295 })
296 .collect()
297 }
298
299 pub fn find_dead_end_nodes(&self) -> Vec<usize> {
301 if self.outputs.is_empty() {
302 return Vec::new();
303 }
304
305 let mut incoming: HashMap<usize, Vec<usize>> = HashMap::new();
307 for (src, target, _) in &self.edges {
308 incoming.entry(*target).or_default().push(*src);
309 }
310
311 let mut reachable = std::collections::HashSet::new();
313 let mut queue = std::collections::VecDeque::new();
314
315 for &output in &self.outputs {
316 queue.push_back(output);
317 reachable.insert(output);
318 }
319
320 while let Some(node) = queue.pop_front() {
321 if let Some(predecessors) = incoming.get(&node) {
322 for &pred in predecessors {
323 if !reachable.contains(&pred) {
324 reachable.insert(pred);
325 queue.push_back(pred);
326 }
327 }
328 }
329 }
330
331 self.nodes
333 .iter()
334 .filter_map(|(idx, _)| {
335 if !reachable.contains(idx) {
336 Some(*idx)
337 } else {
338 None
339 }
340 })
341 .collect()
342 }
343
344 pub fn call_nodes(&self) -> Vec<usize> {
346 self.nodes
347 .iter()
348 .filter_map(|(idx, node)| match node {
349 Node::Call(_, _) => Some(*idx),
350 _ => None,
351 })
352 .collect()
353 }
354
355 pub fn metrics(&self) -> GraphMetrics {
357 let node_count = self.node_count();
358 let edge_count = self.edge_count();
359 let depth = self.get_depth();
360 let has_cycles = self.has_cycles();
361 let is_linear = self.is_linear_chain();
362
363 let complexity_score = (node_count as f32 * 0.1)
365 + (edge_count as f32 * 0.15)
366 + (depth as f32 * 0.2)
367 + if has_cycles { 10.0 } else { 0.0 }
368 + if is_linear { -2.0 } else { 5.0 };
369
370 GraphMetrics {
371 node_count,
372 edge_count,
373 input_count: self.inputs.len(),
374 output_count: self.outputs.len(),
375 max_depth: depth,
376 average_fanout: if node_count > 0 {
377 edge_count as f64 / node_count as f64
378 } else {
379 0.0
380 },
381 connectivity_ratio: if node_count > 1 {
382 edge_count as f64 / ((node_count * (node_count - 1)) as f64)
383 } else {
384 0.0
385 },
386 complexity_score: complexity_score as f64,
387 operation_distribution: self
388 .operation_counts()
389 .into_iter()
390 .map(|(k, v)| (k, v as u32))
391 .collect(),
392 critical_path_length: depth,
393 }
394 }
395
396 pub fn new() -> Self {
398 Self {
399 nodes: Vec::new(),
400 edges: Vec::new(),
401 inputs: Vec::new(),
402 outputs: Vec::new(),
403 }
404 }
405
406 pub fn add_node(&mut self, node: Node) -> usize {
408 let idx = self.nodes.len();
409 self.nodes.push((idx, node));
410 idx
411 }
412
413 pub fn add_input(&mut self, idx: usize) {
415 self.inputs.push(idx);
416 }
417
418 pub fn add_output(&mut self, idx: usize) {
420 self.outputs.push(idx);
421 }
422
423 pub fn add_edge(&mut self, src: usize, target: usize, edge: Edge) {
425 self.edges.push((src, target, edge));
426 }
427
428 pub fn sequential_ops(ops: &[&str]) -> Self {
430 let mut graph = Self::new();
431
432 if ops.is_empty() {
433 return graph;
434 }
435
436 let input = graph.add_node(Node::Input("x".to_string()));
437 graph.add_input(input);
438
439 let mut prev = input;
440 for (i, &op) in ops.iter().enumerate() {
441 let node = graph.add_node(Node::Call(op.to_string(), vec![format!("arg_{i}")]));
442 graph.add_edge(
443 prev,
444 node,
445 Edge {
446 name: format!("edge_{i}"),
447 },
448 );
449 prev = node;
450 }
451
452 let output = graph.add_node(Node::Output);
453 graph.add_edge(
454 prev,
455 output,
456 Edge {
457 name: "final".to_string(),
458 },
459 );
460 graph.add_output(output);
461
462 graph
463 }
464}
465
466impl FxGraph {
467 pub fn to_json(&self) -> TorshResult<String> {
469 let serializable = SerializableGraph::from_graph(self);
470 serde_json::to_string_pretty(&serializable).map_err(|e| {
471 torsh_core::error::TorshError::SerializationError(format!(
472 "Failed to serialize graph to JSON: {}",
473 e
474 ))
475 })
476 }
477
478 pub fn from_json(json: &str) -> TorshResult<Self> {
480 let serializable: SerializableGraph = serde_json::from_str(json).map_err(|e| {
481 torsh_core::error::TorshError::SerializationError(format!(
482 "Failed to deserialize graph from JSON: {}",
483 e
484 ))
485 })?;
486 Ok(serializable.to_graph())
487 }
488
489 pub fn to_binary(&self) -> TorshResult<Vec<u8>> {
491 let serializable = SerializableGraph::from_graph(self);
492 oxicode::serde::encode_to_vec(&serializable, oxicode::config::standard()).map_err(|e| {
493 torsh_core::error::TorshError::SerializationError(format!(
494 "Failed to serialize graph to binary: {}",
495 e
496 ))
497 })
498 }
499
500 pub fn from_binary(data: &[u8]) -> TorshResult<Self> {
502 let (serializable, _): (SerializableGraph, usize) =
503 oxicode::serde::decode_from_slice(data, oxicode::config::standard()).map_err(|e| {
504 torsh_core::error::TorshError::SerializationError(format!(
505 "Failed to deserialize graph from binary: {}",
506 e
507 ))
508 })?;
509 Ok(serializable.to_graph())
510 }
511}