1use crate::{Edge, FxGraph, Node};
7use petgraph::graph::NodeIndex;
8use petgraph::visit::EdgeRef;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use torsh_core::Result;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TorchScriptModel {
16 pub name: String,
17 pub version: String,
18 pub producer_name: String,
19 pub code: String,
20 pub constants: HashMap<String, TorchScriptConstant>,
21 pub parameters: Vec<TorchScriptParameter>,
22 pub methods: Vec<TorchScriptMethod>,
23 pub metadata: HashMap<String, String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub enum TorchScriptConstant {
29 Integer(i64),
30 Float(f64),
31 String(String),
32 Boolean(bool),
33 Tensor(TensorConstant),
34 List(Vec<TorchScriptConstant>),
35 Dict(HashMap<String, TorchScriptConstant>),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct TensorConstant {
41 pub shape: Vec<i64>,
42 pub dtype: String,
43 pub data: Vec<u8>, }
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct TorchScriptParameter {
49 pub name: String,
50 pub dtype: String,
51 pub shape: Vec<i64>,
52 pub requires_grad: bool,
53 pub is_buffer: bool,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TorchScriptMethod {
59 pub name: String,
60 pub code: String,
61 pub schema: MethodSchema,
62 pub graph: Option<TorchScriptGraph>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MethodSchema {
68 pub arguments: Vec<Argument>,
69 pub returns: Vec<Return>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct Argument {
75 pub name: String,
76 pub arg_type: String,
77 pub default_value: Option<TorchScriptConstant>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct Return {
83 pub name: Option<String>,
84 pub return_type: String,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct TorchScriptGraph {
90 pub nodes: Vec<TorchScriptNode>,
91 pub inputs: Vec<String>,
92 pub outputs: Vec<String>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct TorchScriptNode {
98 pub name: String,
99 pub op_type: String,
100 pub inputs: Vec<String>,
101 pub outputs: Vec<String>,
102 pub attributes: HashMap<String, TorchScriptConstant>,
103 pub source_range: Option<SourceRange>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SourceRange {
109 pub filename: String,
110 pub start_line: u32,
111 pub start_col: u32,
112 pub end_line: u32,
113 pub end_col: u32,
114}
115
116pub struct TorchScriptImporter {
118 operator_mapping: HashMap<String, String>,
119 #[allow(dead_code)]
120 type_mapping: HashMap<String, String>,
121}
122
123impl Default for TorchScriptImporter {
124 fn default() -> Self {
125 let mut operator_mapping = HashMap::new();
126
127 operator_mapping.insert("aten::add".to_string(), "add".to_string());
129 operator_mapping.insert("aten::sub".to_string(), "sub".to_string());
130 operator_mapping.insert("aten::mul".to_string(), "mul".to_string());
131 operator_mapping.insert("aten::div".to_string(), "div".to_string());
132 operator_mapping.insert("aten::relu".to_string(), "relu".to_string());
133 operator_mapping.insert("aten::sigmoid".to_string(), "sigmoid".to_string());
134 operator_mapping.insert("aten::tanh".to_string(), "tanh".to_string());
135 operator_mapping.insert("aten::softmax".to_string(), "softmax".to_string());
136
137 operator_mapping.insert("aten::mm".to_string(), "matmul".to_string());
139 operator_mapping.insert("aten::bmm".to_string(), "batch_matmul".to_string());
140 operator_mapping.insert("aten::addmm".to_string(), "linear".to_string());
141
142 operator_mapping.insert("aten::conv2d".to_string(), "conv2d".to_string());
144 operator_mapping.insert("aten::conv1d".to_string(), "conv1d".to_string());
145 operator_mapping.insert("aten::conv3d".to_string(), "conv3d".to_string());
146
147 operator_mapping.insert("aten::max_pool2d".to_string(), "max_pool2d".to_string());
149 operator_mapping.insert("aten::avg_pool2d".to_string(), "avg_pool2d".to_string());
150 operator_mapping.insert(
151 "aten::adaptive_avg_pool2d".to_string(),
152 "adaptive_avg_pool2d".to_string(),
153 );
154
155 operator_mapping.insert("aten::batch_norm".to_string(), "batch_norm".to_string());
157 operator_mapping.insert("aten::layer_norm".to_string(), "layer_norm".to_string());
158 operator_mapping.insert("aten::group_norm".to_string(), "group_norm".to_string());
159
160 operator_mapping.insert("aten::view".to_string(), "reshape".to_string());
162 operator_mapping.insert("aten::reshape".to_string(), "reshape".to_string());
163 operator_mapping.insert("aten::transpose".to_string(), "transpose".to_string());
164 operator_mapping.insert("aten::permute".to_string(), "permute".to_string());
165 operator_mapping.insert("aten::squeeze".to_string(), "squeeze".to_string());
166 operator_mapping.insert("aten::unsqueeze".to_string(), "unsqueeze".to_string());
167
168 let mut type_mapping = HashMap::new();
169 type_mapping.insert("Tensor".to_string(), "tensor".to_string());
170 type_mapping.insert("int".to_string(), "i64".to_string());
171 type_mapping.insert("float".to_string(), "f64".to_string());
172 type_mapping.insert("bool".to_string(), "bool".to_string());
173 type_mapping.insert("str".to_string(), "string".to_string());
174
175 Self {
176 operator_mapping,
177 type_mapping,
178 }
179 }
180}
181
182impl TorchScriptImporter {
183 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn import_model(&self, model: &TorchScriptModel) -> Result<FxGraph> {
189 if let Some(forward_method) = model.methods.iter().find(|m| m.name == "forward") {
190 if let Some(graph) = &forward_method.graph {
191 self.import_graph(graph)
192 } else {
193 self.parse_code_to_graph(&forward_method.code)
195 }
196 } else {
197 Err(torsh_core::error::TorshError::InvalidArgument(
198 "No forward method found in TorchScript model".to_string(),
199 ))
200 }
201 }
202
203 pub fn import_graph(&self, ts_graph: &TorchScriptGraph) -> Result<FxGraph> {
205 let mut fx_graph = FxGraph::new();
206 let mut node_mapping = HashMap::new();
207 let mut value_to_node = HashMap::new();
208
209 for input_name in &ts_graph.inputs {
211 let node = fx_graph.graph.add_node(Node::Input(input_name.clone()));
212 fx_graph.inputs.push(node);
213 value_to_node.insert(input_name.clone(), node);
214 }
215
216 for ts_node in &ts_graph.nodes {
218 let fx_node = self.convert_torchscript_node(ts_node)?;
219 let node_idx = fx_graph.graph.add_node(fx_node);
220 node_mapping.insert(ts_node.name.clone(), node_idx);
221
222 for output in &ts_node.outputs {
224 value_to_node.insert(output.clone(), node_idx);
225 }
226 }
227
228 for output_name in &ts_graph.outputs {
230 let output_node = fx_graph.graph.add_node(Node::Output);
231 fx_graph.outputs.push(output_node);
232
233 if let Some(&producer_node) = value_to_node.get(output_name) {
235 fx_graph.graph.add_edge(
236 producer_node,
237 output_node,
238 Edge {
239 name: output_name.clone(),
240 },
241 );
242 }
243 }
244
245 for ts_node in &ts_graph.nodes {
247 if let Some(&target_node) = node_mapping.get(&ts_node.name) {
248 for input_name in &ts_node.inputs {
249 if let Some(&source_node) = value_to_node.get(input_name) {
250 if source_node != target_node {
251 fx_graph.graph.add_edge(
252 source_node,
253 target_node,
254 Edge {
255 name: input_name.clone(),
256 },
257 );
258 }
259 }
260 }
261 }
262 }
263
264 Ok(fx_graph)
265 }
266
267 fn convert_torchscript_node(&self, ts_node: &TorchScriptNode) -> Result<Node> {
268 let op_name = self
269 .operator_mapping
270 .get(&ts_node.op_type)
271 .unwrap_or(&ts_node.op_type)
272 .clone();
273
274 match ts_node.op_type.as_str() {
276 "prim::Constant" => {
277 let node_name = &ts_node.name;
279 Ok(Node::Input(format!("constant_{node_name}")))
280 }
281 "prim::If" => Ok(Node::Conditional {
282 condition: ts_node
283 .inputs
284 .first()
285 .unwrap_or(&"condition".to_string())
286 .clone(),
287 then_branch: vec!["true_branch".to_string()],
288 else_branch: vec!["false_branch".to_string()],
289 }),
290 "prim::Loop" => Ok(Node::Loop {
291 condition: ts_node
292 .inputs
293 .first()
294 .unwrap_or(&"condition".to_string())
295 .clone(),
296 body: vec!["loop_body".to_string()],
297 loop_vars: ts_node.inputs.iter().skip(1).cloned().collect(),
298 }),
299 "prim::GetAttr" => {
300 let attr_name = ts_node
301 .attributes
302 .get("name")
303 .and_then(|v| {
304 if let TorchScriptConstant::String(s) = v {
305 Some(s.clone())
306 } else {
307 None
308 }
309 })
310 .unwrap_or_else(|| "attr".to_string());
311
312 Ok(Node::GetAttr {
313 target: ts_node
314 .inputs
315 .first()
316 .unwrap_or(&"self".to_string())
317 .clone(),
318 attr: attr_name,
319 })
320 }
321 _ => Ok(Node::Call(op_name, ts_node.inputs.clone())),
322 }
323 }
324
325 fn parse_code_to_graph(&self, _code: &str) -> Result<FxGraph> {
326 let mut graph = FxGraph::new();
329 let input = graph.graph.add_node(Node::Input("input".to_string()));
330 let output = graph.graph.add_node(Node::Output);
331
332 graph.graph.add_edge(
333 input,
334 output,
335 Edge {
336 name: "passthrough".to_string(),
337 },
338 );
339 graph.inputs = vec![input];
340 graph.outputs = vec![output];
341
342 Ok(graph)
343 }
344
345 pub fn add_operator_mapping(&mut self, torchscript_op: String, fx_op: String) {
347 self.operator_mapping.insert(torchscript_op, fx_op);
348 }
349}
350
351pub struct TorchScriptExporter {
353 operator_mapping: HashMap<String, String>,
354 export_parameters: bool,
355 optimization_level: OptimizationLevel,
356}
357
358#[derive(Debug, Clone, Copy, PartialEq, Eq)]
359pub enum OptimizationLevel {
360 None,
361 Basic,
362 Aggressive,
363}
364
365impl Default for TorchScriptExporter {
366 fn default() -> Self {
367 let mut operator_mapping = HashMap::new();
368
369 operator_mapping.insert("add".to_string(), "aten::add".to_string());
371 operator_mapping.insert("sub".to_string(), "aten::sub".to_string());
372 operator_mapping.insert("mul".to_string(), "aten::mul".to_string());
373 operator_mapping.insert("div".to_string(), "aten::div".to_string());
374 operator_mapping.insert("relu".to_string(), "aten::relu".to_string());
375 operator_mapping.insert("sigmoid".to_string(), "aten::sigmoid".to_string());
376 operator_mapping.insert("tanh".to_string(), "aten::tanh".to_string());
377 operator_mapping.insert("softmax".to_string(), "aten::softmax".to_string());
378 operator_mapping.insert("matmul".to_string(), "aten::mm".to_string());
379 operator_mapping.insert("conv2d".to_string(), "aten::conv2d".to_string());
380 operator_mapping.insert("max_pool2d".to_string(), "aten::max_pool2d".to_string());
381 operator_mapping.insert("avg_pool2d".to_string(), "aten::avg_pool2d".to_string());
382 operator_mapping.insert("batch_norm".to_string(), "aten::batch_norm".to_string());
383 operator_mapping.insert("reshape".to_string(), "aten::view".to_string());
384 operator_mapping.insert("transpose".to_string(), "aten::transpose".to_string());
385 operator_mapping.insert("permute".to_string(), "aten::permute".to_string());
386
387 Self {
388 operator_mapping,
389 export_parameters: true,
390 optimization_level: OptimizationLevel::Basic,
391 }
392 }
393}
394
395impl TorchScriptExporter {
396 pub fn new() -> Self {
397 Self::default()
398 }
399
400 pub fn with_parameters(mut self, export_parameters: bool) -> Self {
401 self.export_parameters = export_parameters;
402 self
403 }
404
405 pub fn with_optimization_level(mut self, level: OptimizationLevel) -> Self {
406 self.optimization_level = level;
407 self
408 }
409
410 pub fn export_model(&self, graph: &FxGraph, model_name: &str) -> Result<TorchScriptModel> {
412 let torchscript_graph = self.export_graph(graph)?;
413 let forward_method = self.create_forward_method(&torchscript_graph)?;
414
415 let model = TorchScriptModel {
416 name: model_name.to_string(),
417 version: "1.0".to_string(),
418 producer_name: "torsh-fx".to_string(),
419 code: self.generate_torchscript_code(&torchscript_graph)?,
420 constants: HashMap::new(),
421 parameters: if self.export_parameters {
422 self.extract_parameters(graph)?
423 } else {
424 Vec::new()
425 },
426 methods: vec![forward_method],
427 metadata: HashMap::new(),
428 };
429
430 Ok(model)
431 }
432
433 pub fn export_graph(&self, fx_graph: &FxGraph) -> Result<TorchScriptGraph> {
435 let mut nodes = Vec::new();
436 let mut inputs = Vec::new();
437 let mut outputs = Vec::new();
438 let mut node_name_counter = 0;
439 let mut value_names = HashMap::new();
440
441 for &input_idx in &fx_graph.inputs {
443 if let Some(node) = fx_graph.get_node(input_idx) {
444 if let Node::Input(input_name) = node {
445 inputs.push(input_name.clone());
446 value_names.insert(input_idx, input_name.clone());
447 }
448 }
449 }
450
451 let mut visited = std::collections::HashSet::new();
453 let mut queue = VecDeque::new();
454
455 for &input_idx in &fx_graph.inputs {
457 queue.push_back(input_idx);
458 }
459
460 while let Some(current_idx) = queue.pop_front() {
461 if visited.contains(¤t_idx) {
462 continue;
463 }
464 visited.insert(current_idx);
465
466 if let Some(node) = fx_graph.get_node(current_idx) {
467 if !matches!(node, Node::Input(_)) {
468 let ts_node = self.convert_fx_node(
469 node,
470 current_idx,
471 &mut node_name_counter,
472 &value_names,
473 )?;
474
475 for output in &ts_node.outputs {
477 value_names.insert(current_idx, output.clone());
478 }
479
480 nodes.push(ts_node);
481 }
482
483 for edge_ref in fx_graph
485 .graph
486 .edges_directed(current_idx, petgraph::Direction::Outgoing)
487 {
488 queue.push_back(edge_ref.target());
489 }
490 }
491 }
492
493 for &output_idx in &fx_graph.outputs {
495 for edge_ref in fx_graph
497 .graph
498 .edges_directed(output_idx, petgraph::Direction::Incoming)
499 {
500 let source_idx = edge_ref.source();
501 if let Some(output_name) = value_names.get(&source_idx) {
502 outputs.push(output_name.clone());
503 break;
504 }
505 }
506 }
507
508 Ok(TorchScriptGraph {
509 nodes,
510 inputs,
511 outputs,
512 })
513 }
514
515 fn convert_fx_node(
516 &self,
517 fx_node: &Node,
518 _node_idx: NodeIndex,
519 name_counter: &mut usize,
520 _value_names: &HashMap<NodeIndex, String>,
521 ) -> Result<TorchScriptNode> {
522 let counter = *name_counter;
523 let node_name = format!("node_{counter}");
524 *name_counter += 1;
525
526 match fx_node {
527 Node::Call(op_name, args) => {
528 let ts_op_type = self
529 .operator_mapping
530 .get(op_name)
531 .unwrap_or(op_name)
532 .clone();
533
534 Ok(TorchScriptNode {
535 name: node_name.clone(),
536 op_type: ts_op_type,
537 inputs: args.clone(),
538 outputs: vec![format!("{node_name}_output")],
539 attributes: HashMap::new(),
540 source_range: None,
541 })
542 }
543
544 Node::Conditional {
545 condition,
546 then_branch,
547 else_branch,
548 } => {
549 let mut attributes = HashMap::new();
550 attributes.insert(
551 "then_block".to_string(),
552 TorchScriptConstant::List(
553 then_branch
554 .iter()
555 .map(|s| TorchScriptConstant::String(s.clone()))
556 .collect(),
557 ),
558 );
559 attributes.insert(
560 "else_block".to_string(),
561 TorchScriptConstant::List(
562 else_branch
563 .iter()
564 .map(|s| TorchScriptConstant::String(s.clone()))
565 .collect(),
566 ),
567 );
568
569 Ok(TorchScriptNode {
570 name: node_name.clone(),
571 op_type: "prim::If".to_string(),
572 inputs: vec![condition.clone()],
573 outputs: vec![format!("{node_name}_output")],
574 attributes,
575 source_range: None,
576 })
577 }
578
579 Node::Loop {
580 condition,
581 body,
582 loop_vars,
583 } => {
584 let mut attributes = HashMap::new();
585 attributes.insert(
586 "body".to_string(),
587 TorchScriptConstant::List(
588 body.iter()
589 .map(|s| TorchScriptConstant::String(s.clone()))
590 .collect(),
591 ),
592 );
593
594 let mut inputs = vec![condition.clone()];
595 inputs.extend(loop_vars.iter().cloned());
596
597 Ok(TorchScriptNode {
598 name: node_name.clone(),
599 op_type: "prim::Loop".to_string(),
600 inputs,
601 outputs: vec![format!("{node_name}_output")],
602 attributes,
603 source_range: None,
604 })
605 }
606
607 Node::GetAttr { target, attr } => {
608 let mut attributes = HashMap::new();
609 attributes.insert(
610 "name".to_string(),
611 TorchScriptConstant::String(attr.clone()),
612 );
613
614 Ok(TorchScriptNode {
615 name: node_name.clone(),
616 op_type: "prim::GetAttr".to_string(),
617 inputs: vec![target.clone()],
618 outputs: vec![format!("{node_name}_output")],
619 attributes,
620 source_range: None,
621 })
622 }
623
624 Node::Merge { inputs } => Ok(TorchScriptNode {
625 name: node_name.clone(),
626 op_type: "prim::TupleConstruct".to_string(),
627 inputs: inputs.clone(),
628 outputs: vec![format!("{}_output", node_name)],
629 attributes: HashMap::new(),
630 source_range: None,
631 }),
632
633 _ => Ok(TorchScriptNode {
634 name: node_name.clone(),
635 op_type: "prim::Constant".to_string(),
636 inputs: Vec::new(),
637 outputs: vec![format!("{}_output", node_name)],
638 attributes: HashMap::new(),
639 source_range: None,
640 }),
641 }
642 }
643
644 fn create_forward_method(&self, graph: &TorchScriptGraph) -> Result<TorchScriptMethod> {
645 let arguments = graph
646 .inputs
647 .iter()
648 .map(|input| Argument {
649 name: input.clone(),
650 arg_type: "Tensor".to_string(),
651 default_value: None,
652 })
653 .collect();
654
655 let returns = graph
656 .outputs
657 .iter()
658 .map(|output| Return {
659 name: Some(output.clone()),
660 return_type: "Tensor".to_string(),
661 })
662 .collect();
663
664 let schema = MethodSchema { arguments, returns };
665
666 Ok(TorchScriptMethod {
667 name: "forward".to_string(),
668 code: self.generate_torchscript_code(graph)?,
669 schema,
670 graph: Some(graph.clone()),
671 })
672 }
673
674 fn generate_torchscript_code(&self, graph: &TorchScriptGraph) -> Result<String> {
675 let mut code = String::new();
676
677 code.push_str("def forward(self");
679 for input in &graph.inputs {
680 code.push_str(&format!(", {}: Tensor", input));
681 }
682 code.push_str(") -> ");
683
684 if graph.outputs.len() == 1 {
685 code.push_str("Tensor");
686 } else {
687 code.push_str(&format!(
688 "Tuple[{}]",
689 vec!["Tensor"; graph.outputs.len()].join(", ")
690 ));
691 }
692 code.push_str(":\n");
693
694 for node in &graph.nodes {
696 code.push_str(&self.generate_node_code(node)?);
697 code.push('\n');
698 }
699
700 if graph.outputs.len() == 1 {
702 code.push_str(&format!(" return {}\n", graph.outputs[0]));
703 } else {
704 code.push_str(&format!(" return ({})\n", graph.outputs.join(", ")));
705 }
706
707 Ok(code)
708 }
709
710 fn generate_node_code(&self, node: &TorchScriptNode) -> Result<String> {
711 let indent = " ";
712
713 match node.op_type.as_str() {
714 "aten::add" => Ok(format!(
715 "{}{} = {} + {}",
716 indent,
717 node.outputs[0],
718 node.inputs.get(0).unwrap_or(&"input1".to_string()),
719 node.inputs.get(1).unwrap_or(&"input2".to_string())
720 )),
721
722 "aten::relu" => Ok(format!(
723 "{}{} = torch.relu({})",
724 indent,
725 node.outputs[0],
726 node.inputs.get(0).unwrap_or(&"input".to_string())
727 )),
728
729 "aten::mm" => Ok(format!(
730 "{}{} = torch.mm({}, {})",
731 indent,
732 node.outputs[0],
733 node.inputs.get(0).unwrap_or(&"input1".to_string()),
734 node.inputs.get(1).unwrap_or(&"input2".to_string())
735 )),
736
737 _ => Ok(format!(
738 "{}{} = {}({})",
739 indent,
740 node.outputs[0],
741 node.op_type,
742 node.inputs.join(", ")
743 )),
744 }
745 }
746
747 fn extract_parameters(&self, _graph: &FxGraph) -> Result<Vec<TorchScriptParameter>> {
748 Ok(Vec::new())
751 }
752
753 pub fn add_operator_mapping(&mut self, fx_op: String, torchscript_op: String) {
755 self.operator_mapping.insert(fx_op, torchscript_op);
756 }
757}
758
759pub mod utils {
761 use super::*;
762
763 pub fn load_torchscript_model(path: &str) -> Result<TorchScriptModel> {
765 let content = std::fs::read_to_string(path)
766 .map_err(|e| torsh_core::error::TorshError::IoError(e.to_string()))?;
767
768 serde_json::from_str(&content)
769 .map_err(|e| torsh_core::error::TorshError::SerializationError(e.to_string()))
770 }
771
772 pub fn save_torchscript_model(model: &TorchScriptModel, path: &str) -> Result<()> {
774 let content = serde_json::to_string_pretty(model)
775 .map_err(|e| torsh_core::error::TorshError::SerializationError(e.to_string()))?;
776
777 std::fs::write(path, content)
778 .map_err(|e| torsh_core::error::TorshError::IoError(e.to_string()))
779 }
780
781 pub fn validate_roundtrip(graph: &FxGraph) -> Result<bool> {
783 let exporter = TorchScriptExporter::new();
784 let model = exporter.export_model(graph, "test_model")?;
785
786 let importer = TorchScriptImporter::new();
787 let reconstructed = importer.import_model(&model)?;
788
789 Ok(graph.node_count() == reconstructed.node_count()
791 && graph.edge_count() == reconstructed.edge_count())
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798 use crate::{Edge, FxGraph, Node};
799
800 #[test]
801 fn test_torchscript_import_basic() {
802 let ts_graph = TorchScriptGraph {
803 nodes: vec![TorchScriptNode {
804 name: "node_0".to_string(),
805 op_type: "aten::relu".to_string(),
806 inputs: vec!["input".to_string()],
807 outputs: vec!["relu_out".to_string()],
808 attributes: HashMap::new(),
809 source_range: None,
810 }],
811 inputs: vec!["input".to_string()],
812 outputs: vec!["relu_out".to_string()],
813 };
814
815 let importer = TorchScriptImporter::new();
816 let fx_graph = importer.import_graph(&ts_graph).unwrap();
817
818 assert_eq!(fx_graph.inputs.len(), 1);
819 assert_eq!(fx_graph.outputs.len(), 1);
820 assert!(fx_graph.node_count() >= 3); }
822
823 #[test]
824 fn test_torchscript_export_basic() {
825 let mut graph = FxGraph::new();
826 let input = graph.graph.add_node(Node::Input("x".to_string()));
827 let relu = graph
828 .graph
829 .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
830 let output = graph.graph.add_node(Node::Output);
831
832 graph.graph.add_edge(
833 input,
834 relu,
835 Edge {
836 name: "x".to_string(),
837 },
838 );
839 graph.graph.add_edge(
840 relu,
841 output,
842 Edge {
843 name: "relu_out".to_string(),
844 },
845 );
846 graph.inputs = vec![input];
847 graph.outputs = vec![output];
848
849 let exporter = TorchScriptExporter::new();
850 let ts_graph = exporter.export_graph(&graph).unwrap();
851
852 assert!(!ts_graph.inputs.is_empty());
853 assert!(!ts_graph.outputs.is_empty());
854 assert!(!ts_graph.nodes.is_empty());
855
856 assert!(ts_graph
858 .nodes
859 .iter()
860 .any(|node| node.op_type == "aten::relu"));
861 }
862
863 #[test]
864 fn test_torchscript_roundtrip() {
865 let mut graph = FxGraph::new();
866 let input1 = graph.graph.add_node(Node::Input("x".to_string()));
867 let input2 = graph.graph.add_node(Node::Input("y".to_string()));
868 let add = graph.graph.add_node(Node::Call(
869 "add".to_string(),
870 vec!["x".to_string(), "y".to_string()],
871 ));
872 let relu = graph
873 .graph
874 .add_node(Node::Call("relu".to_string(), vec!["add_out".to_string()]));
875 let output = graph.graph.add_node(Node::Output);
876
877 graph.graph.add_edge(
878 input1,
879 add,
880 Edge {
881 name: "x".to_string(),
882 },
883 );
884 graph.graph.add_edge(
885 input2,
886 add,
887 Edge {
888 name: "y".to_string(),
889 },
890 );
891 graph.graph.add_edge(
892 add,
893 relu,
894 Edge {
895 name: "add_out".to_string(),
896 },
897 );
898 graph.graph.add_edge(
899 relu,
900 output,
901 Edge {
902 name: "relu_out".to_string(),
903 },
904 );
905
906 graph.inputs = vec![input1, input2];
907 graph.outputs = vec![output];
908
909 let exporter = TorchScriptExporter::new();
911 let model = exporter.export_model(&graph, "test_model").unwrap();
912
913 let importer = TorchScriptImporter::new();
915 let reconstructed = importer.import_model(&model).unwrap();
916
917 assert_eq!(graph.inputs.len(), reconstructed.inputs.len());
919 assert_eq!(graph.outputs.len(), reconstructed.outputs.len());
920 }
921
922 #[test]
923 fn test_torchscript_code_generation() {
924 let ts_graph = TorchScriptGraph {
925 nodes: vec![
926 TorchScriptNode {
927 name: "add_node".to_string(),
928 op_type: "aten::add".to_string(),
929 inputs: vec!["x".to_string(), "y".to_string()],
930 outputs: vec!["add_out".to_string()],
931 attributes: HashMap::new(),
932 source_range: None,
933 },
934 TorchScriptNode {
935 name: "relu_node".to_string(),
936 op_type: "aten::relu".to_string(),
937 inputs: vec!["add_out".to_string()],
938 outputs: vec!["result".to_string()],
939 attributes: HashMap::new(),
940 source_range: None,
941 },
942 ],
943 inputs: vec!["x".to_string(), "y".to_string()],
944 outputs: vec!["result".to_string()],
945 };
946
947 let exporter = TorchScriptExporter::new();
948 let code = exporter.generate_torchscript_code(&ts_graph).unwrap();
949
950 assert!(code.contains("def forward(self, x: Tensor, y: Tensor) -> Tensor"));
951 assert!(code.contains("add_out = x + y"));
952 assert!(code.contains("result = torch.relu(add_out)"));
953 assert!(code.contains("return result"));
954 }
955}