1pub const VERSION: &str = env!("CARGO_PKG_VERSION");
8pub const VERSION_MAJOR: u32 = 0;
9pub const VERSION_MINOR: u32 = 1;
10pub const VERSION_PATCH: u32 = 0;
11
12use torsh_core::Result;
13
14pub type TorshResult<T> = Result<T>;
16
17pub mod fx;
19
20pub use fx::{Edge, FxGraph, GraphStats, MemoryEstimate, Node, SerializableGraph};
22
23pub use benchmarking::{BenchmarkResult, GraphBenchmarkSuite, RegressionTester};
25pub use checkpointing::{
26 create_checkpoint, load_checkpoint, save_checkpoint, CheckpointData, CheckpointFormat,
27 CheckpointManager, CheckpointMetadata, CheckpointOptions, ResumableInterpreter,
28};
29pub use codegen::{
30 CacheStats, CodeGenBackend, CodeGenerator, CompiledCode, CppCodeGen, LazyCompiler,
31 PythonCodeGen,
32};
33pub use custom_backends::{
34 execute_with_auto_backend, execute_with_backend, get_backend, list_available_backends,
35 register_backend_factory, BackendCapability, BackendContext, BackendExecutor, BackendFactory,
36 BackendInfo, BackendRegistry, BackendResult, BackendSelectionStrategy, CustomBackend,
37};
38pub use custom_operations::{
39 register_example_operations, CustomInt16AddOperation, CustomInt16MulOperation,
40 CustomInt16SubOperation, CustomTypeUnifyOperation, TypeConversionOperation,
41};
42pub use custom_types::{
43 global_extended_registry, register_extended_operation, CustomTypeUtils,
44 ExtendedCustomOperation, ExtendedOperationRegistry, ExtendedShapeInferenceContext,
45 ExtendedShapeInfo,
46};
47pub use distributed::{
48 create_execution_plan, execute_distributed, init_distributed, CollectiveOp,
49 CommunicationBackendType, DistributedConfig, DistributedExecutionPlan, DistributedExecutor,
50 DistributionStrategy, ReduceOp,
51};
52pub use dynamic_shapes::{
53 DynamicDim, DynamicShape, DynamicShapeInferenceContext, DynamicShapeInfo, ShapeConstraint,
54};
55pub use graph_analysis::{
56 calculate_graph_metrics, DetectedPattern, GraphDiff, GraphDifference, GraphLinter,
57 GraphMetrics, LintIssue, LintReport, LintSeverity, PatternDetector,
58};
59pub use graph_partitioning::{
60 DeviceInfo, DeviceType, GraphPartition, GraphPartitioner, PartitionedGraph,
61 PartitioningStrategy,
62};
63pub use heterogeneous_computing::{
64 DeviceCapability, ExecutionPlan, HeterogeneousExecutor, OperationSpecialization,
65 PlacementStrategy, SimpleDevice,
66};
67pub use memory_optimization::{
68 AdaptiveMemoryManager, AllocationStrategy, GraphMemoryLayout, MemoryAnalyzer,
69 MemoryMappedGraph, MemoryUsageReport,
70};
71pub use onnx_export::{export_to_onnx, OnnxExporter, OnnxModel};
72pub use performance::{
73 CacheStatistics, GraphCache, GraphCompression, ParallelTraversal, PerformanceBottleneck,
74 PerformanceProfiler, PerformanceReport,
75};
76pub use torchscript_compat::{
77 TorchScriptExporter, TorchScriptGraph, TorchScriptImporter, TorchScriptModel,
78};
79pub use tracer::{Module, ModuleTracer, SymbolicTensor, TracingProxy};
80
81pub use emerging_hardware::{
83 create_dna_backend, create_neuromorphic_backend, create_photonic_backend, AdaptationStrategy,
84 CompatibilityReport, EmergingHardware, EmergingHardwareBackend, EmergingHardwareResult,
85 ErrorCorrectionScheme, HardwareCapabilities, HardwareConstraint, HardwareSpecifications,
86 NeuromorphicProcessor, OptimizationObjective, PhotonicProcessor, PrecisionType,
87 QuantumInspiredProcessor, SpecializedOperation,
88};
89pub use interactive_editor::{
90 launch_interactive_editor, AutoSaveConfig, CollaborativeEdit, EditOperation, ExportFormat,
91 ImportFormat, InteractiveGraphEditor, PerformanceMetrics, UserSession, VisualizationConfig,
92};
93pub use neural_architecture_search::{
94 create_default_search_space, create_mobile_constraints, start_neural_architecture_search,
95 ArchitectureSearchSpace, CandidateArchitecture, HardwareConstraints, HardwarePlatform,
96 LayerType, NeuralArchitectureSearch, ObjectiveWeights, SearchResults, SearchStrategy,
97};
98pub use neuromorphic_optimization::{
99 create_loihi_optimizer, optimize_for_mobile_neuromorphic, EnergyEstimate, NeuromorphicHardware,
100 NeuromorphicOptimizationResult, NeuromorphicOptimizer, NeuronModel, OptimizationConfig,
101 SNNConversionParams, SpikeEncoding,
102};
103pub use python_integration::{
104 create_jax_integration, create_pytorch_integration, generate_python_api, graph_to_pytorch_code,
105 DeploymentPackage, GeneratedPythonCode, PyTorchModelMetadata, PythonBindingConfig,
106 PythonCodeGenOptions, PythonDeploymentTarget, PythonFramework, PythonIntegrationService,
107 TrainingInfo,
108};
109pub use quantization::{
110 apply_automatic_precision, prepare_graph_for_qat, quantize_graph_post_training,
111 select_automatic_precision, AutomaticPrecisionSelector, CalibrationData, PTQUtils,
112 PrecisionCriteria, PrecisionProfile, PrecisionRecommendation, PrecisionStrategy, QATUtils,
113 QuantizationAnnotation, QuantizationBenchmark, QuantizationContext, QuantizationParams,
114 QuantizationScheme,
115};
116pub use quantum_computing::{
117 create_local_quantum_backend, create_qaoa_circuit, create_qiskit_backend, create_vqe_circuit,
118 integrate_quantum_computing, CloudProvider, DataTransferType, ErrorMitigation,
119 HybridOptimizationStrategy, HybridWorkflow, NoiseModel, QuantumBackend, QuantumCircuit,
120 QuantumComputingBackend, QuantumExecutionResult, QuantumGate, QuantumPrecision, StateEncoding,
121 SynchronizationType,
122};
123
124pub mod checkpointing;
126pub mod cloud_deployment;
127pub mod codegen;
128pub mod custom_backends;
129pub mod custom_operations;
130pub mod custom_types;
131pub mod distributed;
132pub mod dynamic_shapes;
133pub mod emerging_hardware;
134pub mod benchmarking;
136pub mod graph_analysis;
137pub mod graph_partitioning;
138pub mod heterogeneous_computing;
139pub mod interactive_editor;
140pub mod interpreter;
141pub mod memory_optimization;
142pub mod model_zoo;
143pub mod neural_architecture_search;
144pub mod neuromorphic_optimization;
145pub mod node;
146pub mod onnx_export;
147pub mod passes;
148pub mod performance;
149pub mod python_integration;
150pub mod quantization;
151pub mod quantum_computing;
152pub mod subgraph_rewriter;
153pub mod torchscript_compat;
154pub mod tracer;
155pub mod visualization;
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_graph_serialization_json() {
163 let mut graph = FxGraph::new();
164 let input = graph.graph.add_node(Node::Input("x".to_string()));
165 let relu = graph
166 .graph
167 .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
168 let output = graph.graph.add_node(Node::Output);
169
170 graph.graph.add_edge(
171 input,
172 relu,
173 Edge {
174 name: "x".to_string(),
175 },
176 );
177 graph.graph.add_edge(
178 relu,
179 output,
180 Edge {
181 name: "relu_out".to_string(),
182 },
183 );
184 graph.inputs.push(input);
185 graph.outputs.push(output);
186
187 let json = graph.to_json().unwrap();
189 assert!(json.contains("Input"));
190 assert!(json.contains("relu"));
191
192 let deserialized = FxGraph::from_json(&json).unwrap();
194 assert_eq!(deserialized.node_count(), graph.node_count());
195 assert_eq!(deserialized.edge_count(), graph.edge_count());
196 }
197
198 #[test]
199 fn test_graph_serialization_binary() {
200 let mut graph = FxGraph::new();
201 let input = graph.graph.add_node(Node::Input("x".to_string()));
202 let relu = graph
203 .graph
204 .add_node(Node::Call("relu".to_string(), vec!["x".to_string()]));
205 let output = graph.graph.add_node(Node::Output);
206
207 graph.graph.add_edge(
208 input,
209 relu,
210 Edge {
211 name: "x".to_string(),
212 },
213 );
214 graph.graph.add_edge(
215 relu,
216 output,
217 Edge {
218 name: "relu_out".to_string(),
219 },
220 );
221 graph.inputs.push(input);
222 graph.outputs.push(output);
223
224 let binary = graph.to_binary().unwrap();
226 assert!(!binary.is_empty());
227
228 let deserialized = FxGraph::from_binary(&binary).unwrap();
230 assert_eq!(deserialized.node_count(), graph.node_count());
231 assert_eq!(deserialized.edge_count(), graph.edge_count());
232 }
233
234 #[test]
235 fn test_single_op_graph() {
236 let graph = FxGraph::single_op("relu", vec!["input".to_string()]);
237
238 assert_eq!(graph.node_count(), 3); assert_eq!(graph.edge_count(), 2); assert_eq!(graph.inputs().len(), 1);
241 assert_eq!(graph.outputs().len(), 1);
242
243 assert!(graph.validate().is_ok());
245
246 let input_nodes = graph.input_nodes();
248 let call_nodes = graph.call_nodes();
249 let output_nodes = graph.output_nodes();
250
251 assert_eq!(input_nodes.len(), 1);
252 assert_eq!(call_nodes.len(), 1);
253 assert_eq!(output_nodes.len(), 1);
254
255 if let Node::Call(op_name, _) = &call_nodes[0].1 {
257 assert_eq!(op_name, "relu");
258 } else {
259 panic!("Expected Call node");
260 }
261 }
262
263 #[test]
264 fn test_sequential_ops_graph() {
265 let ops = vec!["relu", "sigmoid", "tanh"];
266 let graph = FxGraph::sequential_ops(&ops);
267
268 assert_eq!(graph.node_count(), 5); assert_eq!(graph.edge_count(), 4); assert_eq!(graph.inputs().len(), 1);
271 assert_eq!(graph.outputs().len(), 1);
272
273 assert!(graph.validate().is_ok());
275
276 let call_nodes = graph.call_nodes();
278 assert_eq!(call_nodes.len(), 3);
279
280 let mut op_names = Vec::new();
282 for (_, node) in call_nodes {
283 if let Node::Call(op_name, _) = node {
284 op_names.push(op_name.clone());
285 }
286 }
287
288 assert!(op_names.contains(&"relu".to_string()));
290 assert!(op_names.contains(&"sigmoid".to_string()));
291 assert!(op_names.contains(&"tanh".to_string()));
292 }
293
294 #[test]
295 fn test_empty_sequential_ops() {
296 let graph = FxGraph::sequential_ops(&[]);
297 assert_eq!(graph.node_count(), 0);
298 assert_eq!(graph.edge_count(), 0);
299 assert_eq!(graph.inputs().len(), 0);
300 assert_eq!(graph.outputs().len(), 0);
301 }
302
303 #[test]
304 fn test_modular_architecture() {
305 let graph = FxGraph::single_op("test_op", vec!["input".to_string()]);
307
308 assert!(graph.node_count() > 0);
310 assert!(graph.edge_count() > 0);
311 assert!(graph.validate().is_ok());
312
313 let summary = graph.summary();
315 assert!(summary.contains("FX Graph Summary"));
316
317 let debug_graph = FxGraph::debug_minimal();
319 assert!(debug_graph.validate().is_ok());
320
321 let json = graph.to_json().unwrap();
323 assert!(!json.is_empty());
324
325 let deserialized = FxGraph::from_json(&json).unwrap();
326 assert_eq!(deserialized.node_count(), graph.node_count());
327 }
328
329 #[test]
330 fn test_graph_validation() {
331 let graph = FxGraph::single_op("relu", vec!["input".to_string()]);
333 assert!(graph.validate().is_ok());
334
335 let mut invalid_graph = FxGraph::new();
337 let output = invalid_graph.add_node(Node::Output);
338 invalid_graph.add_output(output);
339 assert!(invalid_graph.validate().is_err());
340
341 let mut invalid_graph2 = FxGraph::new();
343 let input = invalid_graph2.add_node(Node::Input("x".to_string()));
344 invalid_graph2.add_input(input);
345 assert!(invalid_graph2.validate().is_err());
346 }
347
348 #[test]
349 fn test_performance_recommendations() {
350 let graph = FxGraph::sequential_ops(&["relu", "sigmoid", "tanh"]);
351 let recommendations = graph.performance_recommendations();
352 assert!(!recommendations.is_empty());
353 }
354
355 #[test]
356 fn test_operation_analysis() {
357 let graph = FxGraph::sequential_ops(&["relu", "sigmoid", "relu"]);
358
359 let op_names = graph.get_operation_names();
361 assert_eq!(op_names.len(), 2); assert!(op_names.contains(&"relu".to_string()));
363 assert!(op_names.contains(&"sigmoid".to_string()));
364
365 assert!(graph.contains_operation("relu"));
367 assert!(graph.contains_operation("sigmoid"));
368 assert!(!graph.contains_operation("tanh"));
369
370 let counts = graph.operation_counts();
372 assert_eq!(counts.get("relu"), Some(&2)); assert_eq!(counts.get("sigmoid"), Some(&1)); assert_eq!(counts.get("tanh"), None); }
376}
377
378pub mod prelude {
380 pub use crate::fx::*;
381 pub use crate::{
382 benchmarking::*, checkpointing::*, codegen::*, custom_backends::*, distributed::*,
383 graph_analysis::*, tracer::*, Edge, FxGraph, GraphStats, MemoryEstimate, Node,
384 SerializableGraph,
385 };
386}