Skip to main content

torsh_fx/
torchscript_compat.rs

1//! TorchScript compatibility module
2//!
3//! This module provides functionality to import from and export to TorchScript format,
4//! enabling interoperability with PyTorch models.
5
6use crate::{Edge, FxGraph, Node};
7use petgraph::graph::NodeIndex;
8use petgraph::visit::EdgeRef;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use torsh_core::Result;
12
13/// TorchScript model representation
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TorchScriptModel {
16    pub name: String,
17    pub version: String,
18    pub producer_name: String,
19    pub code: String,
20    pub constants: HashMap<String, TorchScriptConstant>,
21    pub parameters: Vec<TorchScriptParameter>,
22    pub methods: Vec<TorchScriptMethod>,
23    pub metadata: HashMap<String, String>,
24}
25
26/// TorchScript constant value
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum TorchScriptConstant {
29    Integer(i64),
30    Float(f64),
31    String(String),
32    Boolean(bool),
33    Tensor(TensorConstant),
34    List(Vec<TorchScriptConstant>),
35    Dict(HashMap<String, TorchScriptConstant>),
36}
37
38/// Tensor constant representation
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct TensorConstant {
41    pub shape: Vec<i64>,
42    pub dtype: String,
43    pub data: Vec<u8>, // Serialized tensor data
44}
45
46/// TorchScript parameter
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TorchScriptParameter {
49    pub name: String,
50    pub dtype: String,
51    pub shape: Vec<i64>,
52    pub requires_grad: bool,
53    pub is_buffer: bool,
54}
55
56/// TorchScript method
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TorchScriptMethod {
59    pub name: String,
60    pub code: String,
61    pub schema: MethodSchema,
62    pub graph: Option<TorchScriptGraph>,
63}
64
65/// Method schema for type information
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MethodSchema {
68    pub arguments: Vec<Argument>,
69    pub returns: Vec<Return>,
70}
71
72/// Method argument
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct Argument {
75    pub name: String,
76    pub arg_type: String,
77    pub default_value: Option<TorchScriptConstant>,
78}
79
80/// Method return type
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct Return {
83    pub name: Option<String>,
84    pub return_type: String,
85}
86
87/// TorchScript graph representation
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct TorchScriptGraph {
90    pub nodes: Vec<TorchScriptNode>,
91    pub inputs: Vec<String>,
92    pub outputs: Vec<String>,
93}
94
95/// TorchScript node
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct TorchScriptNode {
98    pub name: String,
99    pub op_type: String,
100    pub inputs: Vec<String>,
101    pub outputs: Vec<String>,
102    pub attributes: HashMap<String, TorchScriptConstant>,
103    pub source_range: Option<SourceRange>,
104}
105
106/// Source code location information
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SourceRange {
109    pub filename: String,
110    pub start_line: u32,
111    pub start_col: u32,
112    pub end_line: u32,
113    pub end_col: u32,
114}
115
116/// TorchScript importer
117pub struct TorchScriptImporter {
118    operator_mapping: HashMap<String, String>,
119    #[allow(dead_code)]
120    type_mapping: HashMap<String, String>,
121}
122
123impl Default for TorchScriptImporter {
124    fn default() -> Self {
125        let mut operator_mapping = HashMap::new();
126
127        // Basic operators
128        operator_mapping.insert("aten::add".to_string(), "add".to_string());
129        operator_mapping.insert("aten::sub".to_string(), "sub".to_string());
130        operator_mapping.insert("aten::mul".to_string(), "mul".to_string());
131        operator_mapping.insert("aten::div".to_string(), "div".to_string());
132        operator_mapping.insert("aten::relu".to_string(), "relu".to_string());
133        operator_mapping.insert("aten::sigmoid".to_string(), "sigmoid".to_string());
134        operator_mapping.insert("aten::tanh".to_string(), "tanh".to_string());
135        operator_mapping.insert("aten::softmax".to_string(), "softmax".to_string());
136
137        // Linear algebra
138        operator_mapping.insert("aten::mm".to_string(), "matmul".to_string());
139        operator_mapping.insert("aten::bmm".to_string(), "batch_matmul".to_string());
140        operator_mapping.insert("aten::addmm".to_string(), "linear".to_string());
141
142        // Convolution
143        operator_mapping.insert("aten::conv2d".to_string(), "conv2d".to_string());
144        operator_mapping.insert("aten::conv1d".to_string(), "conv1d".to_string());
145        operator_mapping.insert("aten::conv3d".to_string(), "conv3d".to_string());
146
147        // Pooling
148        operator_mapping.insert("aten::max_pool2d".to_string(), "max_pool2d".to_string());
149        operator_mapping.insert("aten::avg_pool2d".to_string(), "avg_pool2d".to_string());
150        operator_mapping.insert(
151            "aten::adaptive_avg_pool2d".to_string(),
152            "adaptive_avg_pool2d".to_string(),
153        );
154
155        // Normalization
156        operator_mapping.insert("aten::batch_norm".to_string(), "batch_norm".to_string());
157        operator_mapping.insert("aten::layer_norm".to_string(), "layer_norm".to_string());
158        operator_mapping.insert("aten::group_norm".to_string(), "group_norm".to_string());
159
160        // Shape operations
161        operator_mapping.insert("aten::view".to_string(), "reshape".to_string());
162        operator_mapping.insert("aten::reshape".to_string(), "reshape".to_string());
163        operator_mapping.insert("aten::transpose".to_string(), "transpose".to_string());
164        operator_mapping.insert("aten::permute".to_string(), "permute".to_string());
165        operator_mapping.insert("aten::squeeze".to_string(), "squeeze".to_string());
166        operator_mapping.insert("aten::unsqueeze".to_string(), "unsqueeze".to_string());
167
168        let mut type_mapping = HashMap::new();
169        type_mapping.insert("Tensor".to_string(), "tensor".to_string());
170        type_mapping.insert("int".to_string(), "i64".to_string());
171        type_mapping.insert("float".to_string(), "f64".to_string());
172        type_mapping.insert("bool".to_string(), "bool".to_string());
173        type_mapping.insert("str".to_string(), "string".to_string());
174
175        Self {
176            operator_mapping,
177            type_mapping,
178        }
179    }
180}
181
182impl TorchScriptImporter {
183    pub fn new() -> Self {
184        Self::default()
185    }
186
187    /// Import a TorchScript model into an FX graph
188    pub fn import_model(&self, model: &TorchScriptModel) -> Result<FxGraph> {
189        if let Some(forward_method) = model.methods.iter().find(|m| m.name == "forward") {
190            if let Some(graph) = &forward_method.graph {
191                self.import_graph(graph)
192            } else {
193                // Parse from code if no graph is available
194                self.parse_code_to_graph(&forward_method.code)
195            }
196        } else {
197            Err(torsh_core::error::TorshError::InvalidArgument(
198                "No forward method found in TorchScript model".to_string(),
199            ))
200        }
201    }
202
203    /// Import a TorchScript graph into an FX graph
204    pub fn import_graph(&self, ts_graph: &TorchScriptGraph) -> Result<FxGraph> {
205        let mut fx_graph = FxGraph::new();
206        let mut node_mapping = HashMap::new();
207        let mut value_to_node = HashMap::new();
208
209        // Create input nodes
210        for input_name in &ts_graph.inputs {
211            let node = fx_graph.graph.add_node(Node::Input(input_name.clone()));
212            fx_graph.inputs.push(node);
213            value_to_node.insert(input_name.clone(), node);
214        }
215
216        // Process TorchScript nodes in topological order
217        for ts_node in &ts_graph.nodes {
218            let fx_node = self.convert_torchscript_node(ts_node)?;
219            let node_idx = fx_graph.graph.add_node(fx_node);
220            node_mapping.insert(ts_node.name.clone(), node_idx);
221
222            // Map outputs to this node
223            for output in &ts_node.outputs {
224                value_to_node.insert(output.clone(), node_idx);
225            }
226        }
227
228        // Create output nodes
229        for output_name in &ts_graph.outputs {
230            let output_node = fx_graph.graph.add_node(Node::Output);
231            fx_graph.outputs.push(output_node);
232
233            // Connect to the node that produces this output
234            if let Some(&producer_node) = value_to_node.get(output_name) {
235                fx_graph.graph.add_edge(
236                    producer_node,
237                    output_node,
238                    Edge {
239                        name: output_name.clone(),
240                    },
241                );
242            }
243        }
244
245        // Create edges between nodes
246        for ts_node in &ts_graph.nodes {
247            if let Some(&target_node) = node_mapping.get(&ts_node.name) {
248                for input_name in &ts_node.inputs {
249                    if let Some(&source_node) = value_to_node.get(input_name) {
250                        if source_node != target_node {
251                            fx_graph.graph.add_edge(
252                                source_node,
253                                target_node,
254                                Edge {
255                                    name: input_name.clone(),
256                                },
257                            );
258                        }
259                    }
260                }
261            }
262        }
263
264        Ok(fx_graph)
265    }
266
267    fn convert_torchscript_node(&self, ts_node: &TorchScriptNode) -> Result<Node> {
268        let op_name = self
269            .operator_mapping
270            .get(&ts_node.op_type)
271            .unwrap_or(&ts_node.op_type)
272            .clone();
273
274        // Handle special cases
275        match ts_node.op_type.as_str() {
276            "prim::Constant" => {
277                // Constants become inputs for now
278                let node_name = &ts_node.name;
279                Ok(Node::Input(format!("constant_{node_name}")))
280            }
281            "prim::If" => Ok(Node::Conditional {
282                condition: ts_node
283                    .inputs
284                    .first()
285                    .unwrap_or(&"condition".to_string())
286                    .clone(),
287                then_branch: vec!["true_branch".to_string()],
288                else_branch: vec!["false_branch".to_string()],
289            }),
290            "prim::Loop" => Ok(Node::Loop {
291                condition: ts_node
292                    .inputs
293                    .first()
294                    .unwrap_or(&"condition".to_string())
295                    .clone(),
296                body: vec!["loop_body".to_string()],
297                loop_vars: ts_node.inputs.iter().skip(1).cloned().collect(),
298            }),
299            "prim::GetAttr" => {
300                let attr_name = ts_node
301                    .attributes
302                    .get("name")
303                    .and_then(|v| {
304                        if let TorchScriptConstant::String(s) = v {
305                            Some(s.clone())
306                        } else {
307                            None
308                        }
309                    })
310                    .unwrap_or_else(|| "attr".to_string());
311
312                Ok(Node::GetAttr {
313                    target: ts_node
314                        .inputs
315                        .first()
316                        .unwrap_or(&"self".to_string())
317                        .clone(),
318                    attr: attr_name,
319                })
320            }
321            _ => Ok(Node::Call(op_name, ts_node.inputs.clone())),
322        }
323    }
324
325    fn parse_code_to_graph(&self, _code: &str) -> Result<FxGraph> {
326        // This would require a full TorchScript parser
327        // For now, return a simple placeholder graph
328        let mut graph = FxGraph::new();
329        let input = graph.graph.add_node(Node::Input("input".to_string()));
330        let output = graph.graph.add_node(Node::Output);
331
332        graph.graph.add_edge(
333            input,
334            output,
335            Edge {
336                name: "passthrough".to_string(),
337            },
338        );
339        graph.inputs = vec![input];
340        graph.outputs = vec![output];
341
342        Ok(graph)
343    }
344
345    /// Add custom operator mapping
346    pub fn add_operator_mapping(&mut self, torchscript_op: String, fx_op: String) {
347        self.operator_mapping.insert(torchscript_op, fx_op);
348    }
349}
350
351/// TorchScript exporter
352pub struct TorchScriptExporter {
353    operator_mapping: HashMap<String, String>,
354    export_parameters: bool,
355    optimization_level: OptimizationLevel,
356}
357
358#[derive(Debug, Clone, Copy, PartialEq, Eq)]
359pub enum OptimizationLevel {
360    None,
361    Basic,
362    Aggressive,
363}
364
365impl Default for TorchScriptExporter {
366    fn default() -> Self {
367        let mut operator_mapping = HashMap::new();
368
369        // Reverse mapping from FX to TorchScript
370        operator_mapping.insert("add".to_string(), "aten::add".to_string());
371        operator_mapping.insert("sub".to_string(), "aten::sub".to_string());
372        operator_mapping.insert("mul".to_string(), "aten::mul".to_string());
373        operator_mapping.insert("div".to_string(), "aten::div".to_string());
374        operator_mapping.insert("relu".to_string(), "aten::relu".to_string());
375        operator_mapping.insert("sigmoid".to_string(), "aten::sigmoid".to_string());
376        operator_mapping.insert("tanh".to_string(), "aten::tanh".to_string());
377        operator_mapping.insert("softmax".to_string(), "aten::softmax".to_string());
378        operator_mapping.insert("matmul".to_string(), "aten::mm".to_string());
379        operator_mapping.insert("conv2d".to_string(), "aten::conv2d".to_string());
380        operator_mapping.insert("max_pool2d".to_string(), "aten::max_pool2d".to_string());
381        operator_mapping.insert("avg_pool2d".to_string(), "aten::avg_pool2d".to_string());
382        operator_mapping.insert("batch_norm".to_string(), "aten::batch_norm".to_string());
383        operator_mapping.insert("reshape".to_string(), "aten::view".to_string());
384        operator_mapping.insert("transpose".to_string(), "aten::transpose".to_string());
385        operator_mapping.insert("permute".to_string(), "aten::permute".to_string());
386
387        Self {
388            operator_mapping,
389            export_parameters: true,
390            optimization_level: OptimizationLevel::Basic,
391        }
392    }
393}
394
395impl TorchScriptExporter {
396    pub fn new() -> Self {
397        Self::default()
398    }
399
400    pub fn with_parameters(mut self, export_parameters: bool) -> Self {
401        self.export_parameters = export_parameters;
402        self
403    }
404
405    pub fn with_optimization_level(mut self, level: OptimizationLevel) -> Self {
406        self.optimization_level = level;
407        self
408    }
409
410    /// Export an FX graph to TorchScript model
411    pub fn export_model(&self, graph: &FxGraph, model_name: &str) -> Result<TorchScriptModel> {
412        let torchscript_graph = self.export_graph(graph)?;
413        let forward_method = self.create_forward_method(&torchscript_graph)?;
414
415        let model = TorchScriptModel {
416            name: model_name.to_string(),
417            version: "1.0".to_string(),
418            producer_name: "torsh-fx".to_string(),
419            code: self.generate_torchscript_code(&torchscript_graph)?,
420            constants: HashMap::new(),
421            parameters: if self.export_parameters {
422                self.extract_parameters(graph)?
423            } else {
424                Vec::new()
425            },
426            methods: vec![forward_method],
427            metadata: HashMap::new(),
428        };
429
430        Ok(model)
431    }
432
433    /// Export an FX graph to TorchScript graph
434    pub fn export_graph(&self, fx_graph: &FxGraph) -> Result<TorchScriptGraph> {
435        let mut nodes = Vec::new();
436        let mut inputs = Vec::new();
437        let mut outputs = Vec::new();
438        let mut node_name_counter = 0;
439        let mut value_names = HashMap::new();
440
441        // Process input nodes
442        for &input_idx in &fx_graph.inputs {
443            if let Some(node) = fx_graph.get_node(input_idx) {
444                if let Node::Input(input_name) = node {
445                    inputs.push(input_name.clone());
446                    value_names.insert(input_idx, input_name.clone());
447                }
448            }
449        }
450
451        // Process all nodes in topological order
452        let mut visited = std::collections::HashSet::new();
453        let mut queue = VecDeque::new();
454
455        // Start with input nodes
456        for &input_idx in &fx_graph.inputs {
457            queue.push_back(input_idx);
458        }
459
460        while let Some(current_idx) = queue.pop_front() {
461            if visited.contains(&current_idx) {
462                continue;
463            }
464            visited.insert(current_idx);
465
466            if let Some(node) = fx_graph.get_node(current_idx) {
467                if !matches!(node, Node::Input(_)) {
468                    let ts_node = self.convert_fx_node(
469                        node,
470                        current_idx,
471                        &mut node_name_counter,
472                        &value_names,
473                    )?;
474
475                    // Update value names with outputs
476                    for output in &ts_node.outputs {
477                        value_names.insert(current_idx, output.clone());
478                    }
479
480                    nodes.push(ts_node);
481                }
482
483                // Add successors to queue
484                for edge_ref in fx_graph
485                    .graph
486                    .edges_directed(current_idx, petgraph::Direction::Outgoing)
487                {
488                    queue.push_back(edge_ref.target());
489                }
490            }
491        }
492
493        // Process output nodes
494        for &output_idx in &fx_graph.outputs {
495            // Find the input to this output node
496            for edge_ref in fx_graph
497                .graph
498                .edges_directed(output_idx, petgraph::Direction::Incoming)
499            {
500                let source_idx = edge_ref.source();
501                if let Some(output_name) = value_names.get(&source_idx) {
502                    outputs.push(output_name.clone());
503                    break;
504                }
505            }
506        }
507
508        Ok(TorchScriptGraph {
509            nodes,
510            inputs,
511            outputs,
512        })
513    }
514
515    fn convert_fx_node(
516        &self,
517        fx_node: &Node,
518        _node_idx: NodeIndex,
519        name_counter: &mut usize,
520        _value_names: &HashMap<NodeIndex, String>,
521    ) -> Result<TorchScriptNode> {
522        let counter = *name_counter;
523        let node_name = format!("node_{counter}");
524        *name_counter += 1;
525
526        match fx_node {
527            Node::Call(op_name, args) => {
528                let ts_op_type = self
529                    .operator_mapping
530                    .get(op_name)
531                    .unwrap_or(op_name)
532                    .clone();
533
534                Ok(TorchScriptNode {
535                    name: node_name.clone(),
536                    op_type: ts_op_type,
537                    inputs: args.clone(),
538                    outputs: vec![format!("{node_name}_output")],
539                    attributes: HashMap::new(),
540                    source_range: None,
541                })
542            }
543
544            Node::Conditional {
545                condition,
546                then_branch,
547                else_branch,
548            } => {
549                let mut attributes = HashMap::new();
550                attributes.insert(
551                    "then_block".to_string(),
552                    TorchScriptConstant::List(
553                        then_branch
554                            .iter()
555                            .map(|s| TorchScriptConstant::String(s.clone()))
556                            .collect(),
557                    ),
558                );
559                attributes.insert(
560                    "else_block".to_string(),
561                    TorchScriptConstant::List(
562                        else_branch
563                            .iter()
564                            .map(|s| TorchScriptConstant::String(s.clone()))
565                            .collect(),
566                    ),
567                );
568
569                Ok(TorchScriptNode {
570                    name: node_name.clone(),
571                    op_type: "prim::If".to_string(),
572                    inputs: vec![condition.clone()],
573                    outputs: vec![format!("{node_name}_output")],
574                    attributes,
575                    source_range: None,
576                })
577            }
578
579            Node::Loop {
580                condition,
581                body,
582                loop_vars,
583            } => {
584                let mut attributes = HashMap::new();
585                attributes.insert(
586                    "body".to_string(),
587                    TorchScriptConstant::List(
588                        body.iter()
589                            .map(|s| TorchScriptConstant::String(s.clone()))
590                            .collect(),
591                    ),
592                );
593
594                let mut inputs = vec![condition.clone()];
595                inputs.extend(loop_vars.iter().cloned());
596
597                Ok(TorchScriptNode {
598                    name: node_name.clone(),
599                    op_type: "prim::Loop".to_string(),
600                    inputs,
601                    outputs: vec![format!("{node_name}_output")],
602                    attributes,
603                    source_range: None,
604                })
605            }
606
607            Node::GetAttr { target, attr } => {
608                let mut attributes = HashMap::new();
609                attributes.insert(
610                    "name".to_string(),
611                    TorchScriptConstant::String(attr.clone()),
612                );
613
614                Ok(TorchScriptNode {
615                    name: node_name.clone(),
616                    op_type: "prim::GetAttr".to_string(),
617                    inputs: vec![target.clone()],
618                    outputs: vec![format!("{node_name}_output")],
619                    attributes,
620                    source_range: None,
621                })
622            }
623
624            Node::Merge { inputs } => Ok(TorchScriptNode {
625                name: node_name.clone(),
626                op_type: "prim::TupleConstruct".to_string(),
627                inputs: inputs.clone(),
628                outputs: vec![format!("{}_output", node_name)],
629                attributes: HashMap::new(),
630                source_range: None,
631            }),
632
633            _ => Ok(TorchScriptNode {
634                name: node_name.clone(),
635                op_type: "prim::Constant".to_string(),
636                inputs: Vec::new(),
637                outputs: vec![format!("{}_output", node_name)],
638                attributes: HashMap::new(),
639                source_range: None,
640            }),
641        }
642    }
643
644    fn create_forward_method(&self, graph: &TorchScriptGraph) -> Result<TorchScriptMethod> {
645        let arguments = graph
646            .inputs
647            .iter()
648            .map(|input| Argument {
649                name: input.clone(),
650                arg_type: "Tensor".to_string(),
651                default_value: None,
652            })
653            .collect();
654
655        let returns = graph
656            .outputs
657            .iter()
658            .map(|output| Return {
659                name: Some(output.clone()),
660                return_type: "Tensor".to_string(),
661            })
662            .collect();
663
664        let schema = MethodSchema { arguments, returns };
665
666        Ok(TorchScriptMethod {
667            name: "forward".to_string(),
668            code: self.generate_torchscript_code(graph)?,
669            schema,
670            graph: Some(graph.clone()),
671        })
672    }
673
674    fn generate_torchscript_code(&self, graph: &TorchScriptGraph) -> Result<String> {
675        let mut code = String::new();
676
677        // Function signature
678        code.push_str("def forward(self");
679        for input in &graph.inputs {
680            code.push_str(&format!(", {}: Tensor", input));
681        }
682        code.push_str(") -> ");
683
684        if graph.outputs.len() == 1 {
685            code.push_str("Tensor");
686        } else {
687            code.push_str(&format!(
688                "Tuple[{}]",
689                vec!["Tensor"; graph.outputs.len()].join(", ")
690            ));
691        }
692        code.push_str(":\n");
693
694        // Function body
695        for node in &graph.nodes {
696            code.push_str(&self.generate_node_code(node)?);
697            code.push('\n');
698        }
699
700        // Return statement
701        if graph.outputs.len() == 1 {
702            code.push_str(&format!("    return {}\n", graph.outputs[0]));
703        } else {
704            code.push_str(&format!("    return ({})\n", graph.outputs.join(", ")));
705        }
706
707        Ok(code)
708    }
709
710    fn generate_node_code(&self, node: &TorchScriptNode) -> Result<String> {
711        let indent = "    ";
712
713        match node.op_type.as_str() {
714            "aten::add" => Ok(format!(
715                "{}{} = {} + {}",
716                indent,
717                node.outputs[0],
718                node.inputs.get(0).unwrap_or(&"input1".to_string()),
719                node.inputs.get(1).unwrap_or(&"input2".to_string())
720            )),
721
722            "aten::relu" => Ok(format!(
723                "{}{} = torch.relu({})",
724                indent,
725                node.outputs[0],
726                node.inputs.get(0).unwrap_or(&"input".to_string())
727            )),
728
729            "aten::mm" => Ok(format!(
730                "{}{} = torch.mm({}, {})",
731                indent,
732                node.outputs[0],
733                node.inputs.get(0).unwrap_or(&"input1".to_string()),
734                node.inputs.get(1).unwrap_or(&"input2".to_string())
735            )),
736
737            _ => Ok(format!(
738                "{}{} = {}({})",
739                indent,
740                node.outputs[0],
741                node.op_type,
742                node.inputs.join(", ")
743            )),
744        }
745    }
746
747    fn extract_parameters(&self, _graph: &FxGraph) -> Result<Vec<TorchScriptParameter>> {
748        // This would require analysis of the graph to identify learnable parameters
749        // For now, return an empty list
750        Ok(Vec::new())
751    }
752
753    /// Add custom operator mapping  
754    pub fn add_operator_mapping(&mut self, fx_op: String, torchscript_op: String) {
755        self.operator_mapping.insert(fx_op, torchscript_op);
756    }
757}
758
759/// Utility functions for TorchScript compatibility
760pub mod utils {
761    use super::*;
762
763    /// Load a TorchScript model from file
764    pub fn load_torchscript_model(path: &str) -> Result<TorchScriptModel> {
765        let content = std::fs::read_to_string(path)
766            .map_err(|e| torsh_core::error::TorshError::IoError(e.to_string()))?;
767
768        serde_json::from_str(&content)
769            .map_err(|e| torsh_core::error::TorshError::SerializationError(e.to_string()))
770    }
771
772    /// Save a TorchScript model to file
773    pub fn save_torchscript_model(model: &TorchScriptModel, path: &str) -> Result<()> {
774        let content = serde_json::to_string_pretty(model)
775            .map_err(|e| torsh_core::error::TorshError::SerializationError(e.to_string()))?;
776
777        std::fs::write(path, content)
778            .map_err(|e| torsh_core::error::TorshError::IoError(e.to_string()))
779    }
780
781    /// Convert FX graph to TorchScript and back for validation
782    pub fn validate_roundtrip(graph: &FxGraph) -> Result<bool> {
783        let exporter = TorchScriptExporter::new();
784        let model = exporter.export_model(graph, "test_model")?;
785
786        let importer = TorchScriptImporter::new();
787        let reconstructed = importer.import_model(&model)?;
788
789        // Simple validation - check node counts
790        Ok(graph.node_count() == reconstructed.node_count()
791            && graph.edge_count() == reconstructed.edge_count())
792    }
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798    use crate::{Edge, FxGraph, Node};
799
800    #[test]
801    fn test_torchscript_import_basic() {
802        let ts_graph = TorchScriptGraph {
803            nodes: vec![TorchScriptNode {
804                name: "node_0".to_string(),
805                op_type: "aten::relu".to_string(),
806                inputs: vec!["input".to_string()],
807                outputs: vec!["relu_out".to_string()],
808                attributes: HashMap::new(),
809                source_range: None,
810            }],
811            inputs: vec!["input".to_string()],
812            outputs: vec!["relu_out".to_string()],
813        };
814
815        let importer = TorchScriptImporter::new();
816        let fx_graph = importer.import_graph(&ts_graph).unwrap();
817
818        assert_eq!(fx_graph.inputs.len(), 1);
819        assert_eq!(fx_graph.outputs.len(), 1);
820        assert!(fx_graph.node_count() >= 3); // input, relu, output
821    }
822
823    #[test]
824    fn test_torchscript_export_basic() {
825        let mut graph = FxGraph::new();
826        let input = graph.graph.add_node(Node::Input("x".to_string()));
827        let relu = graph
828            .graph
829            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
830        let output = graph.graph.add_node(Node::Output);
831
832        graph.graph.add_edge(
833            input,
834            relu,
835            Edge {
836                name: "x".to_string(),
837            },
838        );
839        graph.graph.add_edge(
840            relu,
841            output,
842            Edge {
843                name: "relu_out".to_string(),
844            },
845        );
846        graph.inputs = vec![input];
847        graph.outputs = vec![output];
848
849        let exporter = TorchScriptExporter::new();
850        let ts_graph = exporter.export_graph(&graph).unwrap();
851
852        assert!(!ts_graph.inputs.is_empty());
853        assert!(!ts_graph.outputs.is_empty());
854        assert!(!ts_graph.nodes.is_empty());
855
856        // Check that relu was converted to aten::relu
857        assert!(ts_graph
858            .nodes
859            .iter()
860            .any(|node| node.op_type == "aten::relu"));
861    }
862
863    #[test]
864    fn test_torchscript_roundtrip() {
865        let mut graph = FxGraph::new();
866        let input1 = graph.graph.add_node(Node::Input("x".to_string()));
867        let input2 = graph.graph.add_node(Node::Input("y".to_string()));
868        let add = graph.graph.add_node(Node::Call(
869            "add".to_string(),
870            vec!["x".to_string(), "y".to_string()],
871        ));
872        let relu = graph
873            .graph
874            .add_node(Node::Call("relu".to_string(), vec!["add_out".to_string()]));
875        let output = graph.graph.add_node(Node::Output);
876
877        graph.graph.add_edge(
878            input1,
879            add,
880            Edge {
881                name: "x".to_string(),
882            },
883        );
884        graph.graph.add_edge(
885            input2,
886            add,
887            Edge {
888                name: "y".to_string(),
889            },
890        );
891        graph.graph.add_edge(
892            add,
893            relu,
894            Edge {
895                name: "add_out".to_string(),
896            },
897        );
898        graph.graph.add_edge(
899            relu,
900            output,
901            Edge {
902                name: "relu_out".to_string(),
903            },
904        );
905
906        graph.inputs = vec![input1, input2];
907        graph.outputs = vec![output];
908
909        // Export to TorchScript
910        let exporter = TorchScriptExporter::new();
911        let model = exporter.export_model(&graph, "test_model").unwrap();
912
913        // Import back to FX
914        let importer = TorchScriptImporter::new();
915        let reconstructed = importer.import_model(&model).unwrap();
916
917        // Basic validation
918        assert_eq!(graph.inputs.len(), reconstructed.inputs.len());
919        assert_eq!(graph.outputs.len(), reconstructed.outputs.len());
920    }
921
922    #[test]
923    fn test_torchscript_code_generation() {
924        let ts_graph = TorchScriptGraph {
925            nodes: vec![
926                TorchScriptNode {
927                    name: "add_node".to_string(),
928                    op_type: "aten::add".to_string(),
929                    inputs: vec!["x".to_string(), "y".to_string()],
930                    outputs: vec!["add_out".to_string()],
931                    attributes: HashMap::new(),
932                    source_range: None,
933                },
934                TorchScriptNode {
935                    name: "relu_node".to_string(),
936                    op_type: "aten::relu".to_string(),
937                    inputs: vec!["add_out".to_string()],
938                    outputs: vec!["result".to_string()],
939                    attributes: HashMap::new(),
940                    source_range: None,
941                },
942            ],
943            inputs: vec!["x".to_string(), "y".to_string()],
944            outputs: vec!["result".to_string()],
945        };
946
947        let exporter = TorchScriptExporter::new();
948        let code = exporter.generate_torchscript_code(&ts_graph).unwrap();
949
950        assert!(code.contains("def forward(self, x: Tensor, y: Tensor) -> Tensor"));
951        assert!(code.contains("add_out = x + y"));
952        assert!(code.contains("result = torch.relu(add_out)"));
953        assert!(code.contains("return result"));
954    }
955}