Skip to main content

torsh_fx/
lib.rs

1//! Graph transformation framework for ToRSh
2//!
3//! This crate provides a comprehensive graph transformation framework built on a modular architecture.
4//! The FX graph system has been refactored into specialized modules for maintainability and performance.
5
6// Version information
7pub const VERSION: &str = env!("CARGO_PKG_VERSION");
8pub const VERSION_MAJOR: u32 = 0;
9pub const VERSION_MINOR: u32 = 1;
10pub const VERSION_PATCH: u32 = 0;
11
12use torsh_core::Result;
13
14/// Convenience type alias for Results in this crate
15pub type TorshResult<T> = Result<T>;
16
17// FX Graph modular system
18pub mod fx;
19
20// Re-export FX Graph types from the modular system
21pub use fx::{Edge, FxGraph, GraphStats, MemoryEstimate, Node, SerializableGraph};
22
23// Re-export key types for convenience from other modules
24pub use benchmarking::{BenchmarkResult, GraphBenchmarkSuite, RegressionTester};
25pub use checkpointing::{
26    create_checkpoint, load_checkpoint, save_checkpoint, CheckpointData, CheckpointFormat,
27    CheckpointManager, CheckpointMetadata, CheckpointOptions, ResumableInterpreter,
28};
29pub use codegen::{
30    CacheStats, CodeGenBackend, CodeGenerator, CompiledCode, CppCodeGen, LazyCompiler,
31    PythonCodeGen,
32};
33pub use custom_backends::{
34    execute_with_auto_backend, execute_with_backend, get_backend, list_available_backends,
35    register_backend_factory, BackendCapability, BackendContext, BackendExecutor, BackendFactory,
36    BackendInfo, BackendRegistry, BackendResult, BackendSelectionStrategy, CustomBackend,
37};
38pub use custom_operations::{
39    register_example_operations, CustomInt16AddOperation, CustomInt16MulOperation,
40    CustomInt16SubOperation, CustomTypeUnifyOperation, TypeConversionOperation,
41};
42pub use custom_types::{
43    global_extended_registry, register_extended_operation, CustomTypeUtils,
44    ExtendedCustomOperation, ExtendedOperationRegistry, ExtendedShapeInferenceContext,
45    ExtendedShapeInfo,
46};
47pub use distributed::{
48    create_execution_plan, execute_distributed, init_distributed, CollectiveOp,
49    CommunicationBackendType, DistributedConfig, DistributedExecutionPlan, DistributedExecutor,
50    DistributionStrategy, ReduceOp,
51};
52pub use dynamic_shapes::{
53    DynamicDim, DynamicShape, DynamicShapeInferenceContext, DynamicShapeInfo, ShapeConstraint,
54};
55pub use graph_analysis::{
56    calculate_graph_metrics, DetectedPattern, GraphDiff, GraphDifference, GraphLinter,
57    GraphMetrics, LintIssue, LintReport, LintSeverity, PatternDetector,
58};
59pub use graph_partitioning::{
60    DeviceInfo, DeviceType, GraphPartition, GraphPartitioner, PartitionedGraph,
61    PartitioningStrategy,
62};
63pub use heterogeneous_computing::{
64    DeviceCapability, ExecutionPlan, HeterogeneousExecutor, OperationSpecialization,
65    PlacementStrategy, SimpleDevice,
66};
67pub use memory_optimization::{
68    AdaptiveMemoryManager, AllocationStrategy, GraphMemoryLayout, MemoryAnalyzer,
69    MemoryMappedGraph, MemoryUsageReport,
70};
71pub use onnx_export::{export_to_onnx, OnnxExporter, OnnxModel};
72pub use performance::{
73    CacheStatistics, GraphCache, GraphCompression, ParallelTraversal, PerformanceBottleneck,
74    PerformanceProfiler, PerformanceReport,
75};
76pub use torchscript_compat::{
77    TorchScriptExporter, TorchScriptGraph, TorchScriptImporter, TorchScriptModel,
78};
79pub use tracer::{Module, ModuleTracer, SymbolicTensor, TracingProxy};
80
81// Re-export additional types for convenience
82pub use emerging_hardware::{
83    create_dna_backend, create_neuromorphic_backend, create_photonic_backend, AdaptationStrategy,
84    CompatibilityReport, EmergingHardware, EmergingHardwareBackend, EmergingHardwareResult,
85    ErrorCorrectionScheme, HardwareCapabilities, HardwareConstraint, HardwareSpecifications,
86    NeuromorphicProcessor, OptimizationObjective, PhotonicProcessor, PrecisionType,
87    QuantumInspiredProcessor, SpecializedOperation,
88};
89pub use interactive_editor::{
90    launch_interactive_editor, AutoSaveConfig, CollaborativeEdit, EditOperation, ExportFormat,
91    ImportFormat, InteractiveGraphEditor, PerformanceMetrics, UserSession, VisualizationConfig,
92};
93pub use neural_architecture_search::{
94    create_default_search_space, create_mobile_constraints, start_neural_architecture_search,
95    ArchitectureSearchSpace, CandidateArchitecture, HardwareConstraints, HardwarePlatform,
96    LayerType, NeuralArchitectureSearch, ObjectiveWeights, SearchResults, SearchStrategy,
97};
98pub use neuromorphic_optimization::{
99    create_loihi_optimizer, optimize_for_mobile_neuromorphic, EnergyEstimate, NeuromorphicHardware,
100    NeuromorphicOptimizationResult, NeuromorphicOptimizer, NeuronModel, OptimizationConfig,
101    SNNConversionParams, SpikeEncoding,
102};
103pub use python_integration::{
104    create_jax_integration, create_pytorch_integration, generate_python_api, graph_to_pytorch_code,
105    DeploymentPackage, GeneratedPythonCode, PyTorchModelMetadata, PythonBindingConfig,
106    PythonCodeGenOptions, PythonDeploymentTarget, PythonFramework, PythonIntegrationService,
107    TrainingInfo,
108};
109pub use quantization::{
110    apply_automatic_precision, prepare_graph_for_qat, quantize_graph_post_training,
111    select_automatic_precision, AutomaticPrecisionSelector, CalibrationData, PTQUtils,
112    PrecisionCriteria, PrecisionProfile, PrecisionRecommendation, PrecisionStrategy, QATUtils,
113    QuantizationAnnotation, QuantizationBenchmark, QuantizationContext, QuantizationParams,
114    QuantizationScheme,
115};
116pub use quantum_computing::{
117    create_local_quantum_backend, create_qaoa_circuit, create_qiskit_backend, create_vqe_circuit,
118    integrate_quantum_computing, CloudProvider, DataTransferType, ErrorMitigation,
119    HybridOptimizationStrategy, HybridWorkflow, NoiseModel, QuantumBackend, QuantumCircuit,
120    QuantumComputingBackend, QuantumExecutionResult, QuantumGate, QuantumPrecision, StateEncoding,
121    SynchronizationType,
122};
123
124// Module declarations for the comprehensive graph transformation framework
125pub mod checkpointing;
126pub mod cloud_deployment;
127pub mod codegen;
128pub mod custom_backends;
129pub mod custom_operations;
130pub mod custom_types;
131pub mod distributed;
132pub mod dynamic_shapes;
133pub mod emerging_hardware;
134// pub mod graph_module;  // Temporarily commented - depends on torsh-nn
135pub mod benchmarking;
136pub mod graph_analysis;
137pub mod graph_partitioning;
138pub mod heterogeneous_computing;
139pub mod interactive_editor;
140pub mod interpreter;
141pub mod memory_optimization;
142pub mod model_zoo;
143pub mod neural_architecture_search;
144pub mod neuromorphic_optimization;
145pub mod node;
146pub mod onnx_export;
147pub mod passes;
148pub mod performance;
149pub mod python_integration;
150pub mod quantization;
151pub mod quantum_computing;
152pub mod subgraph_rewriter;
153pub mod torchscript_compat;
154pub mod tracer;
155pub mod visualization;
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_graph_serialization_json() {
163        let mut graph = FxGraph::new();
164        let input = graph.graph.add_node(Node::Input("x".to_string()));
165        let relu = graph
166            .graph
167            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
168        let output = graph.graph.add_node(Node::Output);
169
170        graph.graph.add_edge(
171            input,
172            relu,
173            Edge {
174                name: "x".to_string(),
175            },
176        );
177        graph.graph.add_edge(
178            relu,
179            output,
180            Edge {
181                name: "relu_out".to_string(),
182            },
183        );
184        graph.inputs.push(input);
185        graph.outputs.push(output);
186
187        // Test JSON serialization
188        let json = graph.to_json().unwrap();
189        assert!(json.contains("Input"));
190        assert!(json.contains("relu"));
191
192        // Test JSON deserialization
193        let deserialized = FxGraph::from_json(&json).unwrap();
194        assert_eq!(deserialized.node_count(), graph.node_count());
195        assert_eq!(deserialized.edge_count(), graph.edge_count());
196    }
197
198    #[test]
199    fn test_graph_serialization_binary() {
200        let mut graph = FxGraph::new();
201        let input = graph.graph.add_node(Node::Input("x".to_string()));
202        let relu = graph
203            .graph
204            .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
205        let output = graph.graph.add_node(Node::Output);
206
207        graph.graph.add_edge(
208            input,
209            relu,
210            Edge {
211                name: "x".to_string(),
212            },
213        );
214        graph.graph.add_edge(
215            relu,
216            output,
217            Edge {
218                name: "relu_out".to_string(),
219            },
220        );
221        graph.inputs.push(input);
222        graph.outputs.push(output);
223
224        // Test binary serialization
225        let binary = graph.to_binary().unwrap();
226        assert!(!binary.is_empty());
227
228        // Test binary deserialization
229        let deserialized = FxGraph::from_binary(&binary).unwrap();
230        assert_eq!(deserialized.node_count(), graph.node_count());
231        assert_eq!(deserialized.edge_count(), graph.edge_count());
232    }
233
234    #[test]
235    fn test_single_op_graph() {
236        let graph = FxGraph::single_op("relu", vec!["input".to_string()]);
237
238        assert_eq!(graph.node_count(), 3); // input, operation, output
239        assert_eq!(graph.edge_count(), 2); // input->op, op->output
240        assert_eq!(graph.inputs().len(), 1);
241        assert_eq!(graph.outputs().len(), 1);
242
243        // Validate the graph structure
244        assert!(graph.validate().is_ok());
245
246        // Check node types
247        let input_nodes = graph.input_nodes();
248        let call_nodes = graph.call_nodes();
249        let output_nodes = graph.output_nodes();
250
251        assert_eq!(input_nodes.len(), 1);
252        assert_eq!(call_nodes.len(), 1);
253        assert_eq!(output_nodes.len(), 1);
254
255        // Check the operation name
256        if let Node::Call(op_name, _) = &call_nodes[0].1 {
257            assert_eq!(op_name, "relu");
258        } else {
259            panic!("Expected Call node");
260        }
261    }
262
263    #[test]
264    fn test_sequential_ops_graph() {
265        let ops = vec!["relu", "sigmoid", "tanh"];
266        let graph = FxGraph::sequential_ops(&ops);
267
268        assert_eq!(graph.node_count(), 5); // input, 3 ops, output
269        assert_eq!(graph.edge_count(), 4); // input->relu, relu->sigmoid, sigmoid->tanh, tanh->output
270        assert_eq!(graph.inputs().len(), 1);
271        assert_eq!(graph.outputs().len(), 1);
272
273        // Validate the graph structure
274        assert!(graph.validate().is_ok());
275
276        // Check that we have the right number of operations
277        let call_nodes = graph.call_nodes();
278        assert_eq!(call_nodes.len(), 3);
279
280        // Verify the operations are in the right order (by checking their connections)
281        let mut op_names = Vec::new();
282        for (_, node) in call_nodes {
283            if let Node::Call(op_name, _) = node {
284                op_names.push(op_name.clone());
285            }
286        }
287
288        // The order might not be preserved in iteration, so just check that all ops are present
289        assert!(op_names.contains(&"relu".to_string()));
290        assert!(op_names.contains(&"sigmoid".to_string()));
291        assert!(op_names.contains(&"tanh".to_string()));
292    }
293
294    #[test]
295    fn test_empty_sequential_ops() {
296        let graph = FxGraph::sequential_ops(&[]);
297        assert_eq!(graph.node_count(), 0);
298        assert_eq!(graph.edge_count(), 0);
299        assert_eq!(graph.inputs().len(), 0);
300        assert_eq!(graph.outputs().len(), 0);
301    }
302
303    #[test]
304    fn test_modular_architecture() {
305        // Test that the modular architecture maintains the same interface
306        let graph = FxGraph::single_op("test_op", vec!["input".to_string()]);
307
308        // Basic functionality should work
309        assert!(graph.node_count() > 0);
310        assert!(graph.edge_count() > 0);
311        assert!(graph.validate().is_ok());
312
313        // Analysis functionality should work
314        let summary = graph.summary();
315        assert!(summary.contains("FX Graph Summary"));
316
317        // Construction utilities should work
318        let debug_graph = FxGraph::debug_minimal();
319        assert!(debug_graph.validate().is_ok());
320
321        // Serialization should work
322        let json = graph.to_json().unwrap();
323        assert!(!json.is_empty());
324
325        let deserialized = FxGraph::from_json(&json).unwrap();
326        assert_eq!(deserialized.node_count(), graph.node_count());
327    }
328
329    #[test]
330    fn test_graph_validation() {
331        // Test valid graph
332        let graph = FxGraph::single_op("relu", vec!["input".to_string()]);
333        assert!(graph.validate().is_ok());
334
335        // Test invalid graph - no inputs
336        let mut invalid_graph = FxGraph::new();
337        let output = invalid_graph.add_node(Node::Output);
338        invalid_graph.add_output(output);
339        assert!(invalid_graph.validate().is_err());
340
341        // Test invalid graph - no outputs
342        let mut invalid_graph2 = FxGraph::new();
343        let input = invalid_graph2.add_node(Node::Input("x".to_string()));
344        invalid_graph2.add_input(input);
345        assert!(invalid_graph2.validate().is_err());
346    }
347
348    #[test]
349    fn test_performance_recommendations() {
350        let graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh"]);
351        let recommendations = graph.performance_recommendations();
352        assert!(!recommendations.is_empty());
353    }
354
355    #[test]
356    fn test_operation_analysis() {
357        let graph = FxGraph::sequential_ops(&["relu", "sigmoid", "relu"]);
358
359        // Test operation names
360        let op_names = graph.get_operation_names();
361        assert_eq!(op_names.len(), 2); // "relu" and "sigmoid" (unique)
362        assert!(op_names.contains(&"relu".to_string()));
363        assert!(op_names.contains(&"sigmoid".to_string()));
364
365        // Test contains operation
366        assert!(graph.contains_operation("relu"));
367        assert!(graph.contains_operation("sigmoid"));
368        assert!(!graph.contains_operation("tanh"));
369
370        // Test operation counts
371        let counts = graph.operation_counts();
372        assert_eq!(counts.get("relu"), Some(&2)); // relu appears twice
373        assert_eq!(counts.get("sigmoid"), Some(&1)); // sigmoid appears once
374        assert_eq!(counts.get("tanh"), None); // tanh doesn't appear
375    }
376}
377
378/// Prelude module for convenient imports
379pub mod prelude {
380    pub use crate::fx::*;
381    pub use crate::{
382        benchmarking::*, checkpointing::*, codegen::*, custom_backends::*, distributed::*,
383        graph_analysis::*, tracer::*, Edge, FxGraph, GraphStats, MemoryEstimate, Node,
384        SerializableGraph,
385    };
386}