Skip to main content

torsh_functional/
lib.rs

1//! Functional API for ToRSh
2//!
3//! This module provides functional operations similar to torch.functional,
4//! including tensor manipulation, mathematical operations, and utilities.
5//!
6//! For comprehensive performance optimization guidance, see the separate
7//! Performance Tuning Guide documentation.
8
9#![allow(deprecated)]
10
11// Version information
12pub const VERSION: &str = env!("CARGO_PKG_VERSION");
13pub const VERSION_MAJOR: u32 = 0;
14pub const VERSION_MINOR: u32 = 1;
15pub const VERSION_PATCH: u32 = 0;
16
17use torsh_core::Result as TorshResult;
18use torsh_tensor::Tensor;
19
20// API patterns and conventions
21pub mod api_patterns;
22
23// Neural network functional operations
24pub mod activation_lookup;
25pub mod activations;
26pub mod advanced_nn;
27pub mod attention;
28pub mod autograd;
29pub mod conv;
30pub mod dropout;
31pub mod loss;
32pub mod normalization;
33pub mod pooling;
34pub mod regularization;
35
36// Tensor operations
37pub mod advanced_manipulation;
38pub mod broadcast;
39pub mod data_ops;
40pub mod fusion;
41pub mod image;
42pub mod interpolation;
43pub mod lazy;
44pub mod linalg;
45pub mod manipulation;
46pub mod math;
47pub mod numerical;
48pub mod optimization;
49pub mod parallel;
50pub mod profiling;
51pub mod quantization;
52pub mod random_ops;
53pub mod reduction;
54pub mod signal;
55pub mod sparse;
56pub mod special;
57pub mod spectral;
58pub mod spectral_advanced;
59pub mod spectral_analysis;
60pub mod spectral_stft;
61pub mod tensor_decomposition;
62pub mod tensor_ops;
63pub mod transformations;
64pub mod type_promotion;
65pub mod utils;
66pub mod wavelet;
67
68#[cfg(test)]
69pub mod testing;
70
71#[cfg(test)]
72pub mod pytorch_correctness;
73
74#[cfg(test)]
75pub mod numerical_correctness;
76
77#[cfg(test)]
78pub mod property_based_tests;
79
80#[cfg(test)]
81pub mod edge_case_tests;
82
83#[cfg(test)]
84pub mod platform_tests;
85
86// Re-exports for convenience
87
88// Activation functions
89pub use activations::{
90    celu, elu, gelu, gumbel_softmax, hardshrink, hardsigmoid, hardsigmoid_v2, hardswish, hardtanh,
91    leaky_relu, log_sigmoid, log_softmax, mish, prelu, relu, relu6, rrelu, selu, sigmoid, silu,
92    softmax, softmin, softplus, softshrink, softsign, tanh, tanhshrink, threshold,
93};
94
95// Loss functions
96pub use loss::{
97    binary_cross_entropy, binary_cross_entropy_with_logits, contrastive_loss,
98    cosine_embedding_loss, cross_entropy, cross_entropy_with_label_smoothing, ctc_loss, focal_loss,
99    gaussian_nll_loss, hinge_embedding_loss, kl_div, l1_loss, margin_ranking_loss, mse_loss,
100    multi_margin_loss, nll_loss, poisson_nll_loss, smooth_l1_loss, triplet_margin_loss,
101    triplet_margin_with_distance_loss, ReductionType,
102};
103
104// Convolution operations
105pub use conv::{
106    conv1d, conv2d, conv3d, conv_output_size, conv_transpose1d, conv_transpose2d, conv_transpose3d,
107    conv_transpose_output_size, depthwise_conv2d, fold, separable_conv2d, unfold,
108};
109
110// Pooling operations
111pub use pooling::{
112    adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d, adaptive_max_pool1d,
113    adaptive_max_pool2d, adaptive_max_pool3d, avg_pool1d, avg_pool2d, avg_pool3d,
114    fractional_max_pool2d, global_avg_pool, global_max_pool, learnable_pool2d, lp_pool1d,
115    lp_pool2d, max_pool1d, max_pool2d, max_pool3d, max_unpool1d, max_unpool2d, max_unpool3d,
116    spatial_pyramid_pool2d, stochastic_pool2d,
117};
118
119// Normalization functions
120pub use normalization::{
121    batch_norm, group_norm, instance_norm, layer_norm, local_response_norm, normalize, weight_norm,
122};
123
124// Dropout functions
125pub use dropout::{
126    alpha_dropout, dropout, dropout1d, dropout2d, dropout3d, feature_alpha_dropout,
127    gaussian_dropout,
128};
129
130// Attention functions
131pub use attention::{
132    cross_attention, flash_attention, multi_head_attention, scaled_dot_product_attention,
133    self_attention,
134};
135
136// Regularization functions
137pub use regularization::{
138    consistency_penalty, gradient_penalty, r1_gradient_penalty, r2_gradient_penalty,
139    spectral_gradient_penalty,
140};
141
142// Advanced neural network operations
143pub use advanced_nn::{
144    // Data augmentation
145    cutmix,
146    // Neural Architecture Search operations
147    darts_operation,
148    decode_architecture,
149    differentiable_augment,
150    encode_architecture,
151    // Other functions
152    knowledge_distillation_loss,
153    label_smoothing,
154    mixup,
155    mutate_architecture,
156    predict_architecture_performance,
157    // Normalization
158    spectral_norm,
159    temperature_scale,
160    weight_standardization,
161};
162
163// Tensor operations
164pub use broadcast::{broadcast_shapes, broadcast_tensors};
165pub use linalg::{
166    baddbmm, bmm, chain_matmul, cholesky, cond, det, eig, inv, lstsq, lu, matrix_rank, norm,
167    pca_lowrank, pinv, qr, solve, svd, svd_lowrank, triangular_solve, NormOrd,
168};
169pub use manipulation::{
170    atleast_1d, atleast_2d, atleast_3d, block_diag, cartesian_prod, chunk, dsplit, hsplit,
171    meshgrid, split, tensor_split, tensordot, unravel_index, vsplit, SplitArg, TensorSplitArg,
172};
173pub use math::{cdist, einsum};
174pub use reduction::{unique, unique_consecutive, UniqueResult};
175pub use spectral::{
176    cepstrum, create_mel_filterbank, fftn, generate_window, hfft, hz_to_mel, ifftn, ihfft, irfft,
177    istft, istft_complete, mel_spectrogram, mel_to_hz, rfft2, rfftn, spectral_centroid,
178    spectral_rolloff, spectrogram, stft, stft_complete, SpectrogramType, WindowFunction,
179};
180pub use tensor_ops::{
181    cosine_similarity, embedding, linear, one_hot, pairwise_distance, pixel_shuffle,
182    pixel_unshuffle,
183};
184
185// Image processing operations
186pub use image::{
187    affine_transform, closing, dilation, erosion, gaussian_blur, hsv_to_rgb, laplacian_filter,
188    opening, resize, rgb_to_hsv, sobel_filter, InterpolationMode, SobelDirection,
189};
190
191// Signal processing
192pub use signal::{
193    correlate, filtfilt, frame, lfilter, overlap_add, periodogram, welch, window, CorrelationMode,
194    PsdScaling, WindowType,
195};
196
197// Data operations
198pub use data_ops::{
199    bincount,
200    histogram,
201    histogram_with_edges,
202    unique as unique_values, // Renamed to avoid conflict with reduction::unique
203    value_counts,
204};
205
206// Random operations
207pub use random_ops::{
208    bernoulli, bernoulli_, exponential_, multinomial, normal_, rand, randint, randint_, randn,
209    randperm, uniform_,
210};
211
212// Type promotion
213pub use type_promotion::{
214    can_cast_safely, common_dtype_for_operation, ensure_compatible_types, get_type_category,
215    get_type_precision, promote_multiple_types, promote_scalar_type, promote_tensors,
216    promote_types, reduction_result_type, result_type, TypeCategory,
217};
218
219// Operation fusion
220pub use fusion::{
221    analyze_fusion_opportunities, detect_fusible_patterns, fused_add_mul, fused_add_relu_mul,
222    fused_batch_norm, fused_mul_add, fused_relu_add, fused_sigmoid_mul, fused_silu,
223    fused_tanh_scale, AdaptiveFusionEngine, FusedOp, FusionOpportunity, FusionPerformance,
224    OpFusionEngine, OpSequence,
225};
226
227// Special mathematical functions
228pub use special::{
229    acosh,
230    airy_ai,
231    asinh,
232    atanh,
233
234    bessel_i0,
235    bessel_i1,
236    bessel_iv,
237    // Bessel functions
238    bessel_j0,
239    bessel_j1,
240    bessel_jn,
241    bessel_k0,
242    bessel_k1,
243
244    bessel_y0,
245    bessel_y1,
246    bessel_yn,
247    beta,
248    // Advanced special functions with scirs2-special integration
249    betainc,
250    dawson,
251    digamma,
252    // Error functions
253    erf,
254    erfc,
255    erfcinv,
256
257    erfcx,
258    erfinv,
259    expint,
260    expm1,
261    fresnel,
262    fresnel_c,
263    fresnel_s,
264    // Gamma functions
265    gamma,
266    hypergeometric_1f1,
267    kelvin_ber,
268    lgamma,
269    log1p,
270    // Statistical functions
271    logsumexp,
272    multigammaln,
273
274    normal_cdf,
275    normal_icdf,
276
277    polygamma,
278    // Trigonometric and other special functions
279    sinc,
280    // Spherical Bessel functions
281    spherical_j0,
282    spherical_j1,
283    spherical_jn,
284    spherical_y0,
285    spherical_y1,
286    spherical_yn,
287
288    voigt_profile,
289};
290
291// Wavelet transforms
292pub use wavelet::{
293    cwt, dwt_1d, dwt_2d, idwt_1d, idwt_2d, wavedec, waverec, WaveletMode, WaveletType,
294};
295
296// Interpolation functions
297pub use interpolation::{
298    barycentric_interp, grid_sample, interp1d, interp2d, lanczos_interp1d, spline1d,
299    InterpolationMode as InterpMode,
300};
301
302// Numerical methods
303pub use numerical::{
304    adaptive_quad, bisection, cumtrapz, gaussian_quad, gradient, newton_raphson,
305    partial_derivative, second_derivative, simps, trapz, DifferentiationMethod, IntegrationMethod,
306};
307
308// Optimization utilities
309pub use optimization::{
310    adam_optimizer,
311    analyze_optimization_problem,
312    auto_configure_optimization,
313    backtracking_line_search,
314    gradient_descent,
315    lbfgs_optimizer,
316    momentum_gradient_descent,
317    wolfe_line_search,
318    AdamParams,
319    AdaptiveAlgorithmSelector,
320    BFGSParams,
321    BacktrackingParams,
322    GradientDescentParams,
323    LineSearchMethod,
324    MomentumParams,
325    OptimizationAlgorithm,
326    // Adaptive algorithm selection
327    TensorCharacteristics,
328    WolfeParams,
329};
330
331// Lazy evaluation
332pub use lazy::{
333    lazy_ops::{execute, lazy, with_optimization},
334    LazyBuilder, LazyContext, LazyOp, LazyTensor,
335};
336
337// Advanced tensor manipulation
338pub use advanced_manipulation::{
339    boolean_index, cat, masked_fill, pad, reshape, slice_with_step, squeeze, unsqueeze,
340    where_tensor, PaddingMode,
341};
342
343// Quantization and compression
344pub use quantization::{
345    dynamic_quantize, fake_quantize, gradual_magnitude_prune, lottery_ticket_prune,
346    magnitude_prune, quantization_error_analysis, uniform_dequantize, uniform_quantize,
347    weight_clustering, QuantizationScheme, QuantizationType,
348};
349
350// Sparse operations
351pub use sparse::{
352    sparse_add, sparse_conv1d, sparse_conv2d, sparse_coo_tensor, sparse_eye, sparse_max,
353    sparse_mean, sparse_min, sparse_mm, sparse_mul, sparse_sum, sparse_to_csr, sparse_transpose,
354    SparseTensor,
355};
356
357// Custom autograd utilities
358pub use autograd::{
359    apply_custom_function, apply_custom_function_with_context, apply_registered_function,
360    get_global_registry, register_custom_function, AutogradContext, AutogradRegistry,
361    CustomAutogradFunction, CustomAutogradFunctionWithContext, ExpFunction, ScaledAddFunction,
362    SquareFunction,
363};
364
365// Performance profiling and benchmarking
366pub use profiling::{
367    benchmark,
368    global_profiler,
369    profile_operation,
370    run_performance_regression_test,
371    BaselineSummary,
372    BenchmarkConfig,
373    BenchmarkResults,
374    OperationMetrics,
375    OperationSummary,
376    // Performance regression testing
377    PerformanceBaseline,
378    PerformanceRegressionTester,
379    Profiler,
380    RegressionTestConfig,
381    RegressionTestResult,
382    SystemInfo,
383};
384
385// Utility functions and patterns
386pub use utils::{
387    apply_binary_elementwise, apply_conditional_elementwise, apply_elementwise_operation,
388    calculate_pooling_output_size, calculate_pooling_output_size_2d,
389    calculate_pooling_output_size_3d, create_tensor_like, function_context, safe_for_log, safe_log,
390    safe_log_prob, validate_broadcastable_shapes, validate_dimension, validate_elementwise_shapes,
391    validate_loss_params, validate_non_empty, validate_pooling_params, validate_positive,
392    validate_range, validate_tensor_dims,
393};
394
395// Advanced functional transformations
396pub use transformations::{
397    einsum_optimized, tensor_contract, tensor_fold, tensor_map, tensor_outer, tensor_reduce,
398    tensor_scan, tensor_zip,
399};
400
401// Tensor decomposition operations
402pub use tensor_decomposition::{cp_decomposition, tucker_decomposition};
403
404/// Align tensors to have the same number of dimensions
405pub fn align_tensors(tensors: &[Tensor]) -> TorshResult<Vec<Tensor>> {
406    if tensors.is_empty() {
407        return Ok(vec![]);
408    }
409
410    // Find maximum number of dimensions
411    let max_dims = tensors.iter().map(|t| t.shape().ndim()).max().unwrap_or(0);
412
413    // Align all tensors
414    let aligned: TorshResult<Vec<_>> = tensors
415        .iter()
416        .map(|t| {
417            let current_dims = t.shape().ndim();
418            if current_dims < max_dims {
419                // Add dimensions of size 1 at the beginning
420                let mut new_shape = vec![1; max_dims - current_dims];
421                new_shape.extend(t.shape().dims());
422                // Convert to i32 for view function
423                let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
424                t.view(&new_shape_i32)
425            } else {
426                Ok(t.clone())
427            }
428        })
429        .collect();
430
431    aligned
432}
433
434#[cfg(test)]
435mod tests {
436
437    #[test]
438    fn test_align_tensors() {
439        use crate::align_tensors;
440        use torsh_tensor::creation::ones;
441
442        // Test alignment functionality
443        let t1 = ones(&[3, 4]).unwrap();
444        let t2 = ones(&[4]).unwrap();
445        let t3 = ones(&[2, 3, 4]).unwrap();
446
447        let aligned = align_tensors(&[t1, t2, t3]).unwrap();
448
449        // All should have 3 dimensions (max dimensions)
450        assert_eq!(aligned[0].shape().ndim(), 3);
451        assert_eq!(aligned[1].shape().ndim(), 3);
452        assert_eq!(aligned[2].shape().ndim(), 3);
453
454        // Check aligned shapes
455        assert_eq!(aligned[0].shape().dims(), &[1, 3, 4]);
456        assert_eq!(aligned[1].shape().dims(), &[1, 1, 4]);
457        assert_eq!(aligned[2].shape().dims(), &[2, 3, 4]);
458    }
459}