Skip to main content

tensorlogic_ir/graph/
export.rs

1//! Graph export to standard ML formats.
2//!
3//! This module provides export functionality for EinsumGraph to various
4//! machine learning interchange formats:
5//! - ONNX (Open Neural Network Exchange) text representation
6//! - TorchScript text representation
7//! - Textual IR representations
8//!
9//! # Examples
10//!
11//! ```no_run
12//! use tensorlogic_ir::{EinsumGraph, EinsumNode};
13//! use tensorlogic_ir::{export_to_onnx_text, export_to_torchscript_text};
14//!
15//! let mut graph = EinsumGraph::new();
16//! let a = graph.add_tensor("A");
17//! let b = graph.add_tensor("B");
18//! let c = graph.add_tensor("C");
19//!
20//! graph.add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c])).unwrap();
21//! graph.add_output(c).unwrap();
22//!
23//! // Export to ONNX text format
24//! let onnx_text = export_to_onnx_text(&graph).unwrap();
25//! println!("{}", onnx_text);
26//!
27//! // Export to TorchScript text format
28//! let torchscript = export_to_torchscript_text(&graph).unwrap();
29//! println!("{}", torchscript);
30//! ```
31
32use crate::error::IrError;
33use crate::graph::{EinsumGraph, OpType};
34use std::fmt::Write as FmtWrite;
35
36/// Export options for ONNX format.
37#[derive(Clone, Debug)]
38pub struct OnnxExportOptions {
39    /// ONNX opset version to target (default: 13)
40    pub opset_version: i64,
41    /// Include metadata in export
42    pub include_metadata: bool,
43    /// Producer name
44    pub producer_name: String,
45    /// Model version
46    pub model_version: i64,
47}
48
49impl Default for OnnxExportOptions {
50    fn default() -> Self {
51        Self {
52            opset_version: 13,
53            include_metadata: true,
54            producer_name: "TensorLogic".to_string(),
55            model_version: 1,
56        }
57    }
58}
59
60/// Export options for TorchScript format.
61#[derive(Clone, Debug)]
62pub struct TorchScriptExportOptions {
63    /// Include type annotations
64    pub include_types: bool,
65    /// Include comments
66    pub include_comments: bool,
67    /// Optimize for inference (freeze parameters)
68    pub optimize_for_inference: bool,
69}
70
71impl Default for TorchScriptExportOptions {
72    fn default() -> Self {
73        Self {
74            include_types: true,
75            include_comments: true,
76            optimize_for_inference: false,
77        }
78    }
79}
80
81/// Export EinsumGraph to ONNX text representation.
82///
83/// This creates a textual representation of the ONNX model that describes
84/// the computation graph structure. This can be used for debugging or
85/// converted to binary ONNX format using ONNX tools.
86///
87/// # Examples
88///
89/// ```no_run
90/// use tensorlogic_ir::{EinsumGraph, EinsumNode};
91/// use tensorlogic_ir::export_to_onnx_text;
92///
93/// let mut graph = EinsumGraph::new();
94/// let x = graph.add_tensor("X");
95/// let y = graph.add_tensor("Y");
96/// let z = graph.add_tensor("Z");
97///
98/// graph.add_node(EinsumNode::elem_binary("add", x, y, z)).unwrap();
99/// graph.add_output(z).unwrap();
100///
101/// let onnx = export_to_onnx_text(&graph).unwrap();
102/// assert!(onnx.contains("ir_version"));
103/// assert!(onnx.contains("Add"));
104/// ```
105pub fn export_to_onnx_text(graph: &EinsumGraph) -> Result<String, IrError> {
106    export_to_onnx_text_with_options(graph, &OnnxExportOptions::default())
107}
108
109/// Export EinsumGraph to ONNX text representation with custom options.
110pub fn export_to_onnx_text_with_options(
111    graph: &EinsumGraph,
112    options: &OnnxExportOptions,
113) -> Result<String, IrError> {
114    let mut output = String::new();
115
116    // ONNX header
117    writeln!(output, "# ONNX Model: TensorLogic Computation Graph")?;
118    writeln!(output, "# Producer: {}", options.producer_name)?;
119    writeln!(output, "# Model Version: {}", options.model_version)?;
120    writeln!(output)?;
121    writeln!(output, "ir_version: 7")?;
122    writeln!(output, "opset_import {{")?;
123    writeln!(output, "  domain: \"\"")?;
124    writeln!(output, "  version: {}", options.opset_version)?;
125    writeln!(output, "}}")?;
126    writeln!(output)?;
127
128    // Model graph
129    writeln!(output, "graph {{")?;
130    writeln!(output, "  name: \"tensorlogic_graph\"")?;
131    writeln!(output)?;
132
133    // Inputs
134    writeln!(output, "  # Inputs")?;
135    for &input_idx in &graph.inputs {
136        let tensor_name = &graph.tensors[input_idx];
137        writeln!(output, "  input {{")?;
138        writeln!(output, "    name: \"{}\"", tensor_name)?;
139        writeln!(output, "    type {{")?;
140        writeln!(output, "      tensor_type {{")?;
141        writeln!(output, "        elem_type: 1  # FLOAT")?;
142        writeln!(output, "        shape {{")?;
143        writeln!(output, "          dim {{ dim_param: \"batch\" }}")?;
144        writeln!(output, "          dim {{ dim_param: \"dynamic\" }}")?;
145        writeln!(output, "        }}")?;
146        writeln!(output, "      }}")?;
147        writeln!(output, "    }}")?;
148        writeln!(output, "  }}")?;
149    }
150    writeln!(output)?;
151
152    // Nodes (operations)
153    writeln!(output, "  # Operations")?;
154    for (node_idx, node) in graph.nodes.iter().enumerate() {
155        export_node_to_onnx(&mut output, node, node_idx, graph)?;
156    }
157    writeln!(output)?;
158
159    // Outputs
160    writeln!(output, "  # Outputs")?;
161    for &output_idx in &graph.outputs {
162        let tensor_name = &graph.tensors[output_idx];
163        writeln!(output, "  output {{")?;
164        writeln!(output, "    name: \"{}\"", tensor_name)?;
165        writeln!(output, "    type {{")?;
166        writeln!(output, "      tensor_type {{")?;
167        writeln!(output, "        elem_type: 1  # FLOAT")?;
168        writeln!(output, "        shape {{")?;
169        writeln!(output, "          dim {{ dim_param: \"batch\" }}")?;
170        writeln!(output, "          dim {{ dim_param: \"dynamic\" }}")?;
171        writeln!(output, "        }}")?;
172        writeln!(output, "      }}")?;
173        writeln!(output, "    }}")?;
174        writeln!(output, "  }}")?;
175    }
176
177    writeln!(output, "}}")?;
178
179    Ok(output)
180}
181
182/// Helper to export a single node to ONNX format.
183fn export_node_to_onnx(
184    output: &mut String,
185    node: &crate::graph::EinsumNode,
186    node_idx: usize,
187    graph: &EinsumGraph,
188) -> Result<(), IrError> {
189    writeln!(output, "  node {{")?;
190
191    // Input tensors
192    for &input_idx in &node.inputs {
193        writeln!(output, "    input: \"{}\"", graph.tensors[input_idx])?;
194    }
195
196    // Output tensors
197    for &output_idx in &node.outputs {
198        writeln!(output, "    output: \"{}\"", graph.tensors[output_idx])?;
199    }
200
201    // Operation type
202    let op_name = match &node.op {
203        OpType::Einsum { spec } => {
204            writeln!(output, "    op_type: \"Einsum\"")?;
205            writeln!(output, "    attribute {{")?;
206            writeln!(output, "      name: \"equation\"")?;
207            writeln!(output, "      s: \"{}\"", spec)?;
208            writeln!(output, "      type: STRING")?;
209            writeln!(output, "    }}")?;
210            "Einsum"
211        }
212        OpType::ElemBinary { op } => {
213            let onnx_op = match op.as_str() {
214                "add" => "Add",
215                "sub" => "Sub",
216                "mul" => "Mul",
217                "div" => "Div",
218                _ => "Unknown",
219            };
220            writeln!(output, "    op_type: \"{}\"", onnx_op)?;
221            onnx_op
222        }
223        OpType::ElemUnary { op } => {
224            let onnx_op = match op.as_str() {
225                "neg" => "Neg",
226                "exp" => "Exp",
227                "log" => "Log",
228                "relu" => "Relu",
229                "sigmoid" => "Sigmoid",
230                "tanh" => "Tanh",
231                _ => "Unknown",
232            };
233            writeln!(output, "    op_type: \"{}\"", onnx_op)?;
234            onnx_op
235        }
236        OpType::Reduce { op, axes } => {
237            let onnx_op = match op.as_str() {
238                "sum" => "ReduceSum",
239                "max" => "ReduceMax",
240                "min" => "ReduceMin",
241                "mean" => "ReduceMean",
242                "prod" => "ReduceProd",
243                _ => "Unknown",
244            };
245            writeln!(output, "    op_type: \"{}\"", onnx_op)?;
246            if !axes.is_empty() {
247                writeln!(output, "    attribute {{")?;
248                writeln!(output, "      name: \"axes\"")?;
249                write!(output, "      ints: ")?;
250                for (i, axis) in axes.iter().enumerate() {
251                    if i > 0 {
252                        write!(output, ", ")?;
253                    }
254                    write!(output, "{}", axis)?;
255                }
256                writeln!(output)?;
257                writeln!(output, "      type: INTS")?;
258                writeln!(output, "    }}")?;
259            }
260            onnx_op
261        }
262    };
263
264    writeln!(output, "    name: \"node_{}\"", node_idx)?;
265    writeln!(output, "    doc_string: \"{} operation\"", op_name)?;
266    writeln!(output, "  }}")?;
267
268    Ok(())
269}
270
271/// Export EinsumGraph to TorchScript text representation.
272///
273/// This creates a PyTorch TorchScript representation that can be loaded
274/// and executed by PyTorch's JIT compiler.
275///
276/// # Examples
277///
278/// ```no_run
279/// use tensorlogic_ir::{EinsumGraph, EinsumNode};
280/// use tensorlogic_ir::export_to_torchscript_text;
281///
282/// let mut graph = EinsumGraph::new();
283/// let x = graph.add_tensor("X");
284/// let w = graph.add_tensor("W");
285/// let y = graph.add_tensor("Y");
286///
287/// graph.add_node(EinsumNode::einsum("ij,jk->ik", vec![x, w], vec![y])).unwrap();
288/// graph.add_output(y).unwrap();
289///
290/// let script = export_to_torchscript_text(&graph).unwrap();
291/// assert!(script.contains("torch.einsum"));
292/// ```
293pub fn export_to_torchscript_text(graph: &EinsumGraph) -> Result<String, IrError> {
294    export_to_torchscript_text_with_options(graph, &TorchScriptExportOptions::default())
295}
296
297/// Export EinsumGraph to TorchScript text representation with custom options.
298pub fn export_to_torchscript_text_with_options(
299    graph: &EinsumGraph,
300    options: &TorchScriptExportOptions,
301) -> Result<String, IrError> {
302    let mut output = String::new();
303
304    // Header
305    if options.include_comments {
306        writeln!(
307            output,
308            "# TorchScript representation of TensorLogic computation graph"
309        )?;
310        writeln!(output, "# Generated by TensorLogic IR")?;
311        writeln!(output)?;
312    }
313
314    writeln!(output, "import torch")?;
315    writeln!(output, "import torch.nn as nn")?;
316    writeln!(output)?;
317
318    // Module class
319    writeln!(output, "class TensorLogicGraph(nn.Module):")?;
320    writeln!(output, "    def __init__(self):")?;
321    writeln!(output, "        super(TensorLogicGraph, self).__init__()")?;
322    writeln!(output)?;
323
324    // Forward method
325    write!(output, "    def forward(self")?;
326
327    // Input parameters
328    for &input_idx in &graph.inputs {
329        write!(output, ", {}", graph.tensors[input_idx])?;
330    }
331    writeln!(output, "):")?;
332
333    if options.include_comments {
334        writeln!(output, "        # Computation graph")?;
335    }
336
337    // Generate operations
338    for node in &graph.nodes {
339        export_node_to_torchscript(&mut output, node, graph, options)?;
340    }
341
342    // Return statement
343    writeln!(output)?;
344    write!(output, "        return ")?;
345    if graph.outputs.len() == 1 {
346        writeln!(output, "{}", graph.tensors[graph.outputs[0]])?;
347    } else {
348        write!(output, "(")?;
349        for (i, &output_idx) in graph.outputs.iter().enumerate() {
350            if i > 0 {
351                write!(output, ", ")?;
352            }
353            write!(output, "{}", graph.tensors[output_idx])?;
354        }
355        writeln!(output, ")")?;
356    }
357
358    Ok(output)
359}
360
361/// Helper to export a single node to TorchScript format.
362fn export_node_to_torchscript(
363    output: &mut String,
364    node: &crate::graph::EinsumNode,
365    graph: &EinsumGraph,
366    options: &TorchScriptExportOptions,
367) -> Result<(), IrError> {
368    let output_tensor = graph.tensors[node.outputs[0]].clone();
369
370    match &node.op {
371        OpType::Einsum { spec } => {
372            write!(
373                output,
374                "        {} = torch.einsum('{}', ",
375                output_tensor, spec
376            )?;
377            for (i, &input_idx) in node.inputs.iter().enumerate() {
378                if i > 0 {
379                    write!(output, ", ")?;
380                }
381                write!(output, "{}", graph.tensors[input_idx])?;
382            }
383            writeln!(output, ")")?;
384        }
385        OpType::ElemBinary { op } => {
386            let input_tensors = &node.inputs;
387            let torch_op = match op.as_str() {
388                "add" => "torch.add",
389                "sub" => "torch.sub",
390                "mul" => "torch.mul",
391                "div" => "torch.div",
392                _ => "torch.unknown",
393            };
394
395            if options.include_comments {
396                writeln!(output, "        # Element-wise binary operation: {}", op)?;
397            }
398
399            writeln!(
400                output,
401                "        {} = {}({}, {})",
402                output_tensor,
403                torch_op,
404                graph.tensors[input_tensors[0]],
405                graph.tensors[input_tensors[1]]
406            )?;
407        }
408        OpType::ElemUnary { op } => {
409            let input_tensor = graph.tensors[node.inputs[0]].clone();
410            let torch_op = match op.as_str() {
411                "neg" => "torch.neg",
412                "exp" => "torch.exp",
413                "log" => "torch.log",
414                "relu" => "torch.relu",
415                "sigmoid" => "torch.sigmoid",
416                "tanh" => "torch.tanh",
417                _ => "torch.unknown",
418            };
419
420            if options.include_comments {
421                writeln!(output, "        # Element-wise unary operation: {}", op)?;
422            }
423
424            writeln!(
425                output,
426                "        {} = {}({})",
427                output_tensor, torch_op, input_tensor
428            )?;
429        }
430        OpType::Reduce { op, axes } => {
431            let input_tensor = graph.tensors[node.inputs[0]].clone();
432            let torch_op = match op.as_str() {
433                "sum" => "sum",
434                "max" => "max",
435                "min" => "min",
436                "mean" => "mean",
437                "prod" => "prod",
438                _ => "unknown",
439            };
440
441            if options.include_comments {
442                writeln!(output, "        # Reduction operation: {}", op)?;
443            }
444
445            if axes.is_empty() {
446                writeln!(
447                    output,
448                    "        {} = {}.{}()",
449                    output_tensor, input_tensor, torch_op
450                )?;
451            } else {
452                write!(
453                    output,
454                    "        {} = {}.{}(dim=[",
455                    output_tensor, input_tensor, torch_op
456                )?;
457                for (i, axis) in axes.iter().enumerate() {
458                    if i > 0 {
459                        write!(output, ", ")?;
460                    }
461                    write!(output, "{}", axis)?;
462                }
463                writeln!(output, "])")?;
464            }
465        }
466    }
467
468    Ok(())
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::graph::{EinsumGraph, EinsumNode};
475
476    #[test]
477    fn test_onnx_export_simple() {
478        let mut graph = EinsumGraph::new();
479        let x = graph.add_tensor("X");
480        let y = graph.add_tensor("Y");
481        let z = graph.add_tensor("Z");
482
483        graph
484            .add_node(EinsumNode::elem_binary("add", x, y, z))
485            .unwrap();
486        graph.add_output(z).unwrap();
487
488        let onnx = export_to_onnx_text(&graph).unwrap();
489
490        assert!(onnx.contains("ir_version"));
491        assert!(onnx.contains("Add"));
492        assert!(onnx.contains("X"));
493        assert!(onnx.contains("Y"));
494        assert!(onnx.contains("Z"));
495    }
496
497    #[test]
498    fn test_onnx_export_einsum() {
499        let mut graph = EinsumGraph::new();
500        let a = graph.add_tensor("A");
501        let b = graph.add_tensor("B");
502        let c = graph.add_tensor("C");
503
504        graph
505            .add_node(EinsumNode::einsum("ij,jk->ik", vec![a, b], vec![c]))
506            .unwrap();
507        graph.add_output(c).unwrap();
508
509        let onnx = export_to_onnx_text(&graph).unwrap();
510
511        assert!(onnx.contains("Einsum"));
512        assert!(onnx.contains("ij,jk->ik"));
513    }
514
515    #[test]
516    fn test_torchscript_export_simple() {
517        let mut graph = EinsumGraph::new();
518        let x = graph.add_tensor("X");
519        let y = graph.add_tensor("Y");
520        let z = graph.add_tensor("Z");
521
522        graph
523            .add_node(EinsumNode::elem_binary("mul", x, y, z))
524            .unwrap();
525        graph.add_output(z).unwrap();
526
527        let script = export_to_torchscript_text(&graph).unwrap();
528
529        assert!(script.contains("import torch"));
530        assert!(script.contains("class TensorLogicGraph"));
531        assert!(script.contains("torch.mul"));
532    }
533
534    #[test]
535    fn test_torchscript_export_einsum() {
536        let mut graph = EinsumGraph::new();
537        let x = graph.add_tensor("X");
538        let w = graph.add_tensor("W");
539        let y = graph.add_tensor("Y");
540
541        graph
542            .add_node(EinsumNode::einsum("ij,jk->ik", vec![x, w], vec![y]))
543            .unwrap();
544        graph.add_output(y).unwrap();
545
546        let script = export_to_torchscript_text(&graph).unwrap();
547
548        assert!(script.contains("torch.einsum"));
549        assert!(script.contains("'ij,jk->ik'"));
550    }
551
552    #[test]
553    fn test_onnx_export_reduction() {
554        let mut graph = EinsumGraph::new();
555        let x = graph.add_tensor("X");
556        let y = graph.add_tensor("Y");
557
558        graph
559            .add_node(EinsumNode::reduce("sum", vec![0, 1], x, y))
560            .unwrap();
561        graph.add_output(y).unwrap();
562
563        let onnx = export_to_onnx_text(&graph).unwrap();
564
565        assert!(onnx.contains("ReduceSum"));
566        assert!(onnx.contains("axes"));
567    }
568
569    #[test]
570    fn test_torchscript_export_unary() {
571        let mut graph = EinsumGraph::new();
572        let x = graph.add_tensor("X");
573        let y = graph.add_tensor("Y");
574
575        graph
576            .add_node(EinsumNode::elem_unary("relu", x, y))
577            .unwrap();
578        graph.add_output(y).unwrap();
579
580        let script = export_to_torchscript_text(&graph).unwrap();
581
582        assert!(script.contains("torch.relu"));
583    }
584
585    #[test]
586    fn test_onnx_export_with_options() {
587        let mut graph = EinsumGraph::new();
588        let x = graph.add_tensor("X");
589        let y = graph.add_tensor("Y");
590
591        graph.add_node(EinsumNode::elem_unary("exp", x, y)).unwrap();
592        graph.add_output(y).unwrap();
593
594        let options = OnnxExportOptions {
595            opset_version: 14,
596            producer_name: "CustomProducer".to_string(),
597            ..Default::default()
598        };
599
600        let onnx = export_to_onnx_text_with_options(&graph, &options).unwrap();
601
602        assert!(onnx.contains("version: 14"));
603        assert!(onnx.contains("CustomProducer"));
604    }
605
606    #[test]
607    fn test_torchscript_export_without_comments() {
608        let mut graph = EinsumGraph::new();
609        let x = graph.add_tensor("X");
610        let y = graph.add_tensor("Y");
611
612        graph
613            .add_node(EinsumNode::elem_unary("tanh", x, y))
614            .unwrap();
615        graph.add_output(y).unwrap();
616
617        let options = TorchScriptExportOptions {
618            include_comments: false,
619            ..Default::default()
620        };
621
622        let script = export_to_torchscript_text_with_options(&graph, &options).unwrap();
623
624        assert!(!script.contains("# "));
625        assert!(script.contains("torch.tanh"));
626    }
627
628    #[test]
629    fn test_export_multiple_outputs() {
630        let mut graph = EinsumGraph::new();
631        let x = graph.add_tensor("X");
632        let y = graph.add_tensor("Y");
633        let z = graph.add_tensor("Z");
634
635        graph.add_node(EinsumNode::elem_unary("exp", x, y)).unwrap();
636        graph.add_node(EinsumNode::elem_unary("log", x, z)).unwrap();
637        graph.add_output(y).unwrap();
638        graph.add_output(z).unwrap();
639
640        let script = export_to_torchscript_text(&graph).unwrap();
641
642        assert!(script.contains("return (Y, Z)"));
643    }
644}