1#![allow(clippy::len_zero)]
8#![allow(clippy::field_reassign_with_default)]
9#![allow(clippy::manual_range_contains)]
10#![allow(clippy::collapsible_if)]
11#![allow(clippy::only_used_in_recursion)]
12#![allow(clippy::needless_range_loop)]
13#![allow(clippy::or_fun_call)]
14#![allow(clippy::derivable_impls)]
15#![allow(clippy::manual_is_multiple_of)]
16#![allow(clippy::overly_complex_bool_expr)]
17#![allow(clippy::unwrap_or_default)]
18pub mod causal;
202pub mod cost_model;
203pub mod critical_path;
204pub mod execution_plan;
205pub mod higher_order;
206pub mod low_rank;
207pub mod memo_cache;
208pub mod partitioned;
209pub mod step_executor;
210
211pub use higher_order::{
212 FiniteDiffMethod, HessianComputer, HessianStats, JacobianComputer, JacobianConfig,
213};
214pub use low_rank::{
215 LowRankApproximation, LowRankCandidate, LowRankConfig, LowRankError, LowRankInferencePass,
216 LowRankPassStats, SvdResult, TruncatedSvd,
217};
218pub use partitioned::{
219 AccumulationStrategy, PartitionConfig, PartitionedError, PartitionedReducer, PartitionedStats,
220};
221pub use step_executor::{BreakpointCondition, IntermediateValue, StepExecutor};
222
223pub mod async_exec;
224pub mod auto_parallel;
225pub mod autodiff;
226pub mod backend_kind;
227pub mod backend_tests;
228pub mod batch;
229pub mod beam_search;
230pub mod cache;
231pub mod cache_optimizer;
232pub mod capabilities;
233pub mod compilation;
234pub mod constraint_propagation;
235pub mod context;
236pub mod debug;
237pub mod diagnostics;
238pub mod distributed;
239mod dummy_executor;
240mod dummy_tensor;
241pub mod dynamic_batching;
242pub mod eager;
243mod error;
244pub mod fusion;
245pub mod gradcheck;
246pub mod jit;
247pub mod join_order;
248pub mod learned_opt;
249pub mod mcmc;
250pub mod memory;
251pub mod mixed_precision;
252pub mod multimodel;
253mod ops;
254pub mod optimization;
255pub mod parallel;
256pub mod perfregression;
257pub mod placement;
258pub mod profiling;
259pub mod profiling_optimizer;
260pub mod pruning;
261pub mod quantization;
262pub mod recovery;
263pub mod rewrite;
264pub mod sampling;
265pub mod scheduling;
266pub mod shape;
267pub mod simd;
268pub mod sparse;
269pub mod speculative;
270pub mod strategy;
271pub mod streaming;
272pub mod symbolic_shape;
273pub mod tensor_stats;
274pub mod tensor_view;
275pub mod trace_recording;
276mod traits;
277pub mod typesafe;
278pub mod uncertainty;
279pub mod validation;
280pub mod visualization;
281pub mod windowed_aggregation;
282pub mod workspace;
283
284#[cfg(test)]
285mod tests;
286
287#[cfg(test)]
288mod validation_tests;
289
290#[cfg(test)]
291mod memory_tests;
292
293#[cfg(feature = "async")]
294pub use async_exec::{
295 AsyncConfig, AsyncExecutionError, AsyncExecutionHandle, AsyncExecutorPool, AsyncStats,
296 AsyncStreamResults, BoxFuture, TlAsyncBatchExecutor, TlAsyncExecutor, TlAsyncStreamExecutor,
297};
298pub use auto_parallel::{
299 AutoParallelError, AutoParallelizer, CostModel as AutoParallelCostModel, DependencyType,
300 NodeId as AutoParallelNodeId, NodeInfo, ParallelExecutionPlan, ParallelStage,
301 ParallelizationAnalysis, ParallelizationStrategy, WorkPartition,
302};
303pub use autodiff::{
304 AccumulationConfig, ClippingStrategy, CustomGradientRegistry, GradientAccumulationStrategy,
305 GradientAccumulator, GradientClipper, GradientConfig, GradientScaler, GradientScaling,
306 GradientStats, TlEnhancedAutodiff,
307};
308pub use backend_kind::{BackendKind, BackendKindError};
309pub use backend_tests::{
310 assert_vec_close, print_test_summary, run_all_basic_tests, run_all_performance_tests,
311 test_backend_edge_cases, test_backend_einsum, test_backend_elem_binary,
312 test_backend_elem_unary, test_backend_forward, test_backend_large_tensors,
313 test_backend_memory_efficiency, test_backend_reduce, test_backend_shapes, BackendTestAdapter,
314 TestResult, DEFAULT_TOLERANCE,
315};
316pub use batch::{BatchResult, TlBatchExecutor};
317pub use beam_search::{
318 BeamHypothesis, BeamSearchConfig, BeamSearchDecoder, BeamSearchError, BeamSearchResult,
319 BeamSearchStats, BeamState, BeamStepInput,
320};
321pub use cache::{CacheKey, CacheStats, EvictionPolicy, MemoryPool, PoolStats, TensorCache};
322pub use cache_optimizer::{
323 AccessPattern, CacheConfig, CacheLevel, CacheMetrics, CacheOptimizer, CacheOptimizerError,
324 DataLayout, OptimizationStats as CacheOptimizationStats, TilingParams,
325};
326pub use capabilities::{BackendCapabilities, DType, DeviceType, Feature, TlCapabilities};
327pub use causal::{
328 ate_backdoor, ate_instrumental_variable, backdoor_criterion, do_intervention,
329 find_backdoor_adjustment, frontdoor_criterion, propensity_score, BackdoorAdjustment,
330 CausalError, CausalGraph, Intervention, ObservationalData, TreatmentEffect,
331};
332pub use compilation::{
333 CacheStats as CompilationCacheStats, CompilationCache, CompilationConfig, CompilationKey,
334 CompilationStats, CompiledGraph, GraphCompiler, OptimizationLevel, TlCompilableExecutor,
335};
336pub use constraint_propagation::{
337 propagate_arc_consistency, solve, BinaryConstraint, ConstraintNetwork, ConstraintRelation,
338 CspConfig, Domain, PropagationResult, SolveStats, VarOrdering,
339};
340pub use context::{ExecutionContext, ExecutionHook, ExecutionPhase, ExecutionState, LoggingHook};
341pub use cost_model::{
342 CostAwareSchedule, CostModel, CostModelConfig, FlopEstimate, GraphCostSummary,
343 MemoryCostEstimate, NodeCostEstimate,
344};
345pub use critical_path::{
346 critical_path, CriticalPathError, CriticalPathReport, CriticalPathResult, InferenceGraph,
347 MissingCostWarning, NodeId as CriticalPathNodeId, NodeLatency,
348};
349pub use debug::{
350 Breakpoint, BreakpointHit, BreakpointManager, ExecutionRecorder, ExecutionReport,
351 ExecutionTrace, ExecutionTracer, OperationHandle, TensorInspector, TensorStats,
352 TraceEntry as DebugTraceEntry, TraceSummary,
353};
354pub use diagnostics::{
355 Diagnostic, DiagnosticCollector, MemoryDiagnostic, NodeExecutionDiagnostic,
356 PerformanceDiagnostic, Severity, ShapeMismatchDiagnostic, SourceLocation,
357 TypeMismatchDiagnostic,
358};
359pub use distributed::{
360 CommunicationBackend, CommunicationOp, DataParallelCoordinator, DistributedConfig,
361 DistributedExecutor, DistributedPlacementPlan, DistributedStats, DummyCommunicationBackend,
362 ModelParallelCoordinator, ParallelismStrategy as DistributedParallelismStrategy,
363 PipelineParallelCoordinator, ReductionOp, ShardingSpec, TlDistributedExecutor,
364};
365pub use dummy_executor::DummyExecutor;
366pub use dummy_tensor::DummyTensor;
367pub use dynamic_batching::{
368 AdaptiveBatcher, BatchRequest, BatchingError, BatchingStats, DynamicBatchConfig,
369 DynamicBatcher, Priority, RequestMetadata, RequestQueue,
370};
371pub use eager::{EagerOp, EagerOps, EagerTape, TlEagerAutodiff, Variable, VariableGrad};
372pub use error::ExecutorError;
373pub use execution_plan::{
374 compute_memory_timeline, ExecutionPlan, MemoryTimelineEntry, PlanFormatter, PlanStep,
375};
376pub use fusion::{
377 FusionCandidate, FusionConfig, FusionCostModel, FusionError, FusionOptimizer, FusionPattern,
378 FusionStats, FusionStrategy,
379};
380pub use gradcheck::{
381 compare_gradients, numerical_gradient_central, numerical_gradient_forward, quick_check,
382 GradCheckConfig, GradCheckResult, GradientChecker, GradientError,
383};
384pub use jit::{
385 AdaptiveOptimizationPlan, AdaptiveOptimizer, HotPathDetector, JitCache, JitCacheEntry,
386 JitCacheStats, JitCompiler, JitConfig, JitEntryStats, JitKey, JitStats, SpecializationContext,
387 TlJitExecutor,
388};
389pub use join_order::{
390 JoinCondition, JoinOptimizerConfig, JoinOrderError, JoinOrderOptimizer, JoinPlan, JoinPlanNode,
391 JoinStats, Relation as JoinRelation,
392};
393pub use learned_opt::{
394 CostPrediction, FeatureVector, FusionRecommendation, LearnedOptError, LearnedOptimizer,
395 LearningStats, LearningStrategy, ModelType, NodeId as LearnedOptNodeId, OptimizationAction,
396 RewardSignal, ScheduleRecommendation, TrainingExample,
397};
398pub use mcmc::{
399 autocorrelation, compute_diagnostics, effective_sample_size, gelman_rubin, ChainDiagnostics,
400 GaussianProposal, HamiltonianMonteCarlo, IndependentGaussianProposal, LogProb, LogProbFn,
401 McmcConfig, McmcError, McmcResult, McmcRng, MetropolisHastings, Proposal,
402};
403pub use memo_cache::{
404 ExprMemoCache, MemoCacheBuilder, MemoConfig, MemoEvictionPolicy, MemoKey, MemoLookupResult,
405 MemoStats,
406};
407pub use memory::{MemoryEstimate, MemoryEstimator, TensorMemory};
408pub use mixed_precision::{
409 GradientCheckpoint, LossScaler, LossScalerStats, LossScalingStrategy, MixedPrecisionConfig,
410 MixedPrecisionError, MixedPrecisionState, MixedPrecisionStats, PrecisionMode,
411};
412pub use multimodel::{
413 CascadeConfig, CoordinationStats, EnsembleConfig, EnsembleStrategy, ModelMetadata,
414 MultiModelCoordinator, MultiModelError, ResourceRequirements, RoutingStrategy,
415 TlEnsembleExecutor, TlModelRouter,
416};
417pub use ops::{ElemOp, ReduceOp};
418pub use optimization::{
419 FusionOpportunity, FusionPlanner, FusionType, GraphOptimizer, OptimizationResult,
420};
421pub use parallel::{
422 LoadBalanceStats, NumaNode, NumaStrategy, ParallelConfig, ParallelError, SchedulerStats,
423 StealStrategy, Task, TaskId, TaskPriority, WorkStealingScheduler,
424};
425pub use perfregression::{
426 BenchmarkBaseline, BenchmarkComparison, BenchmarkConfig, BenchmarkStats, PerfRegression,
427 RegressionReport,
428};
429pub use placement::{Device, PlacementOptimizer, PlacementPlan, PlacementStrategy};
430pub use profiling::{
431 Bottleneck, BottleneckAnalyzer, BottleneckReport, PerformanceBaseline, PerformanceComparison,
432 ProfileData, ProfileStatistics, Profiler, ProfilerHook, TimelineProfiler, TlProfiledExecutor,
433 TraceEntry,
434};
435pub use profiling_optimizer::{
436 ExecutionProfile, Hotspot, OptimizationGoal, OptimizationReport, OptimizationStrategy,
437 ProfilingOptimizer, ProfilingOptimizerError, TuningConfig,
438};
439pub use pruning::{
440 compute_sparsity, row_norms, MagnitudePruner, PruningConfig, PruningError, SparsityPattern,
441 SparsityStats,
442};
443pub use quantization::{
444 CalibrationStats, CalibrationStrategy, FakeQuantize, QuantizationConfig, QuantizationError,
445 QuantizationGranularity, QuantizationMode, QuantizationParams, QuantizationSummary,
446 QuantizationSymmetry, QuantizationType, Quantizer,
447};
448pub use recovery::{
449 Checkpoint, CheckpointManager, DegradationPolicy, FailureInfo, FallbackStrategy,
450 RecoveryConfig, RecoveryMetadata, RecoveryResult, RecoveryStats, RecoveryStrategy, RetryPolicy,
451 TlRecoverableExecutor,
452};
453pub use rewrite::{
454 CommonRules, Match, NodeId as RewriteNodeId, Pattern, ReplacementFn, RewriteEngine,
455 RewriteError, RewriteRule, RewriteStats, RewriteStrategy,
456};
457pub use sampling::{
458 entropy, log_softmax, perplexity, softmax, ConfigurableSampler, GreedyDecoder, SampledToken,
459 SamplingConfig, SamplingError, TemperatureSampler, TopKSampler, TopPSampler,
460};
461pub use scheduling::{ExecutionSchedule, NodeCost, Scheduler, SchedulingStrategy};
462pub use shape::{DimSize, ShapeInferenceContext, TensorShape};
463pub use simd::{
464 AlignedBuffer, CpuArchitecture, SimdCapabilities, SimdError, SimdInstructionSet,
465 SimdOptimizationHints,
466};
467pub use sparse::{
468 detect_sparsity, to_sparse_if_beneficial, SparseCOO, SparseCSC, SparseCSR, SparseError,
469 SparseFormat, SparseTensor, SparseTensorBuilder,
470};
471pub use speculative::{
472 BranchOutcome, NodeId as SpeculativeNodeId, PredictionStrategy, RollbackPolicy,
473 SpeculationStats, SpeculativeError, SpeculativeExecutor, SpeculativeTask,
474};
475pub use strategy::{
476 ExecutionMode, ExecutionStrategy, GradientStrategy, MemoryStrategy, ParallelismStrategy,
477 StrategyOptimizer,
478};
479pub use streaming::{
480 BackpressureConfig, BackpressureStrategy, ChunkIterator, ChunkMetadata, StreamProcessor,
481 StreamResult, StreamingConfig, StreamingConfigV2, StreamingMode, StreamingStats,
482 TlStreamingExecutor, WatermarkConfig,
483};
484pub use symbolic_shape::{
485 propagate_chain, propagate_einsum_shapes, ShapeError, SymbolicDim, SymbolicShape,
486 SymbolicShapeConstraint, SymbolicShapeEnv,
487};
488pub use tensor_stats::{
489 ActivationStatistics, AnomalyDetector, AnomalyKind, AnomalyReport, StatsError,
490 TensorStats as TensorStatsSummary,
491};
492pub use tensor_view::{
493 InPlaceMode, InPlaceOps, SliceSpec, TensorView, TensorViewable, ViewBuilder,
494};
495pub use trace_recording::{
496 CommunicationBottleneck, DeviceSummary, LoadBalanceMetrics, OpSummary, RecordedExecutionTrace,
497 RecordedTraceEntry, TraceAnalyzer, TraceRecorder,
498};
499pub use traits::{TlAutodiff, TlExecutor};
500pub use typesafe::{
501 BroadcastShape, Dim, DimMul, DimOp, DimSize as TypesafeDimSize, Dyn, EinsumSpec, FixedShape,
502 Matrix, MatrixOps, Nat, Scalar, ShapeConstraint, ShapedTensor, Static, Tensor3D, Tensor4D,
503 TensorBuilder, TypedBatch, TypedInputs, TypedOutputs, TypedTensor, TypedTensorOps, Vector, D1,
504 D2, D3, D4, D5, D6, S, Z,
505};
506pub use uncertainty::{
507 find_optimal_temperature, temperature_scale, CalibrationBin, CalibrationMetrics,
508 ConfidenceInterval, IntervalMethod, MonteCarloEstimator, PredictionInterval, UncertaintyError,
509 UncertaintyEstimate,
510};
511pub use validation::{GraphValidator, ValidationResult};
512pub use visualization::{
513 ExportFormat, GraphConfig, GraphVisualizer, TensorStatsVisualizer, TimelineConfig,
514 TimelineVisualizer, VisualizationFormat,
515};
516pub use windowed_aggregation::{
517 WindowAggregation, WindowConfig, WindowError, WindowResult, WindowType, WindowedAggregation,
518};
519pub use workspace::{
520 AllocationStrategy, DefragmentationResult, SharedWorkspacePool, Workspace, WorkspaceConfig,
521 WorkspaceError, WorkspacePool, WorkspaceStats,
522};