1#![allow(deprecated)]
10
11pub const VERSION: &str = env!("CARGO_PKG_VERSION");
13pub const VERSION_MAJOR: u32 = 0;
14pub const VERSION_MINOR: u32 = 1;
15pub const VERSION_PATCH: u32 = 0;
16
17use torsh_core::Result as TorshResult;
18use torsh_tensor::Tensor;
19
20pub mod api_patterns;
22
23pub mod activation_lookup;
25pub mod activations;
26pub mod advanced_nn;
27pub mod attention;
28pub mod autograd;
29pub mod conv;
30pub mod dropout;
31pub mod loss;
32pub mod normalization;
33pub mod pooling;
34pub mod regularization;
35
36pub mod advanced_manipulation;
38pub mod broadcast;
39pub mod data_ops;
40pub mod fusion;
41pub mod image;
42pub mod interpolation;
43pub mod lazy;
44pub mod linalg;
45pub mod manipulation;
46pub mod math;
47pub mod numerical;
48pub mod optimization;
49pub mod parallel;
50pub mod profiling;
51pub mod quantization;
52pub mod random_ops;
53pub mod reduction;
54pub mod signal;
55pub mod sparse;
56pub mod special;
57pub mod spectral;
58pub mod spectral_advanced;
59pub mod spectral_analysis;
60pub mod spectral_stft;
61pub mod tensor_decomposition;
62pub mod tensor_ops;
63pub mod transformations;
64pub mod type_promotion;
65pub mod utils;
66pub mod wavelet;
67
68#[cfg(test)]
69pub mod testing;
70
71#[cfg(test)]
72pub mod pytorch_correctness;
73
74#[cfg(test)]
75pub mod numerical_correctness;
76
77#[cfg(test)]
78pub mod property_based_tests;
79
80#[cfg(test)]
81pub mod edge_case_tests;
82
83#[cfg(test)]
84pub mod platform_tests;
85
86pub use activations::{
90 celu, elu, gelu, gumbel_softmax, hardshrink, hardsigmoid, hardsigmoid_v2, hardswish, hardtanh,
91 leaky_relu, log_sigmoid, log_softmax, mish, prelu, relu, relu6, rrelu, selu, sigmoid, silu,
92 softmax, softmin, softplus, softshrink, softsign, tanh, tanhshrink, threshold,
93};
94
95pub use loss::{
97 binary_cross_entropy, binary_cross_entropy_with_logits, contrastive_loss,
98 cosine_embedding_loss, cross_entropy, cross_entropy_with_label_smoothing, ctc_loss, focal_loss,
99 gaussian_nll_loss, hinge_embedding_loss, kl_div, l1_loss, margin_ranking_loss, mse_loss,
100 multi_margin_loss, nll_loss, poisson_nll_loss, smooth_l1_loss, triplet_margin_loss,
101 triplet_margin_with_distance_loss, ReductionType,
102};
103
104pub use conv::{
106 conv1d, conv2d, conv3d, conv_output_size, conv_transpose1d, conv_transpose2d, conv_transpose3d,
107 conv_transpose_output_size, depthwise_conv2d, fold, separable_conv2d, unfold,
108};
109
110pub use pooling::{
112 adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d, adaptive_max_pool1d,
113 adaptive_max_pool2d, adaptive_max_pool3d, avg_pool1d, avg_pool2d, avg_pool3d,
114 fractional_max_pool2d, global_avg_pool, global_max_pool, learnable_pool2d, lp_pool1d,
115 lp_pool2d, max_pool1d, max_pool2d, max_pool3d, max_unpool1d, max_unpool2d, max_unpool3d,
116 spatial_pyramid_pool2d, stochastic_pool2d,
117};
118
119pub use normalization::{
121 batch_norm, group_norm, instance_norm, layer_norm, local_response_norm, normalize, weight_norm,
122};
123
124pub use dropout::{
126 alpha_dropout, dropout, dropout1d, dropout2d, dropout3d, feature_alpha_dropout,
127 gaussian_dropout,
128};
129
130pub use attention::{
132 cross_attention, flash_attention, multi_head_attention, scaled_dot_product_attention,
133 self_attention,
134};
135
136pub use regularization::{
138 consistency_penalty, gradient_penalty, r1_gradient_penalty, r2_gradient_penalty,
139 spectral_gradient_penalty,
140};
141
142pub use advanced_nn::{
144 cutmix,
146 darts_operation,
148 decode_architecture,
149 differentiable_augment,
150 encode_architecture,
151 knowledge_distillation_loss,
153 label_smoothing,
154 mixup,
155 mutate_architecture,
156 predict_architecture_performance,
157 spectral_norm,
159 temperature_scale,
160 weight_standardization,
161};
162
163pub use broadcast::{broadcast_shapes, broadcast_tensors};
165pub use linalg::{
166 baddbmm, bmm, chain_matmul, cholesky, cond, det, eig, inv, lstsq, lu, matrix_rank, norm,
167 pca_lowrank, pinv, qr, solve, svd, svd_lowrank, triangular_solve, NormOrd,
168};
169pub use manipulation::{
170 atleast_1d, atleast_2d, atleast_3d, block_diag, cartesian_prod, chunk, dsplit, hsplit,
171 meshgrid, split, tensor_split, tensordot, unravel_index, vsplit, SplitArg, TensorSplitArg,
172};
173pub use math::{cdist, einsum};
174pub use reduction::{unique, unique_consecutive, UniqueResult};
175pub use spectral::{
176 cepstrum, create_mel_filterbank, fftn, generate_window, hfft, hz_to_mel, ifftn, ihfft, irfft,
177 istft, istft_complete, mel_spectrogram, mel_to_hz, rfft2, rfftn, spectral_centroid,
178 spectral_rolloff, spectrogram, stft, stft_complete, SpectrogramType, WindowFunction,
179};
180pub use tensor_ops::{
181 cosine_similarity, embedding, linear, one_hot, pairwise_distance, pixel_shuffle,
182 pixel_unshuffle,
183};
184
185pub use image::{
187 affine_transform, closing, dilation, erosion, gaussian_blur, hsv_to_rgb, laplacian_filter,
188 opening, resize, rgb_to_hsv, sobel_filter, InterpolationMode, SobelDirection,
189};
190
191pub use signal::{
193 correlate, filtfilt, frame, lfilter, overlap_add, periodogram, welch, window, CorrelationMode,
194 PsdScaling, WindowType,
195};
196
197pub use data_ops::{
199 bincount,
200 histogram,
201 histogram_with_edges,
202 unique as unique_values, value_counts,
204};
205
206pub use random_ops::{
208 bernoulli, bernoulli_, exponential_, multinomial, normal_, rand, randint, randint_, randn,
209 randperm, uniform_,
210};
211
212pub use type_promotion::{
214 can_cast_safely, common_dtype_for_operation, ensure_compatible_types, get_type_category,
215 get_type_precision, promote_multiple_types, promote_scalar_type, promote_tensors,
216 promote_types, reduction_result_type, result_type, TypeCategory,
217};
218
219pub use fusion::{
221 analyze_fusion_opportunities, detect_fusible_patterns, fused_add_mul, fused_add_relu_mul,
222 fused_batch_norm, fused_mul_add, fused_relu_add, fused_sigmoid_mul, fused_silu,
223 fused_tanh_scale, AdaptiveFusionEngine, FusedOp, FusionOpportunity, FusionPerformance,
224 OpFusionEngine, OpSequence,
225};
226
227pub use special::{
229 acosh,
230 airy_ai,
231 asinh,
232 atanh,
233
234 bessel_i0,
235 bessel_i1,
236 bessel_iv,
237 bessel_j0,
239 bessel_j1,
240 bessel_jn,
241 bessel_k0,
242 bessel_k1,
243
244 bessel_y0,
245 bessel_y1,
246 bessel_yn,
247 beta,
248 betainc,
250 dawson,
251 digamma,
252 erf,
254 erfc,
255 erfcinv,
256
257 erfcx,
258 erfinv,
259 expint,
260 expm1,
261 fresnel,
262 fresnel_c,
263 fresnel_s,
264 gamma,
266 hypergeometric_1f1,
267 kelvin_ber,
268 lgamma,
269 log1p,
270 logsumexp,
272 multigammaln,
273
274 normal_cdf,
275 normal_icdf,
276
277 polygamma,
278 sinc,
280 spherical_j0,
282 spherical_j1,
283 spherical_jn,
284 spherical_y0,
285 spherical_y1,
286 spherical_yn,
287
288 voigt_profile,
289};
290
291pub use wavelet::{
293 cwt, dwt_1d, dwt_2d, idwt_1d, idwt_2d, wavedec, waverec, WaveletMode, WaveletType,
294};
295
296pub use interpolation::{
298 barycentric_interp, grid_sample, interp1d, interp2d, lanczos_interp1d, spline1d,
299 InterpolationMode as InterpMode,
300};
301
302pub use numerical::{
304 adaptive_quad, bisection, cumtrapz, gaussian_quad, gradient, newton_raphson,
305 partial_derivative, second_derivative, simps, trapz, DifferentiationMethod, IntegrationMethod,
306};
307
308pub use optimization::{
310 adam_optimizer,
311 analyze_optimization_problem,
312 auto_configure_optimization,
313 backtracking_line_search,
314 gradient_descent,
315 lbfgs_optimizer,
316 momentum_gradient_descent,
317 wolfe_line_search,
318 AdamParams,
319 AdaptiveAlgorithmSelector,
320 BFGSParams,
321 BacktrackingParams,
322 GradientDescentParams,
323 LineSearchMethod,
324 MomentumParams,
325 OptimizationAlgorithm,
326 TensorCharacteristics,
328 WolfeParams,
329};
330
331pub use lazy::{
333 lazy_ops::{execute, lazy, with_optimization},
334 LazyBuilder, LazyContext, LazyOp, LazyTensor,
335};
336
337pub use advanced_manipulation::{
339 boolean_index, cat, masked_fill, pad, reshape, slice_with_step, squeeze, unsqueeze,
340 where_tensor, PaddingMode,
341};
342
343pub use quantization::{
345 dynamic_quantize, fake_quantize, gradual_magnitude_prune, lottery_ticket_prune,
346 magnitude_prune, quantization_error_analysis, uniform_dequantize, uniform_quantize,
347 weight_clustering, QuantizationScheme, QuantizationType,
348};
349
350pub use sparse::{
352 sparse_add, sparse_conv1d, sparse_conv2d, sparse_coo_tensor, sparse_eye, sparse_max,
353 sparse_mean, sparse_min, sparse_mm, sparse_mul, sparse_sum, sparse_to_csr, sparse_transpose,
354 SparseTensor,
355};
356
357pub use autograd::{
359 apply_custom_function, apply_custom_function_with_context, apply_registered_function,
360 get_global_registry, register_custom_function, AutogradContext, AutogradRegistry,
361 CustomAutogradFunction, CustomAutogradFunctionWithContext, ExpFunction, ScaledAddFunction,
362 SquareFunction,
363};
364
365pub use profiling::{
367 benchmark,
368 global_profiler,
369 profile_operation,
370 run_performance_regression_test,
371 BaselineSummary,
372 BenchmarkConfig,
373 BenchmarkResults,
374 OperationMetrics,
375 OperationSummary,
376 PerformanceBaseline,
378 PerformanceRegressionTester,
379 Profiler,
380 RegressionTestConfig,
381 RegressionTestResult,
382 SystemInfo,
383};
384
385pub use utils::{
387 apply_binary_elementwise, apply_conditional_elementwise, apply_elementwise_operation,
388 calculate_pooling_output_size, calculate_pooling_output_size_2d,
389 calculate_pooling_output_size_3d, create_tensor_like, function_context, safe_for_log, safe_log,
390 safe_log_prob, validate_broadcastable_shapes, validate_dimension, validate_elementwise_shapes,
391 validate_loss_params, validate_non_empty, validate_pooling_params, validate_positive,
392 validate_range, validate_tensor_dims,
393};
394
395pub use transformations::{
397 einsum_optimized, tensor_contract, tensor_fold, tensor_map, tensor_outer, tensor_reduce,
398 tensor_scan, tensor_zip,
399};
400
401pub use tensor_decomposition::{cp_decomposition, tucker_decomposition};
403
404pub fn align_tensors(tensors: &[Tensor]) -> TorshResult<Vec<Tensor>> {
406 if tensors.is_empty() {
407 return Ok(vec![]);
408 }
409
410 let max_dims = tensors.iter().map(|t| t.shape().ndim()).max().unwrap_or(0);
412
413 let aligned: TorshResult<Vec<_>> = tensors
415 .iter()
416 .map(|t| {
417 let current_dims = t.shape().ndim();
418 if current_dims < max_dims {
419 let mut new_shape = vec![1; max_dims - current_dims];
421 new_shape.extend(t.shape().dims());
422 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&x| x as i32).collect();
424 t.view(&new_shape_i32)
425 } else {
426 Ok(t.clone())
427 }
428 })
429 .collect();
430
431 aligned
432}
433
434#[cfg(test)]
435mod tests {
436
437 #[test]
438 fn test_align_tensors() {
439 use crate::align_tensors;
440 use torsh_tensor::creation::ones;
441
442 let t1 = ones(&[3, 4]).unwrap();
444 let t2 = ones(&[4]).unwrap();
445 let t3 = ones(&[2, 3, 4]).unwrap();
446
447 let aligned = align_tensors(&[t1, t2, t3]).unwrap();
448
449 assert_eq!(aligned[0].shape().ndim(), 3);
451 assert_eq!(aligned[1].shape().ndim(), 3);
452 assert_eq!(aligned[2].shape().ndim(), 3);
453
454 assert_eq!(aligned[0].shape().dims(), &[1, 3, 4]);
456 assert_eq!(aligned[1].shape().dims(), &[1, 1, 4]);
457 assert_eq!(aligned[2].shape().dims(), &[2, 3, 4]);
458 }
459}