1use crate::error::IrError;
33use crate::graph::{EinsumGraph, OpType};
34use std::fmt::Write as FmtWrite;
35
36#[derive(Clone, Debug)]
38pub struct OnnxExportOptions {
39 pub opset_version: i64,
41 pub include_metadata: bool,
43 pub producer_name: String,
45 pub model_version: i64,
47}
48
49impl Default for OnnxExportOptions {
50 fn default() -> Self {
51 Self {
52 opset_version: 13,
53 include_metadata: true,
54 producer_name: "TensorLogic".to_string(),
55 model_version: 1,
56 }
57 }
58}
59
60#[derive(Clone, Debug)]
62pub struct TorchScriptExportOptions {
63 pub include_types: bool,
65 pub include_comments: bool,
67 pub optimize_for_inference: bool,
69}
70
71impl Default for TorchScriptExportOptions {
72 fn default() -> Self {
73 Self {
74 include_types: true,
75 include_comments: true,
76 optimize_for_inference: false,
77 }
78 }
79}
80
81pub fn export_to_onnx_text(graph: &EinsumGraph) -> Result<String, IrError> {
106 export_to_onnx_text_with_options(graph, &OnnxExportOptions::default())
107}
108
109pub fn export_to_onnx_text_with_options(
111 graph: &EinsumGraph,
112 options: &OnnxExportOptions,
113) -> Result<String, IrError> {
114 let mut output = String::new();
115
116 writeln!(output, "# ONNX Model: TensorLogic Computation Graph")?;
118 writeln!(output, "# Producer: {}", options.producer_name)?;
119 writeln!(output, "# Model Version: {}", options.model_version)?;
120 writeln!(output)?;
121 writeln!(output, "ir_version: 7")?;
122 writeln!(output, "opset_import {{")?;
123 writeln!(output, " domain: \"\"")?;
124 writeln!(output, " version: {}", options.opset_version)?;
125 writeln!(output, "}}")?;
126 writeln!(output)?;
127
128 writeln!(output, "graph {{")?;
130 writeln!(output, " name: \"tensorlogic_graph\"")?;
131 writeln!(output)?;
132
133 writeln!(output, " # Inputs")?;
135 for &input_idx in &graph.inputs {
136 let tensor_name = &graph.tensors[input_idx];
137 writeln!(output, " input {{")?;
138 writeln!(output, " name: \"{}\"", tensor_name)?;
139 writeln!(output, " type {{")?;
140 writeln!(output, " tensor_type {{")?;
141 writeln!(output, " elem_type: 1 # FLOAT")?;
142 writeln!(output, " shape {{")?;
143 writeln!(output, " dim {{ dim_param: \"batch\" }}")?;
144 writeln!(output, " dim {{ dim_param: \"dynamic\" }}")?;
145 writeln!(output, " }}")?;
146 writeln!(output, " }}")?;
147 writeln!(output, " }}")?;
148 writeln!(output, " }}")?;
149 }
150 writeln!(output)?;
151
152 writeln!(output, " # Operations")?;
154 for (node_idx, node) in graph.nodes.iter().enumerate() {
155 export_node_to_onnx(&mut output, node, node_idx, graph)?;
156 }
157 writeln!(output)?;
158
159 writeln!(output, " # Outputs")?;
161 for &output_idx in &graph.outputs {
162 let tensor_name = &graph.tensors[output_idx];
163 writeln!(output, " output {{")?;
164 writeln!(output, " name: \"{}\"", tensor_name)?;
165 writeln!(output, " type {{")?;
166 writeln!(output, " tensor_type {{")?;
167 writeln!(output, " elem_type: 1 # FLOAT")?;
168 writeln!(output, " shape {{")?;
169 writeln!(output, " dim {{ dim_param: \"batch\" }}")?;
170 writeln!(output, " dim {{ dim_param: \"dynamic\" }}")?;
171 writeln!(output, " }}")?;
172 writeln!(output, " }}")?;
173 writeln!(output, " }}")?;
174 writeln!(output, " }}")?;
175 }
176
177 writeln!(output, "}}")?;
178
179 Ok(output)
180}
181
182fn export_node_to_onnx(
184 output: &mut String,
185 node: &crate::graph::EinsumNode,
186 node_idx: usize,
187 graph: &EinsumGraph,
188) -> Result<(), IrError> {
189 writeln!(output, " node {{")?;
190
191 for &input_idx in &node.inputs {
193 writeln!(output, " input: \"{}\"", graph.tensors[input_idx])?;
194 }
195
196 for &output_idx in &node.outputs {
198 writeln!(output, " output: \"{}\"", graph.tensors[output_idx])?;
199 }
200
201 let op_name = match &node.op {
203 OpType::Einsum { spec } => {
204 writeln!(output, " op_type: \"Einsum\"")?;
205 writeln!(output, " attribute {{")?;
206 writeln!(output, " name: \"equation\"")?;
207 writeln!(output, " s: \"{}\"", spec)?;
208 writeln!(output, " type: STRING")?;
209 writeln!(output, " }}")?;
210 "Einsum"
211 }
212 OpType::ElemBinary { op } => {
213 let onnx_op = match op.as_str() {
214 "add" => "Add",
215 "sub" => "Sub",
216 "mul" => "Mul",
217 "div" => "Div",
218 _ => "Unknown",
219 };
220 writeln!(output, " op_type: \"{}\"", onnx_op)?;
221 onnx_op
222 }
223 OpType::ElemUnary { op } => {
224 let onnx_op = match op.as_str() {
225 "neg" => "Neg",
226 "exp" => "Exp",
227 "log" => "Log",
228 "relu" => "Relu",
229 "sigmoid" => "Sigmoid",
230 "tanh" => "Tanh",
231 _ => "Unknown",
232 };
233 writeln!(output, " op_type: \"{}\"", onnx_op)?;
234 onnx_op
235 }
236 OpType::Reduce { op, axes } => {
237 let onnx_op = match op.as_str() {
238 "sum" => "ReduceSum",
239 "max" => "ReduceMax",
240 "min" => "ReduceMin",
241 "mean" => "ReduceMean",
242 "prod" => "ReduceProd",
243 _ => "Unknown",
244 };
245 writeln!(output, " op_type: \"{}\"", onnx_op)?;
246 if !axes.is_empty() {
247 writeln!(output, " attribute {{")?;
248 writeln!(output, " name: \"axes\"")?;
249 write!(output, " ints: ")?;
250 for (i, axis) in axes.iter().enumerate() {
251 if i > 0 {
252 write!(output, ", ")?;
253 }
254 write!(output, "{}", axis)?;
255 }
256 writeln!(output)?;
257 writeln!(output, " type: INTS")?;
258 writeln!(output, " }}")?;
259 }
260 onnx_op
261 }
262 };
263
264 writeln!(output, " name: \"node_{}\"", node_idx)?;
265 writeln!(output, " doc_string: \"{} operation\"", op_name)?;
266 writeln!(output, " }}")?;
267
268 Ok(())
269}
270
271pub fn export_to_torchscript_text(graph: &EinsumGraph) -> Result<String, IrError> {
294 export_to_torchscript_text_with_options(graph, &TorchScriptExportOptions::default())
295}
296
297pub fn export_to_torchscript_text_with_options(
299 graph: &EinsumGraph,
300 options: &TorchScriptExportOptions,
301) -> Result<String, IrError> {
302 let mut output = String::new();
303
304 if options.include_comments {
306 writeln!(
307 output,
308 "# TorchScript representation of TensorLogic computation graph"
309 )?;
310 writeln!(output, "# Generated by TensorLogic IR")?;
311 writeln!(output)?;
312 }
313
314 writeln!(output, "import torch")?;
315 writeln!(output, "import torch.nn as nn")?;
316 writeln!(output)?;
317
318 writeln!(output, "class TensorLogicGraph(nn.Module):")?;
320 writeln!(output, " def __init__(self):")?;
321 writeln!(output, " super(TensorLogicGraph, self).__init__()")?;
322 writeln!(output)?;
323
324 write!(output, " def forward(self")?;
326
327 for &input_idx in &graph.inputs {
329 write!(output, ", {}", graph.tensors[input_idx])?;
330 }
331 writeln!(output, "):")?;
332
333 if options.include_comments {
334 writeln!(output, " # Computation graph")?;
335 }
336
337 for node in &graph.nodes {
339 export_node_to_torchscript(&mut output, node, graph, options)?;
340 }
341
342 writeln!(output)?;
344 write!(output, " return ")?;
345 if graph.outputs.len() == 1 {
346 writeln!(output, "{}", graph.tensors[graph.outputs[0]])?;
347 } else {
348 write!(output, "(")?;
349 for (i, &output_idx) in graph.outputs.iter().enumerate() {
350 if i > 0 {
351 write!(output, ", ")?;
352 }
353 write!(output, "{}", graph.tensors[output_idx])?;
354 }
355 writeln!(output, ")")?;
356 }
357
358 Ok(output)
359}
360
361fn export_node_to_torchscript(
363 output: &mut String,
364 node: &crate::graph::EinsumNode,
365 graph: &EinsumGraph,
366 options: &TorchScriptExportOptions,
367) -> Result<(), IrError> {
368 let output_tensor = graph.tensors[node.outputs[0]].clone();
369
370 match &node.op {
371 OpType::Einsum { spec } => {
372 write!(
373 output,
374 " {} = torch.einsum('{}', ",
375 output_tensor, spec
376 )?;
377 for (i, &input_idx) in node.inputs.iter().enumerate() {
378 if i > 0 {
379 write!(output, ", ")?;
380 }
381 write!(output, "{}", graph.tensors[input_idx])?;
382 }
383 writeln!(output, ")")?;
384 }
385 OpType::ElemBinary { op } => {
386 let input_tensors = &node.inputs;
387 let torch_op = match op.as_str() {
388 "add" => "torch.add",
389 "sub" => "torch.sub",
390 "mul" => "torch.mul",
391 "div" => "torch.div",
392 _ => "torch.unknown",
393 };
394
395 if options.include_comments {
396 writeln!(output, " # Element-wise binary operation: {}", op)?;
397 }
398
399 writeln!(
400 output,
401 " {} = {}({}, {})",
402 output_tensor,
403 torch_op,
404 graph.tensors[input_tensors[0]],
405 graph.tensors[input_tensors[1]]
406 )?;
407 }
408 OpType::ElemUnary { op } => {
409 let input_tensor = graph.tensors[node.inputs[0]].clone();
410 let torch_op = match op.as_str() {
411 "neg" => "torch.neg",
412 "exp" => "torch.exp",
413 "log" => "torch.log",
414 "relu" => "torch.relu",
415 "sigmoid" => "torch.sigmoid",
416 "tanh" => "torch.tanh",
417 _ => "torch.unknown",
418 };
419
420 if options.include_comments {
421 writeln!(output, " # Element-wise unary operation: {}", op)?;
422 }
423
424 writeln!(
425 output,
426 " {} = {}({})",
427 output_tensor, torch_op, input_tensor
428 )?;
429 }
430 OpType::Reduce { op, axes } => {
431 let input_tensor = graph.tensors[node.inputs[0]].clone();
432 let torch_op = match op.as_str() {
433 "sum" => "sum",
434 "max" => "max",
435 "min" => "min",
436 "mean" => "mean",
437 "prod" => "prod",
438 _ => "unknown",
439 };
440
441 if options.include_comments {
442 writeln!(output, " # Reduction operation: {}", op)?;
443 }
444
445 if axes.is_empty() {
446 writeln!(
447 output,
448 " {} = {}.{}()",
449 output_tensor, input_tensor, torch_op
450 )?;
451 } else {
452 write!(
453 output,
454 " {} = {}.{}(dim=[",
455 output_tensor, input_tensor, torch_op
456 )?;
457 for (i, axis) in axes.iter().enumerate() {
458 if i > 0 {
459 write!(output, ", ")?;
460 }
461 write!(output, "{}", axis)?;
462 }
463 writeln!(output, "])")?;
464 }
465 }
466 }
467
468 Ok(())
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474 use crate::graph::{EinsumGraph, EinsumNode};
475
476 #[test]
477 fn test_onnx_export_simple() {
478 let mut graph = EinsumGraph::new();
479 let x = graph.add_tensor("X");
480 let y = graph.add_tensor("Y");
481 let z = graph.add_tensor("Z");
482
483 graph
484 .add_node(EinsumNode::elem_binary("add", x, y, z))
485 .unwrap();
486 graph.add_output(z).unwrap();
487
488 let onnx = export_to_onnx_text(&graph).unwrap();
489
490 assert!(onnx.contains("ir_version"));
491 assert!(onnx.contains("Add"));
492 assert!(onnx.contains("X"));
493 assert!(onnx.contains("Y"));
494 assert!(onnx.contains("Z"));
495 }
496
497 #[test]
498 fn test_onnx_export_einsum() {
499 let mut graph = EinsumGraph::new();
500 let a = graph.add_tensor("A");
501 let b = graph.add_tensor("B");
502 let c = graph.add_tensor("C");
503
504 graph
505 .add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
506 .unwrap();
507 graph.add_output(c).unwrap();
508
509 let onnx = export_to_onnx_text(&graph).unwrap();
510
511 assert!(onnx.contains("Einsum"));
512 assert!(onnx.contains("ij,jk->ik"));
513 }
514
515 #[test]
516 fn test_torchscript_export_simple() {
517 let mut graph = EinsumGraph::new();
518 let x = graph.add_tensor("X");
519 let y = graph.add_tensor("Y");
520 let z = graph.add_tensor("Z");
521
522 graph
523 .add_node(EinsumNode::elem_binary("mul", x, y, z))
524 .unwrap();
525 graph.add_output(z).unwrap();
526
527 let script = export_to_torchscript_text(&graph).unwrap();
528
529 assert!(script.contains("import torch"));
530 assert!(script.contains("class TensorLogicGraph"));
531 assert!(script.contains("torch.mul"));
532 }
533
534 #[test]
535 fn test_torchscript_export_einsum() {
536 let mut graph = EinsumGraph::new();
537 let x = graph.add_tensor("X");
538 let w = graph.add_tensor("W");
539 let y = graph.add_tensor("Y");
540
541 graph
542 .add_node(EinsumNode::einsum("ij,jk->ik", vec![x, w], vec![y]))
543 .unwrap();
544 graph.add_output(y).unwrap();
545
546 let script = export_to_torchscript_text(&graph).unwrap();
547
548 assert!(script.contains("torch.einsum"));
549 assert!(script.contains("'ij,jk->ik'"));
550 }
551
552 #[test]
553 fn test_onnx_export_reduction() {
554 let mut graph = EinsumGraph::new();
555 let x = graph.add_tensor("X");
556 let y = graph.add_tensor("Y");
557
558 graph
559 .add_node(EinsumNode::reduce("sum", vec![0, 1], x, y))
560 .unwrap();
561 graph.add_output(y).unwrap();
562
563 let onnx = export_to_onnx_text(&graph).unwrap();
564
565 assert!(onnx.contains("ReduceSum"));
566 assert!(onnx.contains("axes"));
567 }
568
569 #[test]
570 fn test_torchscript_export_unary() {
571 let mut graph = EinsumGraph::new();
572 let x = graph.add_tensor("X");
573 let y = graph.add_tensor("Y");
574
575 graph
576 .add_node(EinsumNode::elem_unary("relu", x, y))
577 .unwrap();
578 graph.add_output(y).unwrap();
579
580 let script = export_to_torchscript_text(&graph).unwrap();
581
582 assert!(script.contains("torch.relu"));
583 }
584
585 #[test]
586 fn test_onnx_export_with_options() {
587 let mut graph = EinsumGraph::new();
588 let x = graph.add_tensor("X");
589 let y = graph.add_tensor("Y");
590
591 graph.add_node(EinsumNode::elem_unary("exp", x, y)).unwrap();
592 graph.add_output(y).unwrap();
593
594 let options = OnnxExportOptions {
595 opset_version: 14,
596 producer_name: "CustomProducer".to_string(),
597 ..Default::default()
598 };
599
600 let onnx = export_to_onnx_text_with_options(&graph, &options).unwrap();
601
602 assert!(onnx.contains("version: 14"));
603 assert!(onnx.contains("CustomProducer"));
604 }
605
606 #[test]
607 fn test_torchscript_export_without_comments() {
608 let mut graph = EinsumGraph::new();
609 let x = graph.add_tensor("X");
610 let y = graph.add_tensor("Y");
611
612 graph
613 .add_node(EinsumNode::elem_unary("tanh", x, y))
614 .unwrap();
615 graph.add_output(y).unwrap();
616
617 let options = TorchScriptExportOptions {
618 include_comments: false,
619 ..Default::default()
620 };
621
622 let script = export_to_torchscript_text_with_options(&graph, &options).unwrap();
623
624 assert!(!script.contains("# "));
625 assert!(script.contains("torch.tanh"));
626 }
627
628 #[test]
629 fn test_export_multiple_outputs() {
630 let mut graph = EinsumGraph::new();
631 let x = graph.add_tensor("X");
632 let y = graph.add_tensor("Y");
633 let z = graph.add_tensor("Z");
634
635 graph.add_node(EinsumNode::elem_unary("exp", x, y)).unwrap();
636 graph.add_node(EinsumNode::elem_unary("log", x, z)).unwrap();
637 graph.add_output(y).unwrap();
638 graph.add_output(z).unwrap();
639
640 let script = export_to_torchscript_text(&graph).unwrap();
641
642 assert!(script.contains("return (Y, Z)"));
643 }
644}