1#![allow(dead_code)] #![allow(unused_variables)] use thiserror::Error;
33use torsh_core::{DType, DeviceType, TorshError};
34
35pub mod abstract_interpretation;
36pub mod adaptive_compilation;
37pub mod advisor;
38pub mod analysis;
39pub mod benchmarking;
40pub mod codegen;
41pub mod const_eval;
42pub mod cranelift_backend;
43pub mod custom_ops;
44pub mod debug_symbols;
45pub mod debugger;
46pub mod differentiable_compilation;
47pub mod error_diagnostics;
48pub mod fusion;
49pub mod generics;
50pub mod graph;
51pub mod hardware_tuning;
52pub mod ir;
53pub mod llvm_backend;
54pub mod lowering;
55pub mod metaprogramming;
56pub mod mlir_backend;
57pub mod neural_compilation;
58pub mod optimization_advisor;
59pub mod optimizer;
60pub mod partial_evaluation;
61pub mod pgo;
62pub mod plugin_system;
63pub mod polyhedral_optimization;
64pub mod probabilistic_compilation;
65pub mod profiler;
66pub mod program_synthesis;
67pub mod runtime;
68pub mod script;
69pub mod specialization;
70pub mod speculative_optimization;
71pub mod symbolic_execution;
72pub mod trace_viz;
73pub mod tracing;
74pub mod type_inference;
75
76#[cfg(test)]
77pub mod compilation_test;
78
79pub use abstract_interpretation::{
81 AbstractAnalysisResult, AbstractDomain, AbstractInterpretationConfig, AbstractInterpreter,
82 AbstractValue, ConstantDomain, IntervalDomain, SignDomain,
83};
84pub use adaptive_compilation::{
85 AdaptiveCompiler, AdaptiveConfig, CompilationStrategy, PerformanceMetrics,
86};
87pub use codegen::CodeGenerator;
88pub use const_eval::{ConstEvalConfig, ConstantEvaluator, ConstantValue, EvaluationResult};
89pub use custom_ops::{get_custom_op, list_custom_ops, register_custom_op, CustomOpBuilder};
90pub use debug_symbols::{DebugSymbolConfig, DebugSymbolManager, SourceLocation, SymbolTable};
91pub use debugger::{
92 BreakpointLocation, DebugCommand, DebugSession, DebugState, DebugValue, DebuggerConfig,
93 ExecutionLocation, InspectionTarget, JitDebugger,
94};
95pub use differentiable_compilation::{
96 CompilationParams, CompilationTrainer, DiffCompilationResult, DifferentiableCompiler,
97 GumbelSoftmax, PerformanceMetrics as DiffPerformanceMetrics, SoftDecision,
98};
99pub use error_diagnostics::{
100 DiagnosticError, ErrorCategory, ErrorDiagnosticsManager, ErrorSeverity,
101};
102pub use fusion::{FusionStrategy, KernelFusion};
103pub use generics::{
104 create_type_param, shape_constraint, trait_constraint, GenericFunctionManager,
105 GenericFunctionTemplate, ParameterKind, TypeConstraint, TypeParameter,
106};
107pub use graph::{ComputationGraph, Edge, Node, NodeId};
108pub use hardware_tuning::{
109 Architecture, HardwareInfo, HardwareTuner, HardwareTuningConfig, TuningRecommendation,
110};
111pub use llvm_backend::{LlvmBackend, LlvmOptimizer};
112pub use metaprogramming::{
113 CodeTemplate, DynamicCodeGenerator, GeneratedCode, GraphReflection, MacroDefinition,
114 MetaprogrammingEngine, RuntimeReflector, TemplateParameters,
115};
116pub use mlir_backend::{MlirBackend, MlirOptimizer, MlirPass};
117pub use neural_compilation::{
118 CompilationStrategy as NeuralCompilationStrategy, GraphFeatures, NeuralCompiler,
119 NeuralCompilerConfig, OptimizationDecision,
120};
121pub use optimizer::GraphOptimizer;
122pub use partial_evaluation::{
123 ConstantFolder, EvaluationStatistics, FunctionSpecializer, OptimizedGraph, OptimizedIrModule,
124 PartialEvalConfig, PartialEvaluator,
125};
126pub use pgo::{
127 OptimizationRecommendation as PgoRecommendation, OptimizationType as PgoOptimizationType,
128 PgoConfig, ProfileGuidedOptimizer,
129};
130pub use plugin_system::{
131 load_all_plugins, load_plugin, Plugin, PluginCapability, PluginManager, PluginMetadata,
132 PluginRegistry,
133};
134pub use polyhedral_optimization::{
135 AffineExpr, AffineSchedule, LoopNest, PolyhedralConfig, PolyhedralOptimizer, Polyhedron,
136 TransformationMatrix, TransformationType,
137};
138pub use probabilistic_compilation::{
139 BetaDistribution, MonteCarloResult, NormalDistribution, ProbabilisticCompilationResult,
140 ProbabilisticCompiler, ProbabilisticConfig, ProbabilisticPerformance, UncertainDecision,
141};
142pub use profiler::{PerformanceEvent, ProfilerConfig, ProfilerManager, ProfilingSession};
143pub use program_synthesis::{
144 ExampleBuilder, ProgramSynthesizer, SynthesisExample, SynthesisResult, SynthesisStrategy,
145 SynthesisTemplate, SynthesisValue,
146};
147pub use runtime::JitRuntime;
148pub use script::{export_torchscript, import_torchscript, ScriptCompiler};
149pub use specialization::{
150 create_specialized_type, SpecializationConfig, SpecializedType, TypeSpecializer,
151};
152pub use speculative_optimization::{
153 DeoptimizationEvent, SpeculationResult, SpeculativeConfig, SpeculativeOptimizer,
154};
155pub use symbolic_execution::{
156 Constraint, ConstraintSet, ExecutionState, SymbolicExecutionConfig, SymbolicExecutionEngine,
157 SymbolicExecutionResult, SymbolicGraph, SymbolicValue,
158};
159pub use trace_viz::{
160 TraceEvent, TraceVisualizationManager, VisualizationConfig, VisualizationSession,
161};
162
163pub use optimization_advisor::{
165 analyze_computation_graph, analyze_with_benchmarks, analyze_with_profiling, create_advisor,
166 create_advisor_with_config, create_fast_config, create_production_config,
167 create_thorough_config, quick_analyze, AdvisorConfig, AnalysisInput, CostBenefitAnalysis,
168 OptimizationAdvisor, OptimizationRecommendation, OptimizationReport, OptimizationType,
169 PatternAnalysis, PerformanceAnalysis, SystemConstraints, TargetPlatform, UserPreferences,
170};
171
172pub type IrFunction = ir::IrModule; pub type IrInstruction = ir::Instruction; #[derive(Error, Debug)]
178pub enum JitError {
179 #[error("Graph construction failed: {0}")]
180 GraphError(String),
181
182 #[error("Fusion error: {0}")]
183 FusionError(String),
184
185 #[error("Code generation failed: {0}")]
186 CodeGenError(String),
187
188 #[error("Optimization error: {0}")]
189 OptimizationError(String),
190
191 #[error("Runtime error: {0}")]
192 RuntimeError(String),
193
194 #[error("Unsupported operation: {0}")]
195 UnsupportedOp(String),
196
197 #[error("Compilation error: {0}")]
198 CompilationError(String),
199
200 #[error("Analysis error: {0}")]
201 AnalysisError(String),
202
203 #[error("Abstract interpretation error: {0}")]
204 AbstractInterpretationError(String),
205
206 #[error("Backend error: {0}")]
207 BackendError(#[from] TorshError),
208}
209
210impl From<String> for JitError {
211 fn from(msg: String) -> Self {
212 JitError::RuntimeError(msg)
213 }
214}
215
216pub type JitResult<T> = Result<T, JitError>;
217
218#[derive(Debug, Clone)]
220pub struct JitConfig {
221 pub fusion_strategy: FusionStrategy,
223
224 pub enable_optimizations: bool,
226
227 pub max_fusion_size: usize,
229
230 pub enable_profiling: bool,
232
233 pub target_device: DeviceType,
235
236 pub enable_caching: bool,
238
239 pub enable_specialization: bool,
241
242 pub specialization_config: SpecializationConfig,
244}
245
246impl Default for JitConfig {
247 fn default() -> Self {
248 Self {
249 fusion_strategy: FusionStrategy::Default,
250 enable_optimizations: true,
251 max_fusion_size: 8,
252 enable_profiling: false,
253 target_device: DeviceType::Cpu,
254 enable_caching: true,
255 enable_specialization: true,
256 specialization_config: SpecializationConfig::default(),
257 }
258 }
259}
260
261pub struct JitCompiler {
263 config: JitConfig,
264 runtime: JitRuntime,
265 specializer: TypeSpecializer,
266 generics: GenericFunctionManager,
267 debug_symbols: DebugSymbolManager,
268 profiler: ProfilerManager,
269 trace_viz: TraceVisualizationManager,
270 error_diagnostics: ErrorDiagnosticsManager,
271}
272
273impl JitCompiler {
274 pub fn new(config: JitConfig) -> Self {
276 Self {
277 runtime: JitRuntime::new(config.clone()),
278 specializer: TypeSpecializer::new(config.specialization_config.clone()),
279 generics: GenericFunctionManager::with_defaults(),
280 debug_symbols: DebugSymbolManager::with_defaults(),
281 profiler: ProfilerManager::with_defaults(),
282 trace_viz: TraceVisualizationManager::with_defaults(),
283 error_diagnostics: ErrorDiagnosticsManager::with_defaults(),
284 config,
285 }
286 }
287
288 pub fn compile(&mut self, graph: ComputationGraph) -> JitResult<CompiledModule> {
290 graph
292 .validate()
293 .map_err(|e| JitError::GraphError(format!("{:?}", e)))?;
294
295 let inferred_graph = self.apply_type_shape_inference(graph)?;
297
298 let optimized_graph = if self.config.enable_optimizations {
300 let optimizer = GraphOptimizer::new();
301 optimizer.optimize(inferred_graph)?
302 } else {
303 inferred_graph
304 };
305
306 let fusion = KernelFusion::new(self.config.fusion_strategy.clone());
308 let fused_graph = fusion.apply(optimized_graph)?;
309
310 let ir_module = crate::lowering::lower_graph_to_ir(&fused_graph, "jit_module".to_string())?;
312
313 let optimized_ir = self.apply_ir_optimizations(ir_module)?;
315
316 let compiled_kernels = self.generate_code(&optimized_ir)?;
318
319 Ok(CompiledModule {
321 graph: fused_graph,
322 kernels: compiled_kernels,
323 runtime: self.runtime.clone(),
324 })
325 }
326
327 fn apply_type_shape_inference(
329 &self,
330 mut graph: ComputationGraph,
331 ) -> JitResult<ComputationGraph> {
332 use crate::type_inference::{ShapeInference, TypeInference};
333
334 let mut type_inf = TypeInference::new();
336 type_inf.infer_types(&graph)?;
337
338 let mut shape_inf = ShapeInference::new();
340 shape_inf.infer_shapes(&graph)?;
341
342 let node_ids: Vec<_> = graph.nodes().map(|(id, _)| id).collect();
344 for node_id in node_ids {
345 if let Some(inferred_type) = type_inf.get_type(node_id) {
346 if let Some(node_mut) = graph.node_mut(node_id) {
347 node_mut.dtype = inferred_type;
348 }
349 }
350 if let Some(inferred_shape) = shape_inf.get_shape(node_id) {
351 if let Some(node_mut) = graph.node_mut(node_id) {
352 node_mut.output_shape = inferred_shape.clone();
353 }
354 }
355 }
356
357 Ok(graph)
358 }
359
360 fn apply_ir_optimizations(
362 &self,
363 mut ir_module: crate::ir::IrModule,
364 ) -> JitResult<crate::ir::IrModule> {
365 use crate::lowering::{IrConstantFolding, IrDeadCodeElimination, IrPass};
366
367 let dce = IrDeadCodeElimination;
369 dce.run(&mut ir_module)?;
370
371 let cf = IrConstantFolding;
373 cf.run(&mut ir_module)?;
374
375 ir_module.validate().map_err(JitError::GraphError)?;
377
378 Ok(ir_module)
379 }
380
381 fn generate_code(&self, ir_module: &crate::ir::IrModule) -> JitResult<Vec<CompiledKernel>> {
383 match self.config.target_device {
384 DeviceType::Cpu => {
385 #[cfg(feature = "cranelift-backend")]
386 {
387 let mut codegen = crate::cranelift_backend::CraneliftCodeGen::new()?;
388 codegen.generate(ir_module)
389 }
390 #[cfg(not(feature = "cranelift-backend"))]
391 {
392 let codegen = CodeGenerator::new(self.config.target_device.clone());
394 codegen.generate_interpreter(ir_module)
395 }
396 }
397 _ => {
398 let codegen = CodeGenerator::new(self.config.target_device);
400 codegen.generate_from_ir(ir_module)
401 }
402 }
403 }
404}
405
406pub struct CompiledModule {
408 graph: ComputationGraph,
409 kernels: Vec<CompiledKernel>,
410 runtime: JitRuntime,
411}
412
413impl CompiledModule {
414 pub fn execute(&self, inputs: &[TensorRef]) -> JitResult<Vec<TensorRef>> {
416 self.runtime.execute(&self.graph, &self.kernels, inputs)
417 }
418
419 pub fn stats(&self) -> ExecutionStats {
421 self.runtime.stats()
422 }
423}
424
425pub struct CompiledKernel {
427 pub id: String,
429
430 pub source_nodes: Vec<NodeId>,
432
433 pub code: Vec<u8>,
435
436 pub metadata: KernelMetadata,
438}
439
440#[derive(Debug, Clone)]
442pub struct KernelMetadata {
443 pub inputs: Vec<TensorDesc>,
445
446 pub outputs: Vec<TensorDesc>,
448
449 pub shared_memory: usize,
451
452 pub block_size: (usize, usize, usize),
454
455 pub grid_size: (usize, usize, usize),
457}
458
459#[derive(Debug, Clone)]
461pub struct TensorDesc {
462 pub dtype: DType,
463 pub shape: Vec<usize>,
464 pub strides: Vec<usize>,
465 pub offset: usize,
466}
467
468#[derive(Debug, Clone, Default)]
470pub struct ExecutionStats {
471 pub total_time_us: u64,
473
474 pub kernel_launches: usize,
476
477 pub memory_transferred: usize,
479
480 pub cache_hit_rate: f32,
482}
483
484#[derive(Clone, Debug)]
486pub struct TensorRef {
487 pub data: Vec<f32>,
489}
490
491pub fn trace<F>(_func: F, _example_inputs: &[TensorRef]) -> JitResult<CompiledModule>
521where
522 F: Fn(&[TensorRef]) -> Vec<TensorRef>,
523{
524 Ok(CompiledModule {
534 graph: ComputationGraph::new(),
535 kernels: Vec::new(),
536 runtime: JitRuntime::new(JitConfig::default()),
537 })
538}
539
540pub fn script<M>(module: M) -> JitResult<CompiledModule>
542where
543 M: ScriptableModule,
544{
545 script::script(module)
546}
547
548pub trait ScriptableModule {
550 fn to_graph(&self) -> JitResult<ComputationGraph>;
552}
553
554pub mod utils {
556 use super::{graph, ComputationGraph, DType, FusionStrategy, JitConfig};
557
558 #[must_use]
566 pub fn estimate_compilation_time(graph: &ComputationGraph) -> u64 {
567 let node_count = graph.nodes().count();
568 let edge_count = graph.edges().count();
569
570 let base_overhead = 10; let node_time = (node_count as f64 * 0.5) as u64;
573 let edge_time = (edge_count as f64 * 0.1) as u64;
574
575 base_overhead + node_time + edge_time
576 }
577
578 #[must_use]
585 pub fn estimate_memory_usage(graph: &ComputationGraph) -> usize {
586 let mut total_bytes = 0;
587
588 for (_, node) in graph.nodes() {
589 let elements: usize = node.output_shape.dims().iter().product();
590 let dtype_size = match node.dtype {
591 DType::F32 | DType::I32 | DType::U32 | DType::QInt32 => 4,
592 DType::F64 | DType::I64 | DType::U64 | DType::C64 => 8,
593 DType::F16 | DType::BF16 | DType::I16 => 2,
594 DType::I8 | DType::U8 | DType::Bool | DType::QInt8 | DType::QUInt8 => 1,
595 DType::C128 => 16,
596 };
597
598 total_bytes += elements * dtype_size;
599 }
600
601 let overhead = graph.nodes().count() * 256; total_bytes + overhead
604 }
605
606 #[must_use]
613 pub fn should_jit_compile(graph: &ComputationGraph) -> bool {
614 let node_count = graph.nodes().count();
615
616 if node_count < 5 {
618 return false;
619 }
620
621 let fusion_opportunities = count_fusion_opportunities(graph);
623 if fusion_opportunities > 3 {
624 return true; }
626
627 node_count >= 10
630 }
631
632 fn count_fusion_opportunities(graph: &ComputationGraph) -> usize {
634 let mut opportunities = 0;
635
636 for (node_id, node) in graph.nodes() {
637 let predecessors = graph.predecessors(node_id).count();
639
640 if predecessors > 0 && is_fusible_op(&node.op) {
641 opportunities += 1;
642 }
643 }
644
645 opportunities
646 }
647
648 fn is_fusible_op(op: &graph::Operation) -> bool {
650 matches!(
651 op,
652 graph::Operation::Add
653 | graph::Operation::Sub
654 | graph::Operation::Mul
655 | graph::Operation::Div
656 | graph::Operation::Relu
657 | graph::Operation::Sigmoid
658 | graph::Operation::Tanh
659 | graph::Operation::Gelu
660 | graph::Operation::Exp
661 | graph::Operation::Log
662 | graph::Operation::Sqrt
663 | graph::Operation::Neg
664 | graph::Operation::Abs
665 )
666 }
667
668 #[must_use]
672 pub fn recommend_config(graph: &ComputationGraph) -> JitConfig {
673 let node_count = graph.nodes().count();
674 let fusion_ops = count_fusion_opportunities(graph);
675
676 let mut config = JitConfig::default();
677
678 if fusion_ops > 10 {
680 config.fusion_strategy = FusionStrategy::Aggressive;
681 config.max_fusion_size = 16;
682 } else if fusion_ops > 5 {
683 config.fusion_strategy = FusionStrategy::Default;
684 config.max_fusion_size = 8;
685 } else {
686 config.fusion_strategy = FusionStrategy::Conservative;
687 config.max_fusion_size = 4;
688 }
689
690 config.enable_optimizations = node_count >= 10;
692
693 config.enable_profiling = node_count >= 50;
695
696 config
697 }
698
699 #[must_use]
706 pub fn estimate_flops(graph: &ComputationGraph) -> u64 {
707 let mut total_flops = 0u64;
708
709 for (_, node) in graph.nodes() {
710 let elements: u64 = node.output_shape.dims().iter().product::<usize>() as u64;
711
712 let op_flops = match &node.op {
713 graph::Operation::MatMul => {
714 let dim = (elements as f64).sqrt() as u64;
717 2 * dim * dim * dim
718 }
719 graph::Operation::Conv2d { .. } => {
720 elements * 9 }
723 graph::Operation::Add
724 | graph::Operation::Sub
725 | graph::Operation::Mul
726 | graph::Operation::Div => elements,
727 graph::Operation::Relu | graph::Operation::Abs | graph::Operation::Neg => {
728 elements / 2 }
730 graph::Operation::Exp
731 | graph::Operation::Log
732 | graph::Operation::Sqrt
733 | graph::Operation::Sin
734 | graph::Operation::Cos => {
735 elements * 10 }
737 graph::Operation::Sigmoid | graph::Operation::Tanh | graph::Operation::Gelu => {
738 elements * 5 }
740 _ => elements, };
742
743 total_flops += op_flops;
744 }
745
746 total_flops
747 }
748
749 #[must_use]
757 pub fn estimate_arithmetic_intensity(graph: &ComputationGraph) -> f64 {
758 let flops = estimate_flops(graph) as f64;
759 let bytes = estimate_memory_usage(graph) as f64;
760
761 if bytes > 0.0 {
762 flops / bytes
763 } else {
764 0.0
765 }
766 }
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772 use crate::graph::{Node, Operation};
773
774 #[test]
775 fn test_jit_config_default() {
776 let config = JitConfig::default();
777 assert!(config.enable_optimizations);
778 assert_eq!(config.max_fusion_size, 8);
779 assert!(!config.enable_profiling);
780 }
781
782 #[test]
783 fn test_jit_compiler_creation() {
784 let config = JitConfig::default();
785 let _compiler = JitCompiler::new(config);
786 assert!(true);
788 }
789
790 #[test]
791 fn test_utils_estimate_compilation_time() {
792 let mut graph = ComputationGraph::new();
793
794 let time = utils::estimate_compilation_time(&graph);
796 assert!(time >= 10); for i in 0..10 {
800 let node = Node::new(Operation::Add, format!("node_{}", i))
801 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[100]))])
802 .with_dtypes(vec![DType::F32])
803 .with_device(DeviceType::Cpu);
804 graph.add_node(node);
805 }
806
807 let time_with_nodes = utils::estimate_compilation_time(&graph);
808 assert!(time_with_nodes > time); }
810
811 #[test]
812 fn test_utils_estimate_memory_usage() {
813 let mut graph = ComputationGraph::new();
814
815 let node = Node::new(Operation::Add, "test".to_string())
817 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[100, 100]))])
818 .with_dtypes(vec![DType::F32])
819 .with_device(DeviceType::Cpu);
820 graph.add_node(node);
821
822 let memory = utils::estimate_memory_usage(&graph);
823
824 let expected_min = 100 * 100 * 4;
826 assert!(memory >= expected_min);
827 }
828
829 #[test]
830 fn test_utils_should_jit_compile() {
831 let mut graph = ComputationGraph::new();
832
833 for i in 0..3 {
835 let node = Node::new(Operation::Add, format!("node_{}", i))
836 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
837 .with_dtypes(vec![DType::F32])
838 .with_device(DeviceType::Cpu);
839 graph.add_node(node);
840 }
841
842 assert!(!utils::should_jit_compile(&graph));
843
844 for i in 3..15 {
846 let node = Node::new(Operation::Add, format!("node_{}", i))
847 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
848 .with_dtypes(vec![DType::F32])
849 .with_device(DeviceType::Cpu);
850 graph.add_node(node);
851 }
852
853 assert!(utils::should_jit_compile(&graph));
854 }
855
856 #[test]
857 fn test_utils_recommend_config() {
858 let mut graph = ComputationGraph::new();
859
860 let mut prev_nodes = Vec::new();
862
863 for i in 0..15 {
864 let op = if i % 3 == 0 {
865 Operation::Add
866 } else if i % 3 == 1 {
867 Operation::Mul
868 } else {
869 Operation::Relu
870 };
871
872 let node = Node::new(op, format!("node_{}", i))
873 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[100]))])
874 .with_dtypes(vec![DType::F32])
875 .with_device(DeviceType::Cpu);
876 let node_id = graph.add_node(node);
877
878 if let Some(&prev) = prev_nodes.last() {
880 graph.add_edge(
881 prev,
882 node_id,
883 crate::graph::Edge {
884 src_output: 0,
885 dst_input: 0,
886 },
887 );
888 }
889
890 prev_nodes.push(node_id);
891 }
892
893 let config = utils::recommend_config(&graph);
894
895 assert!(config.enable_optimizations);
897 assert!(config.max_fusion_size >= 4);
899 }
900
901 #[test]
902 fn test_utils_estimate_flops() {
903 let mut graph = ComputationGraph::new();
904
905 let node = Node::new(Operation::MatMul, "matmul".to_string())
907 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[64, 64]))])
908 .with_dtypes(vec![DType::F32])
909 .with_device(DeviceType::Cpu);
910 graph.add_node(node);
911
912 let flops = utils::estimate_flops(&graph);
913
914 assert!(flops > 100_000);
916 }
917
918 #[test]
919 fn test_utils_estimate_arithmetic_intensity() {
920 let mut graph = ComputationGraph::new();
921
922 let node = Node::new(Operation::Add, "add".to_string())
924 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[1000]))])
925 .with_dtypes(vec![DType::F32])
926 .with_device(DeviceType::Cpu);
927 graph.add_node(node);
928
929 let intensity = utils::estimate_arithmetic_intensity(&graph);
930
931 assert!(intensity > 0.0);
933 assert!(intensity.is_finite());
934 }
935
936 #[test]
937 fn test_trace_placeholder() {
938 let result = trace(|_inputs| vec![], &[]);
940 assert!(result.is_ok());
941
942 let module = result.unwrap();
943 assert_eq!(module.kernels.len(), 0); }
945
946 #[test]
947 fn test_jit_error_display() {
948 let error = JitError::GraphError("test error".to_string());
949 let display = format!("{}", error);
950 assert!(display.contains("test error"));
951 }
952
953 #[test]
954 fn test_jit_config_builder_pattern() {
955 let config = JitConfig {
956 fusion_strategy: FusionStrategy::Aggressive,
957 enable_optimizations: true,
958 max_fusion_size: 16,
959 enable_profiling: true,
960 target_device: DeviceType::Cpu,
961 enable_caching: true,
962 enable_specialization: false,
963 specialization_config: SpecializationConfig::default(),
964 };
965
966 assert_eq!(config.fusion_strategy, FusionStrategy::Aggressive);
967 assert!(config.enable_optimizations);
968 assert_eq!(config.max_fusion_size, 16);
969 assert!(config.enable_caching);
970 }
971}
972
973pub const VERSION: &str = env!("CARGO_PKG_VERSION");
975pub const VERSION_MAJOR: u32 = 0;
976pub const VERSION_MINOR: u32 = 1;
977pub const VERSION_PATCH: u32 = 0;
978
979#[allow(ambiguous_glob_reexports)]
981pub mod prelude {
982 pub use crate::{
983 abstract_interpretation::*, adaptive_compilation::*, codegen::*, const_eval::*,
984 custom_ops::*, debug_symbols::*, debugger::*, differentiable_compilation::*,
985 error_diagnostics::*, fusion::*, graph::*, optimizer::*, runtime::*, script::*, tracing::*,
986 };
987}