Skip to main content

vector_ta/indicators/moving_averages/
alma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2pub use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
3
4#[cfg(feature = "python")]
5use numpy::{IntoPyArray, PyArray1};
6#[cfg(feature = "python")]
7use pyo3::exceptions::PyValueError;
8#[cfg(feature = "python")]
9use pyo3::prelude::*;
10#[cfg(feature = "python")]
11use pyo3::types::{PyDict, PyList};
12
13#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
14use serde::{Deserialize, Serialize};
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use wasm_bindgen::prelude::*;
17
18use crate::utilities::data_loader::{source_type, Candles};
19use crate::utilities::enums::Kernel;
20use crate::utilities::helpers::{
21    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
22    make_uninit_matrix,
23};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26use aligned_vec::{AVec, CACHELINE_ALIGN};
27#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
28use core::arch::x86_64::*;
29#[cfg(not(target_arch = "wasm32"))]
30use rayon::prelude::*;
31use std::alloc::{alloc, dealloc, Layout};
32use std::convert::AsRef;
33use std::error::Error;
34use std::mem::MaybeUninit;
35use thiserror::Error;
36
37impl<'a> AsRef<[f64]> for AlmaInput<'a> {
38    #[inline(always)]
39    fn as_ref(&self) -> &[f64] {
40        match &self.data {
41            AlmaData::Slice(slice) => slice,
42            AlmaData::Candles { candles, source } => source_type(candles, source),
43        }
44    }
45}
46
47#[derive(Debug, Clone)]
48pub enum AlmaData<'a> {
49    Candles {
50        candles: &'a Candles,
51        source: &'a str,
52    },
53    Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct AlmaOutput {
58    pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63    all(target_arch = "wasm32", feature = "wasm"),
64    derive(Serialize, Deserialize)
65)]
66pub struct AlmaParams {
67    pub period: Option<usize>,
68    pub offset: Option<f64>,
69    pub sigma: Option<f64>,
70}
71
72impl Default for AlmaParams {
73    fn default() -> Self {
74        Self {
75            period: Some(9),
76            offset: Some(0.85),
77            sigma: Some(6.0),
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct AlmaInput<'a> {
84    pub data: AlmaData<'a>,
85    pub params: AlmaParams,
86}
87
88impl<'a> AlmaInput<'a> {
89    #[inline]
90    pub fn from_candles(c: &'a Candles, s: &'a str, p: AlmaParams) -> Self {
91        Self {
92            data: AlmaData::Candles {
93                candles: c,
94                source: s,
95            },
96            params: p,
97        }
98    }
99    #[inline]
100    pub fn from_slice(sl: &'a [f64], p: AlmaParams) -> Self {
101        Self {
102            data: AlmaData::Slice(sl),
103            params: p,
104        }
105    }
106    #[inline]
107    pub fn with_default_candles(c: &'a Candles) -> Self {
108        Self::from_candles(c, "close", AlmaParams::default())
109    }
110    #[inline]
111    pub fn get_period(&self) -> usize {
112        self.params.period.unwrap_or(9)
113    }
114    #[inline]
115    pub fn get_offset(&self) -> f64 {
116        self.params.offset.unwrap_or(0.85)
117    }
118    #[inline]
119    pub fn get_sigma(&self) -> f64 {
120        self.params.sigma.unwrap_or(6.0)
121    }
122}
123
124#[derive(Copy, Clone, Debug)]
125pub struct AlmaBuilder {
126    period: Option<usize>,
127    offset: Option<f64>,
128    sigma: Option<f64>,
129    kernel: Kernel,
130}
131
132impl Default for AlmaBuilder {
133    fn default() -> Self {
134        Self {
135            period: None,
136            offset: None,
137            sigma: None,
138            kernel: Kernel::Auto,
139        }
140    }
141}
142
143impl AlmaBuilder {
144    #[inline(always)]
145    pub fn new() -> Self {
146        Self::default()
147    }
148    #[inline(always)]
149    pub fn period(mut self, n: usize) -> Self {
150        self.period = Some(n);
151        self
152    }
153    #[inline(always)]
154    pub fn offset(mut self, x: f64) -> Self {
155        self.offset = Some(x);
156        self
157    }
158    #[inline(always)]
159    pub fn sigma(mut self, s: f64) -> Self {
160        self.sigma = Some(s);
161        self
162    }
163    #[inline(always)]
164    pub fn kernel(mut self, k: Kernel) -> Self {
165        self.kernel = k;
166        self
167    }
168
169    #[inline(always)]
170    pub fn apply(self, c: &Candles) -> Result<AlmaOutput, AlmaError> {
171        let p = AlmaParams {
172            period: self.period,
173            offset: self.offset,
174            sigma: self.sigma,
175        };
176        let i = AlmaInput::from_candles(c, "close", p);
177        alma_with_kernel(&i, self.kernel)
178    }
179
180    #[inline(always)]
181    pub fn apply_slice(self, d: &[f64]) -> Result<AlmaOutput, AlmaError> {
182        let p = AlmaParams {
183            period: self.period,
184            offset: self.offset,
185            sigma: self.sigma,
186        };
187        let i = AlmaInput::from_slice(d, p);
188        alma_with_kernel(&i, self.kernel)
189    }
190
191    #[inline(always)]
192    pub fn into_stream(self) -> Result<AlmaStream, AlmaError> {
193        let p = AlmaParams {
194            period: self.period,
195            offset: self.offset,
196            sigma: self.sigma,
197        };
198        AlmaStream::try_new(p)
199    }
200}
201
202#[derive(Debug, Error)]
203pub enum AlmaError {
204    #[error("alma: Input data slice is empty.")]
205    EmptyInputData,
206    #[error("alma: All values are NaN.")]
207    AllValuesNaN,
208
209    #[error("alma: Invalid period: period = {period}, data length = {data_len}")]
210    InvalidPeriod { period: usize, data_len: usize },
211
212    #[error("alma: Not enough valid data: needed = {needed}, valid = {valid}")]
213    NotEnoughValidData { needed: usize, valid: usize },
214
215    #[error("alma: Invalid sigma: {sigma}")]
216    InvalidSigma { sigma: f64 },
217
218    #[error("alma: Invalid offset: {offset}")]
219    InvalidOffset { offset: f64 },
220
221    #[error("alma: Output length mismatch: expected {expected}, got {got}")]
222    OutputLengthMismatch { expected: usize, got: usize },
223
224    #[error("alma: Invalid range: start={start}, end={end}, step={step}")]
225    InvalidRange {
226        start: String,
227        end: String,
228        step: String,
229    },
230
231    #[error("alma: Invalid kernel for batch: {0:?}")]
232    InvalidKernelForBatch(crate::utilities::enums::Kernel),
233}
234
235#[inline]
236pub fn alma(input: &AlmaInput) -> Result<AlmaOutput, AlmaError> {
237    alma_with_kernel(input, Kernel::Auto)
238}
239
240#[inline(always)]
241fn alma_compute_into(
242    data: &[f64],
243    weights: &[f64],
244    period: usize,
245    first: usize,
246    inv_n: f64,
247    kernel: Kernel,
248    out: &mut [f64],
249) {
250    unsafe {
251        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
252        {
253            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
254                alma_simd128(data, weights, period, first, inv_n, out);
255                return;
256            }
257        }
258
259        match kernel {
260            Kernel::Scalar | Kernel::ScalarBatch => {
261                alma_scalar(data, weights, period, first, inv_n, out)
262            }
263            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
264            Kernel::Avx2 | Kernel::Avx2Batch => alma_avx2(data, weights, period, first, inv_n, out),
265            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266            Kernel::Avx512 | Kernel::Avx512Batch => {
267                alma_avx512(data, weights, period, first, inv_n, out)
268            }
269            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
270            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
271                alma_scalar(data, weights, period, first, inv_n, out)
272            }
273            _ => unreachable!(),
274        }
275    }
276}
277
278#[inline(always)]
279fn alma_prepare<'a>(
280    input: &'a AlmaInput,
281    kernel: Kernel,
282) -> Result<(&'a [f64], AVec<f64>, usize, usize, f64, Kernel), AlmaError> {
283    let data: &[f64] = input.as_ref();
284    let len = data.len();
285    if len == 0 {
286        return Err(AlmaError::EmptyInputData);
287    }
288    let first = data
289        .iter()
290        .position(|x| !x.is_nan())
291        .ok_or(AlmaError::AllValuesNaN)?;
292    let period = input.get_period();
293    let offset = input.get_offset();
294    let sigma = input.get_sigma();
295
296    if period == 0 || period > len {
297        return Err(AlmaError::InvalidPeriod {
298            period,
299            data_len: len,
300        });
301    }
302    if len - first < period {
303        return Err(AlmaError::NotEnoughValidData {
304            needed: period,
305            valid: len - first,
306        });
307    }
308    if sigma <= 0.0 {
309        return Err(AlmaError::InvalidSigma { sigma });
310    }
311    if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
312        return Err(AlmaError::InvalidOffset { offset });
313    }
314
315    let m = offset * (period - 1) as f64;
316    let s = period as f64 / sigma;
317    let s2 = 2.0 * s * s;
318
319    let aligned_period = ((period + 7) / 8) * 8;
320    let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, aligned_period);
321    weights.resize(aligned_period, 0.0);
322
323    let inv_s2 = 1.0 / s2;
324    let mut norm = 0.0;
325
326    for i in 0..period {
327        let diff = i as f64 - m;
328        let w = (-diff * diff * inv_s2).exp();
329        weights[i] = w;
330        norm += w;
331    }
332    let inv_norm = 1.0 / norm;
333
334    let chosen = match kernel {
335        Kernel::Auto => detect_best_kernel(),
336        k => k,
337    };
338
339    Ok((data, weights, period, first, inv_norm, chosen))
340}
341
342pub fn alma_with_kernel(input: &AlmaInput, kernel: Kernel) -> Result<AlmaOutput, AlmaError> {
343    let (data, weights, period, first, inv_n, chosen) = alma_prepare(input, kernel)?;
344
345    let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
346
347    alma_compute_into(data, &weights, period, first, inv_n, chosen, &mut out);
348
349    Ok(AlmaOutput { values: out })
350}
351
352#[inline]
353pub fn alma_into_slice(dst: &mut [f64], input: &AlmaInput, kern: Kernel) -> Result<(), AlmaError> {
354    let (data, weights, period, first, inv_n, chosen) = alma_prepare(input, kern)?;
355
356    if dst.len() != data.len() {
357        return Err(AlmaError::OutputLengthMismatch {
358            expected: data.len(),
359            got: dst.len(),
360        });
361    }
362
363    alma_compute_into(data, &weights, period, first, inv_n, chosen, dst);
364
365    let warmup_end = first + period - 1;
366    for v in &mut dst[..warmup_end] {
367        *v = f64::NAN;
368    }
369
370    Ok(())
371}
372
373#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
374#[inline]
375pub fn alma_into(input: &AlmaInput, out: &mut [f64]) -> Result<(), AlmaError> {
376    let (data, weights, period, first, inv_n, chosen) = alma_prepare(input, Kernel::Auto)?;
377
378    if out.len() != data.len() {
379        return Err(AlmaError::OutputLengthMismatch {
380            expected: data.len(),
381            got: out.len(),
382        });
383    }
384
385    let warmup_end = first + period - 1;
386    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
387    let warm = warmup_end.min(out.len());
388    for v in &mut out[..warm] {
389        *v = qnan;
390    }
391
392    alma_compute_into(data, &weights, period, first, inv_n, chosen, out);
393
394    Ok(())
395}
396
397#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
398#[inline]
399#[target_feature(enable = "avx512f")]
400pub unsafe fn hsum_pd_zmm(v: __m512d) -> f64 {
401    #[allow(unused_unsafe)]
402    {
403        _mm512_reduce_add_pd(v)
404    }
405}
406
407#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
408#[inline]
409#[target_feature(enable = "avx512f")]
410pub fn alma_avx512(
411    data: &[f64],
412    weights: &[f64],
413    period: usize,
414    first_valid: usize,
415    inv_norm: f64,
416    out: &mut [f64],
417) {
418    if period <= 32 {
419        unsafe { alma_avx512_short(data, weights, period, first_valid, inv_norm, out) }
420    } else {
421        unsafe { alma_avx512_long(data, weights, period, first_valid, inv_norm, out) }
422    }
423}
424
425#[inline(always)]
426pub fn alma_scalar(
427    data: &[f64],
428    weights: &[f64],
429    period: usize,
430    first_val: usize,
431    inv_norm: f64,
432    out: &mut [f64],
433) {
434    assert!(
435        weights.len() >= period,
436        "weights.len() must be at least `period`"
437    );
438    assert!(
439        out.len() >= data.len(),
440        "`out` must be at least as long as `data`"
441    );
442
443    let p4 = period & !3;
444
445    for i in (first_val + period - 1)..data.len() {
446        let start = i + 1 - period;
447        let window = &data[start..start + period];
448
449        let mut sum = 0.0;
450        for (d4, w4) in window[..p4]
451            .chunks_exact(4)
452            .zip(weights[..p4].chunks_exact(4))
453        {
454            sum += d4[0] * w4[0] + d4[1] * w4[1] + d4[2] * w4[2] + d4[3] * w4[3];
455        }
456
457        for (d, w) in window[p4..].iter().zip(&weights[p4..]) {
458            sum += d * w;
459        }
460
461        out[i] = sum * inv_norm;
462    }
463}
464
465#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
466#[inline(always)]
467unsafe fn alma_simd128(
468    data: &[f64],
469    weights: &[f64],
470    period: usize,
471    first_val: usize,
472    inv_norm: f64,
473    out: &mut [f64],
474) {
475    use core::arch::wasm32::*;
476
477    assert!(
478        weights.len() >= period,
479        "weights.len() must be at least `period`"
480    );
481    assert!(
482        out.len() >= data.len(),
483        "`out` must be at least as long as `data`"
484    );
485
486    const STEP: usize = 2;
487    let chunks = period / STEP;
488    let tail = period % STEP;
489
490    for i in (first_val + period - 1)..data.len() {
491        let start = i + 1 - period;
492        let mut acc = f64x2_splat(0.0);
493
494        for blk in 0..chunks {
495            let idx = blk * STEP;
496            let w = v128_load(weights.as_ptr().add(idx) as *const v128);
497            let d = v128_load(data.as_ptr().add(start + idx) as *const v128);
498            acc = f64x2_add(acc, f64x2_mul(d, w));
499        }
500
501        let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
502
503        if tail != 0 {
504            sum += data[start + chunks * STEP] * weights[chunks * STEP];
505        }
506
507        out[i] = sum * inv_norm;
508    }
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[inline]
513#[target_feature(enable = "avx2,fma")]
514unsafe fn alma_avx2_short(
515    data: &[f64],
516    weights: &[f64],
517    period: usize,
518    first_valid: usize,
519    inv_norm: f64,
520    out: &mut [f64],
521) {
522    const STEP: usize = 4;
523    let chunks = period / STEP;
524    let tail = period % STEP;
525
526    let tail_mask = match tail {
527        0 => _mm256_setzero_si256(),
528        1 => _mm256_setr_epi64x(-1, 0, 0, 0),
529        2 => _mm256_setr_epi64x(-1, -1, 0, 0),
530        3 => _mm256_setr_epi64x(-1, -1, -1, 0),
531        _ => unreachable!(),
532    };
533
534    for i in (first_valid + period - 1)..data.len() {
535        let start = i + 1 - period;
536        let mut acc = _mm256_setzero_pd();
537
538        for blk in 0..chunks {
539            let idx = blk * STEP;
540            let w = _mm256_loadu_pd(weights.as_ptr().add(idx));
541            let d = _mm256_loadu_pd(data.as_ptr().add(start + idx));
542            acc = _mm256_fmadd_pd(d, w, acc);
543        }
544
545        if tail != 0 {
546            let w_tail = _mm256_maskload_pd(weights.as_ptr().add(chunks * STEP), tail_mask);
547            let d_tail = _mm256_maskload_pd(data.as_ptr().add(start + chunks * STEP), tail_mask);
548            acc = _mm256_fmadd_pd(d_tail, w_tail, acc);
549        }
550
551        let hi = _mm256_extractf128_pd(acc, 1);
552        let lo = _mm256_castpd256_pd128(acc);
553        let sum2 = _mm_add_pd(hi, lo);
554        let sum1 = _mm_add_pd(sum2, _mm_unpackhi_pd(sum2, sum2));
555        let sum = _mm_cvtsd_f64(sum1);
556
557        *out.get_unchecked_mut(i) = sum * inv_norm;
558    }
559}
560
561#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
562#[inline]
563#[target_feature(enable = "avx2,fma")]
564unsafe fn alma_avx2_long(
565    data: &[f64],
566    weights: &[f64],
567    period: usize,
568    first_valid: usize,
569    inv_norm: f64,
570    out: &mut [f64],
571) {
572    const STEP: usize = 4;
573    let chunks = period / STEP;
574    let tail = period % STEP;
575
576    let paired_chunks = chunks / 2;
577    let odd_chunk = chunks % 2;
578
579    let tail_mask = match tail {
580        0 => _mm256_setzero_si256(),
581        1 => _mm256_setr_epi64x(-1, 0, 0, 0),
582        2 => _mm256_setr_epi64x(-1, -1, 0, 0),
583        3 => _mm256_setr_epi64x(-1, -1, -1, 0),
584        _ => unreachable!(),
585    };
586
587    for i in (first_valid + period - 1)..data.len() {
588        let start = i + 1 - period;
589        let mut acc0 = _mm256_setzero_pd();
590        let mut acc1 = _mm256_setzero_pd();
591
592        for blk in 0..paired_chunks {
593            let idx0 = (blk * 2) * STEP;
594            let idx1 = (blk * 2 + 1) * STEP;
595
596            let w0 = _mm256_loadu_pd(weights.as_ptr().add(idx0));
597            let w1 = _mm256_loadu_pd(weights.as_ptr().add(idx1));
598            let d0 = _mm256_loadu_pd(data.as_ptr().add(start + idx0));
599            let d1 = _mm256_loadu_pd(data.as_ptr().add(start + idx1));
600
601            acc0 = _mm256_fmadd_pd(d0, w0, acc0);
602            acc1 = _mm256_fmadd_pd(d1, w1, acc1);
603        }
604
605        if odd_chunk != 0 {
606            let idx = (paired_chunks * 2) * STEP;
607            let w = _mm256_loadu_pd(weights.as_ptr().add(idx));
608            let d = _mm256_loadu_pd(data.as_ptr().add(start + idx));
609            acc0 = _mm256_fmadd_pd(d, w, acc0);
610        }
611
612        let acc = _mm256_add_pd(acc0, acc1);
613
614        let final_acc = if tail != 0 {
615            let w_tail = _mm256_maskload_pd(weights.as_ptr().add(chunks * STEP), tail_mask);
616            let d_tail = _mm256_maskload_pd(data.as_ptr().add(start + chunks * STEP), tail_mask);
617            _mm256_fmadd_pd(d_tail, w_tail, acc)
618        } else {
619            acc
620        };
621
622        let sum128 = _mm_add_pd(
623            _mm256_castpd256_pd128(final_acc),
624            _mm256_extractf128_pd(final_acc, 1),
625        );
626        let sum = _mm_cvtsd_f64(_mm_hadd_pd(sum128, sum128));
627
628        *out.get_unchecked_mut(i) = sum * inv_norm;
629    }
630}
631
632#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
633#[inline]
634#[target_feature(enable = "avx2,fma")]
635pub fn alma_avx2(
636    data: &[f64],
637    weights: &[f64],
638    period: usize,
639    first_valid: usize,
640    inv_norm: f64,
641    out: &mut [f64],
642) {
643    if period <= 32 {
644        unsafe { alma_avx2_short(data, weights, period, first_valid, inv_norm, out) }
645    } else {
646        unsafe { alma_avx2_long(data, weights, period, first_valid, inv_norm, out) }
647    }
648}
649
650#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
651#[inline]
652#[target_feature(enable = "avx512f,fma")]
653unsafe fn alma_avx512_short(
654    data: &[f64],
655    weights: &[f64],
656    period: usize,
657    first_valid: usize,
658    inv_norm: f64,
659    out: &mut [f64],
660) {
661    debug_assert!(period >= 1);
662    debug_assert!(data.len() == out.len());
663    debug_assert!(weights.len() >= period);
664
665    const STEP: usize = 8;
666    let chunks = period / STEP;
667    let tail_len = period % STEP;
668    let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
669
670    if chunks == 0 {
671        let w_vec = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr());
672        for i in (first_valid + period - 1)..data.len() {
673            let start = i + 1 - period;
674            let d_vec = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start));
675            let sum = hsum_pd_zmm(_mm512_mul_pd(d_vec, w_vec)) * inv_norm;
676            *out.get_unchecked_mut(i) = sum;
677        }
678        return;
679    }
680
681    for i in (first_valid + period - 1)..data.len() {
682        let start = i + 1 - period;
683        let mut acc = _mm512_setzero_pd();
684
685        for blk in 0..chunks {
686            let w = _mm512_load_pd(weights.as_ptr().add(blk * STEP));
687            let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
688            acc = _mm512_fmadd_pd(d, w, acc);
689        }
690
691        if tail_len != 0 {
692            let w_tail = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr().add(chunks * STEP));
693            let d_tail = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start + chunks * STEP));
694            acc = _mm512_fmadd_pd(d_tail, w_tail, acc);
695        }
696
697        *out.get_unchecked_mut(i) = hsum_pd_zmm(acc) * inv_norm;
698    }
699}
700
701#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
702#[inline]
703#[target_feature(enable = "avx512f,fma")]
704unsafe fn alma_avx512_long(
705    data: &[f64],
706    weights: &[f64],
707    period: usize,
708    first_valid: usize,
709    inv_norm: f64,
710    out: &mut [f64],
711) {
712    const STEP: usize = 8;
713    let n_chunks = period / STEP;
714    let tail_len = period % STEP;
715
716    let unroll8 = n_chunks & !7;
717    let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
718
719    debug_assert!(period >= 1 && n_chunks > 0);
720    debug_assert_eq!(data.len(), out.len());
721    debug_assert!(weights.len() >= period);
722
723    const MAX_STACK_CHUNKS: usize = 256;
724    let mut stack_storage = MaybeUninit::<[__m512d; MAX_STACK_CHUNKS]>::uninit();
725    let mut heap_storage: Option<Vec<__m512d>> = None;
726
727    let wregs: &[__m512d] = if n_chunks <= MAX_STACK_CHUNKS {
728        let base = stack_storage.as_mut_ptr().cast::<__m512d>();
729        for blk in 0..n_chunks {
730            unsafe {
731                base.add(blk)
732                    .write(_mm512_load_pd(weights.as_ptr().add(blk * STEP)));
733            }
734        }
735        unsafe { core::slice::from_raw_parts(base, n_chunks) }
736    } else {
737        let mut regs = Vec::with_capacity(n_chunks);
738        for blk in 0..n_chunks {
739            regs.push(_mm512_load_pd(weights.as_ptr().add(blk * STEP)));
740        }
741        heap_storage = Some(regs);
742        heap_storage.as_ref().unwrap().as_slice()
743    };
744    let w_tail = if tail_len != 0 {
745        Some(_mm512_maskz_loadu_pd(
746            tail_mask,
747            weights.as_ptr().add(n_chunks * STEP),
748        ))
749    } else {
750        None
751    };
752
753    let mut data_ptr = data.as_ptr().add(first_valid);
754    let stop_ptr = data.as_ptr().add(data.len());
755
756    let mut dst_ptr = out.as_mut_ptr().add(first_valid + period - 1);
757
758    if tail_len == 0 {
759        while data_ptr.add(period) <= stop_ptr {
760            let mut s0 = _mm512_setzero_pd();
761            let mut s1 = _mm512_setzero_pd();
762            let mut s2 = _mm512_setzero_pd();
763            let mut s3 = _mm512_setzero_pd();
764            let mut s4 = _mm512_setzero_pd();
765            let mut s5 = _mm512_setzero_pd();
766            let mut s6 = _mm512_setzero_pd();
767            let mut s7 = _mm512_setzero_pd();
768
769            for blk in (0..unroll8).step_by(8) {
770                let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
771                let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
772                let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
773                let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
774                let d4 = _mm512_loadu_pd(data_ptr.add((blk + 4) * STEP));
775                let d5 = _mm512_loadu_pd(data_ptr.add((blk + 5) * STEP));
776                let d6 = _mm512_loadu_pd(data_ptr.add((blk + 6) * STEP));
777                let d7 = _mm512_loadu_pd(data_ptr.add((blk + 7) * STEP));
778
779                s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
780                s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
781                s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
782                s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
783                s4 = _mm512_fmadd_pd(d4, *wregs.get_unchecked(blk + 4), s4);
784                s5 = _mm512_fmadd_pd(d5, *wregs.get_unchecked(blk + 5), s5);
785                s6 = _mm512_fmadd_pd(d6, *wregs.get_unchecked(blk + 6), s6);
786                s7 = _mm512_fmadd_pd(d7, *wregs.get_unchecked(blk + 7), s7);
787            }
788
789            for blk in unroll8..n_chunks {
790                let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
791                s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
792            }
793
794            let sum01 = _mm512_add_pd(s0, s1);
795            let sum23 = _mm512_add_pd(s2, s3);
796            let sum45 = _mm512_add_pd(s4, s5);
797            let sum67 = _mm512_add_pd(s6, s7);
798            let sum0123 = _mm512_add_pd(sum01, sum23);
799            let sum4567 = _mm512_add_pd(sum45, sum67);
800            let tot = _mm512_add_pd(sum0123, sum4567);
801
802            *dst_ptr = hsum_pd_zmm(tot) * inv_norm;
803
804            data_ptr = data_ptr.add(1);
805            dst_ptr = dst_ptr.add(1);
806        }
807    } else {
808        let wt = w_tail.expect("tail_len != 0 but w_tail missing");
809
810        while data_ptr.add(period) <= stop_ptr {
811            let mut s0 = _mm512_setzero_pd();
812            let mut s1 = _mm512_setzero_pd();
813            let mut s2 = _mm512_setzero_pd();
814            let mut s3 = _mm512_setzero_pd();
815            let mut s4 = _mm512_setzero_pd();
816            let mut s5 = _mm512_setzero_pd();
817            let mut s6 = _mm512_setzero_pd();
818            let mut s7 = _mm512_setzero_pd();
819
820            for blk in (0..unroll8).step_by(8) {
821                let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
822                let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
823                let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
824                let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
825                let d4 = _mm512_loadu_pd(data_ptr.add((blk + 4) * STEP));
826                let d5 = _mm512_loadu_pd(data_ptr.add((blk + 5) * STEP));
827                let d6 = _mm512_loadu_pd(data_ptr.add((blk + 6) * STEP));
828                let d7 = _mm512_loadu_pd(data_ptr.add((blk + 7) * STEP));
829
830                s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
831                s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
832                s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
833                s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
834                s4 = _mm512_fmadd_pd(d4, *wregs.get_unchecked(blk + 4), s4);
835                s5 = _mm512_fmadd_pd(d5, *wregs.get_unchecked(blk + 5), s5);
836                s6 = _mm512_fmadd_pd(d6, *wregs.get_unchecked(blk + 6), s6);
837                s7 = _mm512_fmadd_pd(d7, *wregs.get_unchecked(blk + 7), s7);
838            }
839
840            for blk in unroll8..n_chunks {
841                let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
842                s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
843            }
844
845            let d_tail = _mm512_maskz_loadu_pd(tail_mask, data_ptr.add(n_chunks * STEP));
846            s0 = _mm512_fmadd_pd(d_tail, wt, s0);
847
848            let sum01 = _mm512_add_pd(s0, s1);
849            let sum23 = _mm512_add_pd(s2, s3);
850            let sum45 = _mm512_add_pd(s4, s5);
851            let sum67 = _mm512_add_pd(s6, s7);
852            let sum0123 = _mm512_add_pd(sum01, sum23);
853            let sum4567 = _mm512_add_pd(sum45, sum67);
854            let tot = _mm512_add_pd(sum0123, sum4567);
855
856            *dst_ptr = hsum_pd_zmm(tot) * inv_norm;
857
858            data_ptr = data_ptr.add(1);
859            dst_ptr = dst_ptr.add(1);
860        }
861    }
862}
863
864#[derive(Debug, Clone)]
865pub struct AlmaStream {
866    period: usize,
867
868    weights: AVec<f64>,
869    inv_norm: f64,
870
871    buffer: Vec<f64>,
872
873    buf2: Vec<f64>,
874
875    head: usize,
876    filled: usize,
877    kernel: Kernel,
878}
879
880impl AlmaStream {
881    pub fn try_new(params: AlmaParams) -> Result<Self, AlmaError> {
882        let period = params.period.unwrap_or(9);
883        if period == 0 {
884            return Err(AlmaError::InvalidPeriod {
885                period,
886                data_len: 0,
887            });
888        }
889        let offset = params.offset.unwrap_or(0.85);
890        if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
891            return Err(AlmaError::InvalidOffset { offset });
892        }
893        let sigma = params.sigma.unwrap_or(6.0);
894        if sigma <= 0.0 {
895            return Err(AlmaError::InvalidSigma { sigma });
896        }
897
898        let m = offset * (period - 1) as f64;
899        let s = period as f64 / sigma;
900        let s2 = 2.0 * s * s;
901
902        let mut weights = AVec::<f64>::with_capacity(CACHELINE_ALIGN, period);
903        weights.resize(period, 0.0);
904
905        let mut norm = 0.0;
906        for i in 0..period {
907            let diff = i as f64 - m;
908            let w = (-(diff * diff) / s2).exp();
909            weights[i] = w;
910            norm += w;
911        }
912        let inv_norm = 1.0 / norm;
913
914        let buffer = vec![f64::NAN; period];
915        let buf2 = vec![f64::NAN; period * 2];
916        let kernel = detect_best_kernel();
917
918        Ok(Self {
919            period,
920            weights,
921            inv_norm,
922            buffer,
923            buf2,
924            head: 0,
925            filled: 0,
926            kernel,
927        })
928    }
929
930    #[inline(always)]
931    pub fn update(&mut self, value: f64) -> Option<f64> {
932        let h = self.head;
933
934        self.buffer[h] = value;
935
936        self.buf2[h] = value;
937        self.buf2[h + self.period] = value;
938
939        let mut new_h = h + 1;
940        if new_h == self.period {
941            new_h = 0;
942        }
943        self.head = new_h;
944
945        if self.filled < self.period {
946            self.filled += 1;
947            if self.filled < self.period {
948                return None;
949            }
950        }
951
952        Some(self.dot_at_head())
953    }
954
955    #[inline(always)]
956    fn dot_at_head(&self) -> f64 {
957        let start = self.head;
958        let end = start + self.period;
959        let x = &self.buf2[start..end];
960        let w = &self.weights[..self.period];
961        let acc = dot_contiguous(self.kernel, x, w);
962        acc * self.inv_norm
963    }
964}
965
966#[inline(always)]
967fn dot_scalar_unrolled_safe(x: &[f64], w: &[f64]) -> f64 {
968    debug_assert_eq!(x.len(), w.len());
969    let n = x.len();
970    let mut i = 0usize;
971    let n4 = n & !3;
972    let mut s0 = 0.0f64;
973    let mut s1 = 0.0f64;
974    let mut s2 = 0.0f64;
975    let mut s3 = 0.0f64;
976
977    while i < n4 {
978        s0 += x[i] * w[i];
979        s1 += x[i + 1] * w[i + 1];
980        s2 += x[i + 2] * w[i + 2];
981        s3 += x[i + 3] * w[i + 3];
982        i += 4;
983    }
984    let mut sum = (s0 + s1) + (s2 + s3);
985    while i < n {
986        sum += x[i] * w[i];
987        i += 1;
988    }
989    sum
990}
991
992#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
993#[inline(always)]
994unsafe fn hsum256(v: __m256d) -> f64 {
995    let hi = _mm256_extractf128_pd(v, 1);
996    let lo = _mm256_castpd256_pd128(v);
997    let s = _mm_add_pd(hi, lo);
998    let s = _mm_add_sd(s, _mm_unpackhi_pd(s, s));
999    _mm_cvtsd_f64(s)
1000}
1001
1002#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1003#[inline(always)]
1004unsafe fn hsum512(v: __m512d) -> f64 {
1005    _mm512_reduce_add_pd(v)
1006}
1007
1008#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1009#[inline(always)]
1010unsafe fn dot_avx2(x: *const f64, w: *const f64, n: usize) -> f64 {
1011    let mut i = 0usize;
1012    let n4 = n & !3;
1013    let mut acc = _mm256_setzero_pd();
1014    while i < n4 {
1015        let xv = _mm256_loadu_pd(x.add(i));
1016        let wv = _mm256_loadu_pd(w.add(i));
1017        acc = _mm256_fmadd_pd(xv, wv, acc);
1018        i += 4;
1019    }
1020    let mut sum = hsum256(acc);
1021    while i < n {
1022        sum += *x.add(i) * *w.add(i);
1023        i += 1;
1024    }
1025    sum
1026}
1027
1028#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1029#[inline(always)]
1030unsafe fn dot_avx512(x: *const f64, w: *const f64, n: usize) -> f64 {
1031    let mut i = 0usize;
1032    let n8 = n & !7;
1033    let mut acc = _mm512_setzero_pd();
1034    while i < n8 {
1035        let xv = _mm512_loadu_pd(x.add(i));
1036        let wv = _mm512_loadu_pd(w.add(i));
1037        acc = _mm512_fmadd_pd(xv, wv, acc);
1038        i += 8;
1039    }
1040    let mut sum = hsum512(acc);
1041    while i < n {
1042        sum += *x.add(i) * *w.add(i);
1043        i += 1;
1044    }
1045    sum
1046}
1047
1048#[inline(always)]
1049fn dot_contiguous(kernel: Kernel, x: &[f64], w: &[f64]) -> f64 {
1050    debug_assert_eq!(x.len(), w.len());
1051    #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1052    {
1053        match kernel {
1054            Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
1055                return dot_avx512(x.as_ptr(), w.as_ptr(), x.len());
1056            },
1057            Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
1058                return dot_avx2(x.as_ptr(), w.as_ptr(), x.len());
1059            },
1060            _ => {}
1061        }
1062    }
1063
1064    dot_scalar_unrolled_safe(x, w)
1065}
1066
1067#[derive(Clone, Debug)]
1068pub struct AlmaBatchRange {
1069    pub period: (usize, usize, usize),
1070    pub offset: (f64, f64, f64),
1071    pub sigma: (f64, f64, f64),
1072}
1073
1074impl Default for AlmaBatchRange {
1075    fn default() -> Self {
1076        Self {
1077            period: (9, 258, 1),
1078            offset: (0.85, 0.85, 0.0),
1079            sigma: (6.0, 6.0, 0.0),
1080        }
1081    }
1082}
1083
1084#[derive(Clone, Debug, Default)]
1085pub struct AlmaBatchBuilder {
1086    range: AlmaBatchRange,
1087    kernel: Kernel,
1088}
1089
1090impl AlmaBatchBuilder {
1091    pub fn new() -> Self {
1092        Self::default()
1093    }
1094
1095    pub fn kernel(mut self, k: Kernel) -> Self {
1096        self.kernel = k;
1097        self
1098    }
1099
1100    #[inline]
1101    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1102        self.range.period = (start, end, step);
1103        self
1104    }
1105    #[inline]
1106    pub fn period_static(mut self, p: usize) -> Self {
1107        self.range.period = (p, p, 0);
1108        self
1109    }
1110
1111    #[inline]
1112    pub fn offset_range(mut self, start: f64, end: f64, step: f64) -> Self {
1113        self.range.offset = (start, end, step);
1114        self
1115    }
1116    #[inline]
1117    pub fn offset_static(mut self, x: f64) -> Self {
1118        self.range.offset = (x, x, 0.0);
1119        self
1120    }
1121
1122    #[inline]
1123    pub fn sigma_range(mut self, start: f64, end: f64, step: f64) -> Self {
1124        self.range.sigma = (start, end, step);
1125        self
1126    }
1127    #[inline]
1128    pub fn sigma_static(mut self, s: f64) -> Self {
1129        self.range.sigma = (s, s, 0.0);
1130        self
1131    }
1132
1133    pub fn apply_slice(self, data: &[f64]) -> Result<AlmaBatchOutput, AlmaError> {
1134        alma_batch_with_kernel(data, &self.range, self.kernel)
1135    }
1136
1137    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<AlmaBatchOutput, AlmaError> {
1138        AlmaBatchBuilder::new().kernel(k).apply_slice(data)
1139    }
1140
1141    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<AlmaBatchOutput, AlmaError> {
1142        let slice = source_type(c, src);
1143        self.apply_slice(slice)
1144    }
1145
1146    pub fn with_default_candles(c: &Candles) -> Result<AlmaBatchOutput, AlmaError> {
1147        AlmaBatchBuilder::new()
1148            .kernel(Kernel::Auto)
1149            .apply_candles(c, "close")
1150    }
1151}
1152
1153pub fn alma_batch_with_kernel(
1154    data: &[f64],
1155    sweep: &AlmaBatchRange,
1156    k: Kernel,
1157) -> Result<AlmaBatchOutput, AlmaError> {
1158    let kernel = match k {
1159        Kernel::Auto => detect_best_batch_kernel(),
1160        other if other.is_batch() => other,
1161        _ => return Err(AlmaError::InvalidKernelForBatch(k)),
1162    };
1163
1164    let simd = match kernel {
1165        Kernel::Avx512Batch => Kernel::Avx512,
1166        Kernel::Avx2Batch => Kernel::Avx2,
1167        Kernel::ScalarBatch => Kernel::Scalar,
1168        _ => unreachable!(),
1169    };
1170    alma_batch_par_slice(data, sweep, simd)
1171}
1172
1173#[derive(Clone, Debug)]
1174pub struct AlmaBatchOutput {
1175    pub values: Vec<f64>,
1176    pub combos: Vec<AlmaParams>,
1177    pub rows: usize,
1178    pub cols: usize,
1179}
1180impl AlmaBatchOutput {
1181    pub fn row_for_params(&self, p: &AlmaParams) -> Option<usize> {
1182        self.combos.iter().position(|c| {
1183            c.period.unwrap_or(9) == p.period.unwrap_or(9)
1184                && (c.offset.unwrap_or(0.85) - p.offset.unwrap_or(0.85)).abs() < 1e-12
1185                && (c.sigma.unwrap_or(6.0) - p.sigma.unwrap_or(6.0)).abs() < 1e-12
1186        })
1187    }
1188
1189    pub fn values_for(&self, p: &AlmaParams) -> Option<&[f64]> {
1190        self.row_for_params(p).map(|row| {
1191            let start = row * self.cols;
1192            &self.values[start..start + self.cols]
1193        })
1194    }
1195}
1196
1197#[inline(always)]
1198fn expand_grid(r: &AlmaBatchRange) -> Result<Vec<AlmaParams>, AlmaError> {
1199    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, AlmaError> {
1200        if step == 0 || start == end {
1201            return Ok(vec![start]);
1202        }
1203        if start < end {
1204            return Ok((start..=end).step_by(step.max(1)).collect());
1205        }
1206
1207        let mut v = Vec::new();
1208        let mut x = start as isize;
1209        let end_i = end as isize;
1210        let st = (step as isize).max(1);
1211        while x >= end_i {
1212            v.push(x as usize);
1213            x -= st;
1214        }
1215        if v.is_empty() {
1216            return Err(AlmaError::InvalidRange {
1217                start: start.to_string(),
1218                end: end.to_string(),
1219                step: step.to_string(),
1220            });
1221        }
1222        Ok(v)
1223    }
1224    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, AlmaError> {
1225        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1226            return Ok(vec![start]);
1227        }
1228        if start < end {
1229            let mut v = Vec::new();
1230            let mut x = start;
1231            let st = step.abs();
1232            while x <= end + 1e-12 {
1233                v.push(x);
1234                x += st;
1235            }
1236            if v.is_empty() {
1237                return Err(AlmaError::InvalidRange {
1238                    start: start.to_string(),
1239                    end: end.to_string(),
1240                    step: step.to_string(),
1241                });
1242            }
1243            return Ok(v);
1244        }
1245        let mut v = Vec::new();
1246        let mut x = start;
1247        let st = step.abs();
1248        while x + 1e-12 >= end {
1249            v.push(x);
1250            x -= st;
1251        }
1252        if v.is_empty() {
1253            return Err(AlmaError::InvalidRange {
1254                start: start.to_string(),
1255                end: end.to_string(),
1256                step: step.to_string(),
1257            });
1258        }
1259        Ok(v)
1260    }
1261
1262    let periods = axis_usize(r.period)?;
1263    let offsets = axis_f64(r.offset)?;
1264    let sigmas = axis_f64(r.sigma)?;
1265
1266    let cap = periods
1267        .len()
1268        .checked_mul(offsets.len())
1269        .and_then(|x| x.checked_mul(sigmas.len()))
1270        .ok_or_else(|| AlmaError::InvalidRange {
1271            start: "cap".into(),
1272            end: "overflow".into(),
1273            step: "mul".into(),
1274        })?;
1275
1276    let mut out = Vec::with_capacity(cap);
1277    for &p in &periods {
1278        for &o in &offsets {
1279            for &s in &sigmas {
1280                out.push(AlmaParams {
1281                    period: Some(p),
1282                    offset: Some(o),
1283                    sigma: Some(s),
1284                });
1285            }
1286        }
1287    }
1288    Ok(out)
1289}
1290
1291#[inline(always)]
1292pub fn alma_batch_slice(
1293    data: &[f64],
1294    sweep: &AlmaBatchRange,
1295    kern: Kernel,
1296) -> Result<AlmaBatchOutput, AlmaError> {
1297    alma_batch_inner(data, sweep, kern, false)
1298}
1299
1300#[inline(always)]
1301pub fn alma_batch_par_slice(
1302    data: &[f64],
1303    sweep: &AlmaBatchRange,
1304    kern: Kernel,
1305) -> Result<AlmaBatchOutput, AlmaError> {
1306    alma_batch_inner(data, sweep, kern, true)
1307}
1308
1309#[inline]
1310fn round_up8(x: usize) -> usize {
1311    (x + 7) & !7
1312}
1313
1314#[inline(always)]
1315fn alma_batch_inner(
1316    data: &[f64],
1317    sweep: &AlmaBatchRange,
1318    kern: Kernel,
1319    parallel: bool,
1320) -> Result<AlmaBatchOutput, AlmaError> {
1321    let combos = expand_grid(sweep)?;
1322    let cols = data.len();
1323    let rows = combos.len();
1324
1325    if cols == 0 {
1326        return Err(AlmaError::AllValuesNaN);
1327    }
1328
1329    let _ = rows
1330        .checked_mul(cols)
1331        .ok_or_else(|| AlmaError::InvalidRange {
1332            start: rows.to_string(),
1333            end: cols.to_string(),
1334            step: "rows*cols".into(),
1335        })?;
1336    let mut buf_mu = make_uninit_matrix(rows, cols);
1337
1338    let warm: Vec<usize> = combos
1339        .iter()
1340        .map(|c| data.iter().position(|x| !x.is_nan()).unwrap_or(0) + c.period.unwrap() - 1)
1341        .collect();
1342    init_matrix_prefixes(&mut buf_mu, cols, &warm);
1343
1344    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1345    let out: &mut [f64] = unsafe {
1346        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1347    };
1348
1349    alma_batch_inner_into(data, sweep, kern, parallel, out)?;
1350
1351    let values = unsafe {
1352        Vec::from_raw_parts(
1353            buf_guard.as_mut_ptr() as *mut f64,
1354            buf_guard.len(),
1355            buf_guard.capacity(),
1356        )
1357    };
1358
1359    Ok(AlmaBatchOutput {
1360        values,
1361        combos,
1362        rows,
1363        cols,
1364    })
1365}
1366
1367#[inline(always)]
1368fn alma_batch_inner_into(
1369    data: &[f64],
1370    sweep: &AlmaBatchRange,
1371    kern: Kernel,
1372    parallel: bool,
1373    out: &mut [f64],
1374) -> Result<Vec<AlmaParams>, AlmaError> {
1375    let combos = expand_grid(sweep)?;
1376    if combos.is_empty() {
1377        return Err(AlmaError::InvalidRange {
1378            start: "range".into(),
1379            end: "range".into(),
1380            step: "empty".into(),
1381        });
1382    }
1383
1384    let first = data
1385        .iter()
1386        .position(|x| !x.is_nan())
1387        .ok_or(AlmaError::AllValuesNaN)?;
1388    let max_p = combos
1389        .iter()
1390        .map(|c| round_up8(c.period.unwrap()))
1391        .max()
1392        .unwrap();
1393    if data.len() - first < max_p {
1394        return Err(AlmaError::NotEnoughValidData {
1395            needed: max_p,
1396            valid: data.len() - first,
1397        });
1398    }
1399
1400    let rows = combos.len();
1401    let cols = data.len();
1402    let mut inv_norms = vec![0.0; rows];
1403
1404    let cap = rows
1405        .checked_mul(max_p)
1406        .ok_or_else(|| AlmaError::InvalidRange {
1407            start: rows.to_string(),
1408            end: max_p.to_string(),
1409            step: "rows*max_p".into(),
1410        })?;
1411    let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
1412    flat_w.resize(cap, 0.0);
1413
1414    for (row, prm) in combos.iter().enumerate() {
1415        let period = prm.period.unwrap();
1416        let offset = prm.offset.unwrap();
1417        let sigma = prm.sigma.unwrap();
1418
1419        if sigma <= 0.0 {
1420            return Err(AlmaError::InvalidSigma { sigma });
1421        }
1422        if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
1423            return Err(AlmaError::InvalidOffset { offset });
1424        }
1425
1426        let m = offset * (period - 1) as f64;
1427        let s = period as f64 / sigma;
1428        let s2 = 2.0 * s * s;
1429
1430        let mut norm = 0.0;
1431        for i in 0..period {
1432            let w = (-(i as f64 - m).powi(2) / s2).exp();
1433            flat_w[row * max_p + i] = w;
1434            norm += w;
1435        }
1436        inv_norms[row] = 1.0 / norm;
1437    }
1438    let out_uninit = unsafe {
1439        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1440    };
1441
1442    let warm: Vec<usize> = combos
1443        .iter()
1444        .map(|c| first + c.period.unwrap() - 1)
1445        .collect();
1446    init_matrix_prefixes(out_uninit, cols, &warm);
1447
1448    let actual_kern = match kern {
1449        Kernel::Auto => detect_best_batch_kernel(),
1450        k => k,
1451    };
1452
1453    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1454        let period = combos[row].period.unwrap();
1455        let w_ptr = flat_w.as_ptr().add(row * max_p);
1456        let inv_n = *inv_norms.get_unchecked(row);
1457
1458        let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1459
1460        match actual_kern {
1461            Kernel::Scalar | Kernel::ScalarBatch => {
1462                alma_row_scalar(data, first, period, w_ptr, inv_n, dst)
1463            }
1464            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1465            Kernel::Avx2 | Kernel::Avx2Batch => {
1466                alma_row_avx2(data, first, period, w_ptr, inv_n, dst)
1467            }
1468            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1469            Kernel::Avx512 | Kernel::Avx512Batch => {
1470                alma_row_avx512(data, first, period, w_ptr, inv_n, dst)
1471            }
1472            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1473            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1474                alma_row_scalar(data, first, period, w_ptr, inv_n, dst)
1475            }
1476            Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1477        }
1478    };
1479
1480    if parallel {
1481        #[cfg(not(target_arch = "wasm32"))]
1482        {
1483            out_uninit
1484                .par_chunks_mut(cols)
1485                .enumerate()
1486                .for_each(|(row, slice)| do_row(row, slice));
1487        }
1488
1489        #[cfg(target_arch = "wasm32")]
1490        {
1491            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1492                do_row(row, slice);
1493            }
1494        }
1495    } else {
1496        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1497            do_row(row, slice);
1498        }
1499    }
1500
1501    Ok(combos)
1502}
1503
1504#[inline(always)]
1505unsafe fn alma_row_scalar(
1506    data: &[f64],
1507    first: usize,
1508    period: usize,
1509    w_ptr: *const f64,
1510    inv_n: f64,
1511    out: &mut [f64],
1512) {
1513    let p4 = period & !3;
1514    for i in (first + period - 1)..data.len() {
1515        let start = i + 1 - period;
1516        let mut sum = 0.0;
1517        for k in (0..p4).step_by(4) {
1518            let w = std::slice::from_raw_parts(w_ptr.add(k), 4);
1519            let d = &data[start + k..start + k + 4];
1520            sum += d[0] * w[0] + d[1] * w[1] + d[2] * w[2] + d[3] * w[3];
1521        }
1522        for k in p4..period {
1523            sum += *data.get_unchecked(start + k) * *w_ptr.add(k);
1524        }
1525        out[i] = sum * inv_n;
1526    }
1527}
1528
1529#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1530#[target_feature(enable = "avx2,fma")]
1531unsafe fn alma_row_avx2(
1532    data: &[f64],
1533    first: usize,
1534    period: usize,
1535    w_ptr: *const f64,
1536    inv_n: f64,
1537    out: &mut [f64],
1538) {
1539    const STEP: usize = 4;
1540    let vec_blocks = period / STEP;
1541    let tail = period % STEP;
1542    let tail_mask = match tail {
1543        0 => _mm256_setzero_si256(),
1544        1 => _mm256_setr_epi64x(-1, 0, 0, 0),
1545        2 => _mm256_setr_epi64x(-1, -1, 0, 0),
1546        3 => _mm256_setr_epi64x(-1, -1, -1, 0),
1547        _ => unreachable!(),
1548    };
1549
1550    for i in (first + period - 1)..data.len() {
1551        let start = i + 1 - period;
1552        let mut acc = _mm256_setzero_pd();
1553
1554        for blk in 0..vec_blocks {
1555            let d = _mm256_loadu_pd(data.as_ptr().add(start + blk * STEP));
1556            let w = _mm256_loadu_pd(w_ptr.add(blk * STEP));
1557            acc = _mm256_fmadd_pd(d, w, acc);
1558        }
1559
1560        if tail != 0 {
1561            let d = _mm256_maskload_pd(data.as_ptr().add(start + vec_blocks * STEP), tail_mask);
1562            let w = _mm256_maskload_pd(w_ptr.add(vec_blocks * STEP), tail_mask);
1563            acc = _mm256_fmadd_pd(d, w, acc);
1564        }
1565
1566        let hi = _mm256_extractf128_pd(acc, 1);
1567        let lo = _mm256_castpd256_pd128(acc);
1568        let s2 = _mm_add_pd(hi, lo);
1569        let s1 = _mm_add_pd(s2, _mm_unpackhi_pd(s2, s2));
1570        out[i] = _mm_cvtsd_f64(s1) * inv_n;
1571    }
1572}
1573
1574#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1575#[target_feature(enable = "avx512f,fma")]
1576pub unsafe fn alma_row_avx512(
1577    data: &[f64],
1578    first: usize,
1579    period: usize,
1580    w_ptr: *const f64,
1581    inv_n: f64,
1582    out: &mut [f64],
1583) {
1584    if period <= 32 {
1585        alma_row_avx512_short(data, first, period, w_ptr, inv_n, out);
1586    } else {
1587        alma_row_avx512_long(data, first, period, w_ptr, inv_n, out);
1588    }
1589}
1590
1591#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1592#[target_feature(enable = "avx512f,fma")]
1593unsafe fn alma_row_avx512_short(
1594    data: &[f64],
1595    first: usize,
1596    period: usize,
1597    w_ptr: *const f64,
1598    inv_n: f64,
1599    out: &mut [f64],
1600) {
1601    debug_assert!(period <= 32);
1602    const STEP: usize = 8;
1603
1604    let chunks = period / STEP;
1605    let tail_len = period % STEP;
1606    let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1607
1608    if chunks == 0 {
1609        let w_tail = _mm512_maskz_loadu_pd(tail_mask, w_ptr);
1610        for i in (first + period - 1)..data.len() {
1611            let start = i + 1 - period;
1612            let d_tail = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start));
1613            let res = hsum_pd_zmm(_mm512_mul_pd(d_tail, w_tail)) * inv_n;
1614            *out.get_unchecked_mut(i) = res;
1615        }
1616        return;
1617    }
1618
1619    for i in (first + period - 1)..data.len() {
1620        let start = i + 1 - period;
1621        let mut acc = _mm512_setzero_pd();
1622
1623        for blk in 0..chunks {
1624            let w = _mm512_load_pd(w_ptr.add(blk * STEP));
1625            let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
1626            acc = _mm512_fmadd_pd(d, w, acc);
1627        }
1628
1629        if tail_len != 0 {
1630            let w_tail = _mm512_maskz_loadu_pd(tail_mask, w_ptr.add(chunks * STEP));
1631            let d_tail = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start + chunks * STEP));
1632            acc = _mm512_fmadd_pd(d_tail, w_tail, acc);
1633        }
1634
1635        let res = hsum_pd_zmm(acc) * inv_n;
1636        *out.get_unchecked_mut(i) = res;
1637    }
1638}
1639
1640#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1641#[target_feature(enable = "avx512f,fma")]
1642unsafe fn alma_row_avx512_long(
1643    data: &[f64],
1644    first: usize,
1645    period: usize,
1646    w_ptr: *const f64,
1647    inv_n: f64,
1648    out: &mut [f64],
1649) {
1650    const STEP: usize = 8;
1651    let n_chunks = period / STEP;
1652    let tail_len = period % STEP;
1653    let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1654
1655    const MAX_CHUNKS: usize = 512;
1656    debug_assert!(n_chunks + (tail_len != 0) as usize <= MAX_CHUNKS);
1657
1658    let mut wregs: [core::mem::MaybeUninit<__m512d>; MAX_CHUNKS] =
1659        core::mem::MaybeUninit::uninit().assume_init();
1660
1661    for blk in 0..n_chunks {
1662        wregs[blk]
1663            .as_mut_ptr()
1664            .write(_mm512_load_pd(w_ptr.add(blk * STEP)));
1665    }
1666    if tail_len != 0 {
1667        wregs[n_chunks]
1668            .as_mut_ptr()
1669            .write(_mm512_maskz_loadu_pd(tmask, w_ptr.add(n_chunks * STEP)));
1670    }
1671
1672    let wregs: &[__m512d] = core::slice::from_raw_parts(
1673        wregs.as_ptr() as *const __m512d,
1674        n_chunks + (tail_len != 0) as usize,
1675    );
1676
1677    if tail_len == 0 {
1678        long_kernel_no_tail(data, first, n_chunks, wregs, inv_n, out);
1679    } else {
1680        long_kernel_with_tail(data, first, n_chunks, tail_len, tmask, wregs, inv_n, out);
1681    }
1682}
1683
1684#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1685#[target_feature(enable = "avx512f,fma")]
1686unsafe fn long_kernel_no_tail(
1687    data: &[f64],
1688    first: usize,
1689    n_chunks: usize,
1690    wregs: &[__m512d],
1691    inv_n: f64,
1692    out: &mut [f64],
1693) {
1694    const STEP: usize = 8;
1695    let paired = n_chunks & !3;
1696
1697    let mut data_ptr = data.as_ptr().add(first);
1698    let stop_ptr = data.as_ptr().add(data.len());
1699    let mut dst_ptr = out.as_mut_ptr().add(first + n_chunks * STEP - 1);
1700
1701    while data_ptr < stop_ptr {
1702        let mut s0 = _mm512_setzero_pd();
1703        let mut s1 = _mm512_setzero_pd();
1704        let mut s2 = _mm512_setzero_pd();
1705        let mut s3 = _mm512_setzero_pd();
1706
1707        let mut blk = 0;
1708        while blk < paired {
1709            let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
1710            let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
1711            let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
1712            let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
1713
1714            s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1715            s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1716            s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1717            s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1718
1719            blk += 4;
1720        }
1721
1722        for r in blk..n_chunks {
1723            let d = _mm512_loadu_pd(data_ptr.add(r * STEP));
1724            s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(r), s0);
1725        }
1726
1727        let sum = _mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3));
1728        let res = hsum_pd_zmm(sum) * inv_n;
1729
1730        *dst_ptr = res;
1731
1732        data_ptr = data_ptr.add(1);
1733        dst_ptr = dst_ptr.add(1);
1734        if data_ptr.add(n_chunks * STEP) > stop_ptr {
1735            break;
1736        }
1737    }
1738}
1739
1740#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1741#[target_feature(enable = "avx512f,fma")]
1742unsafe fn long_kernel_with_tail(
1743    data: &[f64],
1744    first: usize,
1745    n_chunks: usize,
1746    tail_len: usize,
1747    tmask: __mmask8,
1748    wregs: &[__m512d],
1749    inv_n: f64,
1750    out: &mut [f64],
1751) {
1752    const STEP: usize = 8;
1753    let paired = n_chunks & !3;
1754
1755    let w_tail = *wregs.get_unchecked(n_chunks);
1756
1757    let mut data_ptr = data.as_ptr().add(first);
1758    let stop_ptr = data.as_ptr().add(data.len());
1759    let mut dst_ptr = out.as_mut_ptr().add(first + n_chunks * STEP + tail_len - 1);
1760
1761    while data_ptr < stop_ptr {
1762        let mut s0 = _mm512_setzero_pd();
1763        let mut s1 = _mm512_setzero_pd();
1764        let mut s2 = _mm512_setzero_pd();
1765        let mut s3 = _mm512_setzero_pd();
1766
1767        let mut blk = 0;
1768        while blk < paired {
1769            let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
1770            let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
1771            let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
1772            let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
1773
1774            s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1775            s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1776            s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1777            s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1778
1779            blk += 4;
1780        }
1781
1782        for r in blk..n_chunks {
1783            let d = _mm512_loadu_pd(data_ptr.add(r * STEP));
1784            s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(r), s0);
1785        }
1786
1787        let d_tail = _mm512_maskz_loadu_pd(tmask, data_ptr.add(n_chunks * STEP));
1788        s0 = _mm512_fmadd_pd(d_tail, w_tail, s0);
1789
1790        let sum = _mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3));
1791        let res = hsum_pd_zmm(sum) * inv_n;
1792
1793        *dst_ptr = res;
1794
1795        data_ptr = data_ptr.add(1);
1796        dst_ptr = dst_ptr.add(1);
1797        if data_ptr.add(n_chunks * STEP + tail_len) > stop_ptr {
1798            break;
1799        }
1800    }
1801}
1802
1803#[cfg(test)]
1804mod tests {
1805    use super::*;
1806    use crate::skip_if_unsupported;
1807    use crate::utilities::data_loader::read_candles_from_csv;
1808    #[cfg(feature = "proptest")]
1809    use proptest::prelude::*;
1810
1811    fn check_alma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1812        skip_if_unsupported!(kernel, test_name);
1813        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1814        let candles = read_candles_from_csv(file_path)?;
1815
1816        let default_params = AlmaParams {
1817            period: None,
1818            offset: None,
1819            sigma: None,
1820        };
1821        let input = AlmaInput::from_candles(&candles, "close", default_params);
1822        let output = alma_with_kernel(&input, kernel)?;
1823        assert_eq!(output.values.len(), candles.close.len());
1824
1825        Ok(())
1826    }
1827
1828    fn check_alma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1829        skip_if_unsupported!(kernel, test_name);
1830        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1831        let candles = read_candles_from_csv(file_path)?;
1832
1833        let input = AlmaInput::from_candles(&candles, "close", AlmaParams::default());
1834        let result = alma_with_kernel(&input, kernel)?;
1835        let expected_last_five = [
1836            59286.72216704,
1837            59273.53428138,
1838            59204.37290721,
1839            59155.93381742,
1840            59026.92526112,
1841        ];
1842        let start = result.values.len().saturating_sub(5);
1843        for (i, &val) in result.values[start..].iter().enumerate() {
1844            let diff = (val - expected_last_five[i]).abs();
1845            assert!(
1846                diff < 1e-8,
1847                "[{}] ALMA {:?} mismatch at idx {}: got {}, expected {}",
1848                test_name,
1849                kernel,
1850                i,
1851                val,
1852                expected_last_five[i]
1853            );
1854        }
1855        Ok(())
1856    }
1857
1858    fn check_alma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1859        skip_if_unsupported!(kernel, test_name);
1860        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1861        let candles = read_candles_from_csv(file_path)?;
1862
1863        let input = AlmaInput::with_default_candles(&candles);
1864        match input.data {
1865            AlmaData::Candles { source, .. } => assert_eq!(source, "close"),
1866            _ => panic!("Expected AlmaData::Candles"),
1867        }
1868        let output = alma_with_kernel(&input, kernel)?;
1869        assert_eq!(output.values.len(), candles.close.len());
1870
1871        Ok(())
1872    }
1873
1874    fn check_alma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1875        skip_if_unsupported!(kernel, test_name);
1876        let input_data = [10.0, 20.0, 30.0];
1877        let params = AlmaParams {
1878            period: Some(0),
1879            offset: None,
1880            sigma: None,
1881        };
1882        let input = AlmaInput::from_slice(&input_data, params);
1883        let res = alma_with_kernel(&input, kernel);
1884        assert!(
1885            res.is_err(),
1886            "[{}] ALMA should fail with zero period",
1887            test_name
1888        );
1889        Ok(())
1890    }
1891
1892    fn check_alma_period_exceeds_length(
1893        test_name: &str,
1894        kernel: Kernel,
1895    ) -> Result<(), Box<dyn Error>> {
1896        skip_if_unsupported!(kernel, test_name);
1897        let data_small = [10.0, 20.0, 30.0];
1898        let params = AlmaParams {
1899            period: Some(10),
1900            offset: None,
1901            sigma: None,
1902        };
1903        let input = AlmaInput::from_slice(&data_small, params);
1904        let res = alma_with_kernel(&input, kernel);
1905        assert!(
1906            res.is_err(),
1907            "[{}] ALMA should fail with period exceeding length",
1908            test_name
1909        );
1910        Ok(())
1911    }
1912
1913    fn check_alma_very_small_dataset(
1914        test_name: &str,
1915        kernel: Kernel,
1916    ) -> Result<(), Box<dyn Error>> {
1917        skip_if_unsupported!(kernel, test_name);
1918        let single_point = [42.0];
1919        let params = AlmaParams {
1920            period: Some(9),
1921            offset: None,
1922            sigma: None,
1923        };
1924        let input = AlmaInput::from_slice(&single_point, params);
1925        let res = alma_with_kernel(&input, kernel);
1926        assert!(
1927            res.is_err(),
1928            "[{}] ALMA should fail with insufficient data",
1929            test_name
1930        );
1931        Ok(())
1932    }
1933
1934    fn check_alma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1935        skip_if_unsupported!(kernel, test_name);
1936        let empty: [f64; 0] = [];
1937        let input = AlmaInput::from_slice(&empty, AlmaParams::default());
1938        let res = alma_with_kernel(&input, kernel);
1939        assert!(
1940            matches!(res, Err(AlmaError::EmptyInputData)),
1941            "[{}] ALMA should fail with empty input",
1942            test_name
1943        );
1944        Ok(())
1945    }
1946
1947    fn check_alma_invalid_sigma(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1948        skip_if_unsupported!(kernel, test_name);
1949        let data = [1.0, 2.0, 3.0];
1950        let params = AlmaParams {
1951            period: Some(2),
1952            offset: None,
1953            sigma: Some(0.0),
1954        };
1955        let input = AlmaInput::from_slice(&data, params);
1956        let res = alma_with_kernel(&input, kernel);
1957        assert!(
1958            matches!(res, Err(AlmaError::InvalidSigma { .. })),
1959            "[{}] ALMA should fail with invalid sigma",
1960            test_name
1961        );
1962        Ok(())
1963    }
1964
1965    fn check_alma_invalid_offset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1966        skip_if_unsupported!(kernel, test_name);
1967        let data = [1.0, 2.0, 3.0];
1968        let params = AlmaParams {
1969            period: Some(2),
1970            offset: Some(f64::NAN),
1971            sigma: None,
1972        };
1973        let input = AlmaInput::from_slice(&data, params);
1974        let res = alma_with_kernel(&input, kernel);
1975        assert!(
1976            matches!(res, Err(AlmaError::InvalidOffset { .. })),
1977            "[{}] ALMA should fail with invalid offset",
1978            test_name
1979        );
1980        Ok(())
1981    }
1982
1983    fn check_alma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1984        skip_if_unsupported!(kernel, test_name);
1985        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1986        let candles = read_candles_from_csv(file_path)?;
1987
1988        let first_params = AlmaParams {
1989            period: Some(9),
1990            offset: None,
1991            sigma: None,
1992        };
1993        let first_input = AlmaInput::from_candles(&candles, "close", first_params);
1994        let first_result = alma_with_kernel(&first_input, kernel)?;
1995
1996        let second_params = AlmaParams {
1997            period: Some(9),
1998            offset: None,
1999            sigma: None,
2000        };
2001        let second_input = AlmaInput::from_slice(&first_result.values, second_params);
2002        let second_result = alma_with_kernel(&second_input, kernel)?;
2003
2004        assert_eq!(second_result.values.len(), first_result.values.len());
2005        let expected_last_five = [
2006            59140.73195170,
2007            59211.58090986,
2008            59238.16030697,
2009            59222.63528822,
2010            59165.14427332,
2011        ];
2012        let start = second_result.values.len().saturating_sub(5);
2013        for (i, &val) in second_result.values[start..].iter().enumerate() {
2014            let diff = (val - expected_last_five[i]).abs();
2015            assert!(
2016                diff < 1e-8,
2017                "[{}] ALMA Slice Reinput {:?} mismatch at idx {}: got {}, expected {}",
2018                test_name,
2019                kernel,
2020                i,
2021                val,
2022                expected_last_five[i]
2023            );
2024        }
2025        Ok(())
2026    }
2027
2028    fn check_alma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2029        skip_if_unsupported!(kernel, test_name);
2030        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2031        let candles = read_candles_from_csv(file_path)?;
2032
2033        let input = AlmaInput::from_candles(
2034            &candles,
2035            "close",
2036            AlmaParams {
2037                period: Some(9),
2038                offset: None,
2039                sigma: None,
2040            },
2041        );
2042        let res = alma_with_kernel(&input, kernel)?;
2043        assert_eq!(res.values.len(), candles.close.len());
2044        if res.values.len() > 240 {
2045            for (i, &val) in res.values[240..].iter().enumerate() {
2046                assert!(
2047                    !val.is_nan(),
2048                    "[{}] Found unexpected NaN at out-index {}",
2049                    test_name,
2050                    240 + i
2051                );
2052            }
2053        }
2054        Ok(())
2055    }
2056
2057    fn check_alma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2058        skip_if_unsupported!(kernel, test_name);
2059
2060        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2061        let candles = read_candles_from_csv(file_path)?;
2062
2063        let period = 9;
2064        let offset = 0.85;
2065        let sigma = 6.0;
2066
2067        let input = AlmaInput::from_candles(
2068            &candles,
2069            "close",
2070            AlmaParams {
2071                period: Some(period),
2072                offset: Some(offset),
2073                sigma: Some(sigma),
2074            },
2075        );
2076        let batch_output = alma_with_kernel(&input, kernel)?.values;
2077
2078        let mut stream = AlmaStream::try_new(AlmaParams {
2079            period: Some(period),
2080            offset: Some(offset),
2081            sigma: Some(sigma),
2082        })?;
2083
2084        let mut stream_values = Vec::with_capacity(candles.close.len());
2085        for &price in &candles.close {
2086            match stream.update(price) {
2087                Some(alma_val) => stream_values.push(alma_val),
2088                None => stream_values.push(f64::NAN),
2089            }
2090        }
2091
2092        assert_eq!(batch_output.len(), stream_values.len());
2093        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2094            if b.is_nan() && s.is_nan() {
2095                continue;
2096            }
2097            let diff = (b - s).abs();
2098            assert!(
2099                diff < 1e-9,
2100                "[{}] ALMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2101                test_name,
2102                i,
2103                b,
2104                s,
2105                diff
2106            );
2107        }
2108        Ok(())
2109    }
2110
2111    #[cfg(debug_assertions)]
2112    fn check_alma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2113        skip_if_unsupported!(kernel, test_name);
2114
2115        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2116        let candles = read_candles_from_csv(file_path)?;
2117
2118        let test_params = vec![
2119            AlmaParams::default(),
2120            AlmaParams {
2121                period: Some(5),
2122                offset: Some(0.5),
2123                sigma: Some(3.0),
2124            },
2125            AlmaParams {
2126                period: Some(5),
2127                offset: Some(0.85),
2128                sigma: Some(6.0),
2129            },
2130            AlmaParams {
2131                period: Some(5),
2132                offset: Some(1.0),
2133                sigma: Some(10.0),
2134            },
2135            AlmaParams {
2136                period: Some(9),
2137                offset: Some(0.2),
2138                sigma: Some(4.0),
2139            },
2140            AlmaParams {
2141                period: Some(9),
2142                offset: Some(0.85),
2143                sigma: Some(6.0),
2144            },
2145            AlmaParams {
2146                period: Some(9),
2147                offset: Some(0.95),
2148                sigma: Some(8.0),
2149            },
2150            AlmaParams {
2151                period: Some(20),
2152                offset: Some(0.0),
2153                sigma: Some(2.0),
2154            },
2155            AlmaParams {
2156                period: Some(20),
2157                offset: Some(0.5),
2158                sigma: Some(5.0),
2159            },
2160            AlmaParams {
2161                period: Some(20),
2162                offset: Some(0.85),
2163                sigma: Some(6.0),
2164            },
2165            AlmaParams {
2166                period: Some(20),
2167                offset: Some(1.0),
2168                sigma: Some(10.0),
2169            },
2170            AlmaParams {
2171                period: Some(2),
2172                offset: Some(0.0),
2173                sigma: Some(0.1),
2174            },
2175            AlmaParams {
2176                period: Some(50),
2177                offset: Some(0.5),
2178                sigma: Some(15.0),
2179            },
2180            AlmaParams {
2181                period: Some(100),
2182                offset: Some(0.85),
2183                sigma: Some(20.0),
2184            },
2185        ];
2186
2187        for (param_idx, params) in test_params.iter().enumerate() {
2188            let input = AlmaInput::from_candles(&candles, "close", params.clone());
2189            let output = alma_with_kernel(&input, kernel)?;
2190
2191            for (i, &val) in output.values.iter().enumerate() {
2192                if val.is_nan() {
2193                    continue;
2194                }
2195
2196                let bits = val.to_bits();
2197
2198                if bits == 0x11111111_11111111 {
2199                    panic!(
2200                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2201                        with params: period={}, offset={}, sigma={}",
2202                        test_name,
2203                        val,
2204                        bits,
2205                        i,
2206                        params.period.unwrap_or(9),
2207                        params.offset.unwrap_or(0.85),
2208                        params.sigma.unwrap_or(6.0)
2209                    );
2210                }
2211
2212                if bits == 0x22222222_22222222 {
2213                    panic!(
2214                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2215                        with params: period={}, offset={}, sigma={}",
2216                        test_name,
2217                        val,
2218                        bits,
2219                        i,
2220                        params.period.unwrap_or(9),
2221                        params.offset.unwrap_or(0.85),
2222                        params.sigma.unwrap_or(6.0)
2223                    );
2224                }
2225
2226                if bits == 0x33333333_33333333 {
2227                    panic!(
2228                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2229                        with params: period={}, offset={}, sigma={}",
2230                        test_name,
2231                        val,
2232                        bits,
2233                        i,
2234                        params.period.unwrap_or(9),
2235                        params.offset.unwrap_or(0.85),
2236                        params.sigma.unwrap_or(6.0)
2237                    );
2238                }
2239            }
2240        }
2241
2242        Ok(())
2243    }
2244
2245    #[cfg(not(debug_assertions))]
2246    fn check_alma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2247        Ok(())
2248    }
2249    #[cfg(feature = "proptest")]
2250    #[allow(clippy::float_cmp)]
2251    fn check_alma_property(
2252        test_name: &str,
2253        kernel: Kernel,
2254    ) -> Result<(), Box<dyn std::error::Error>> {
2255        use proptest::prelude::*;
2256        skip_if_unsupported!(kernel, test_name);
2257
2258        let strat = (1usize..=64).prop_flat_map(|period| {
2259            (
2260                prop::collection::vec(
2261                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2262                    period..400,
2263                ),
2264                Just(period),
2265                0f64..1f64,
2266                0.1f64..10.0f64,
2267            )
2268        });
2269
2270        proptest::test_runner::TestRunner::default()
2271            .run(&strat, |(data, period, offset, sigma)| {
2272                let params = AlmaParams {
2273                    period: Some(period),
2274                    offset: Some(offset),
2275                    sigma: Some(sigma),
2276                };
2277                let input = AlmaInput::from_slice(&data, params);
2278
2279                let AlmaOutput { values: out } = alma_with_kernel(&input, kernel).unwrap();
2280                let AlmaOutput { values: ref_out } =
2281                    alma_with_kernel(&input, Kernel::Scalar).unwrap();
2282
2283                for i in (period - 1)..data.len() {
2284                    let window = &data[i + 1 - period..=i];
2285                    let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
2286                    let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2287                    let y = out[i];
2288                    let r = ref_out[i];
2289
2290                    prop_assert!(
2291                        y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
2292                        "idx {i}: {y} ∉ [{lo}, {hi}]"
2293                    );
2294
2295                    if period == 1 {
2296                        prop_assert!((y - data[i]).abs() <= f64::EPSILON);
2297                    }
2298
2299                    if data.windows(2).all(|w| w[0] == w[1]) {
2300                        prop_assert!((y - data[0]).abs() <= 1e-9);
2301                    }
2302
2303                    let y_bits = y.to_bits();
2304                    let r_bits = r.to_bits();
2305
2306                    if !y.is_finite() || !r.is_finite() {
2307                        prop_assert!(
2308                            y.to_bits() == r.to_bits(),
2309                            "finite/NaN mismatch idx {i}: {y} vs {r}"
2310                        );
2311                        continue;
2312                    }
2313
2314                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2315
2316                    prop_assert!(
2317                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2318                        "mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2319                    );
2320                }
2321                Ok(())
2322            })
2323            .unwrap();
2324
2325        Ok(())
2326    }
2327
2328    macro_rules! generate_all_alma_tests {
2329        ($($test_fn:ident),*) => {
2330            paste::paste! {
2331                $(
2332                    #[test]
2333                    fn [<$test_fn _scalar_f64>]() {
2334                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2335                    }
2336                )*
2337                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2338                $(
2339                    #[test]
2340                    fn [<$test_fn _avx2_f64>]() {
2341                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2342                    }
2343                    #[test]
2344                    fn [<$test_fn _avx512_f64>]() {
2345                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2346                    }
2347                )*
2348                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2349                $(
2350                    #[test]
2351                    fn [<$test_fn _simd128_f64>]() {
2352                        let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2353                    }
2354                )*
2355            }
2356        }
2357    }
2358
2359    generate_all_alma_tests!(
2360        check_alma_partial_params,
2361        check_alma_accuracy,
2362        check_alma_default_candles,
2363        check_alma_zero_period,
2364        check_alma_period_exceeds_length,
2365        check_alma_very_small_dataset,
2366        check_alma_empty_input,
2367        check_alma_invalid_sigma,
2368        check_alma_invalid_offset,
2369        check_alma_reinput,
2370        check_alma_nan_handling,
2371        check_alma_streaming,
2372        check_alma_no_poison
2373    );
2374
2375    #[cfg(feature = "proptest")]
2376    generate_all_alma_tests!(check_alma_property);
2377
2378    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2379    #[test]
2380    fn test_alma_into_matches_api() -> Result<(), Box<dyn Error>> {
2381        let mut data = vec![f64::NAN; 3];
2382        data.extend((0..256).map(|i| (i as f64).sin() * 100.0 + (i as f64) * 0.1));
2383
2384        let input = AlmaInput::from_slice(&data, AlmaParams::default());
2385
2386        let baseline = alma_with_kernel(&input, Kernel::Auto)?.values;
2387
2388        let mut out = vec![0.0; data.len()];
2389        alma_into(&input, &mut out)?;
2390
2391        assert_eq!(baseline.len(), out.len());
2392
2393        fn eq_or_both_nan(a: f64, b: f64) -> bool {
2394            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2395        }
2396
2397        for i in 0..out.len() {
2398            assert!(
2399                eq_or_both_nan(baseline[i], out[i]),
2400                "mismatch at {}: baseline={} out={}",
2401                i,
2402                baseline[i],
2403                out[i]
2404            );
2405        }
2406
2407        Ok(())
2408    }
2409
2410    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2411        skip_if_unsupported!(kernel, test);
2412
2413        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2414        let c = read_candles_from_csv(file)?;
2415
2416        let output = AlmaBatchBuilder::new()
2417            .kernel(kernel)
2418            .apply_candles(&c, "close")?;
2419
2420        let def = AlmaParams::default();
2421        let row = output.values_for(&def).expect("default row missing");
2422
2423        assert_eq!(row.len(), c.close.len());
2424
2425        let expected = [
2426            59286.72216704,
2427            59273.53428138,
2428            59204.37290721,
2429            59155.93381742,
2430            59026.92526112,
2431        ];
2432        let start = row.len() - 5;
2433        for (i, &v) in row[start..].iter().enumerate() {
2434            assert!(
2435                (v - expected[i]).abs() < 1e-8,
2436                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2437            );
2438        }
2439        Ok(())
2440    }
2441
2442    macro_rules! gen_batch_tests {
2443        ($fn_name:ident) => {
2444            paste::paste! {
2445                #[test] fn [<$fn_name _scalar>]()      {
2446                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2447                }
2448                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2449                #[test] fn [<$fn_name _avx2>]()        {
2450                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2451                }
2452                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2453                #[test] fn [<$fn_name _avx512>]()      {
2454                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2455                }
2456                #[test] fn [<$fn_name _auto_detect>]() {
2457                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2458                                     Kernel::Auto);
2459                }
2460            }
2461        };
2462    }
2463
2464    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2465        skip_if_unsupported!(kernel, test);
2466
2467        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2468        let c = read_candles_from_csv(file)?;
2469
2470        let output = AlmaBatchBuilder::new()
2471            .kernel(kernel)
2472            .period_range(9, 20, 1)
2473            .offset_range(0.5, 1.0, 0.1)
2474            .sigma_range(3.0, 9.0, 1.0)
2475            .apply_candles(&c, "close")?;
2476
2477        let expected_combos = 12 * 6 * 7;
2478        assert_eq!(output.combos.len(), expected_combos);
2479        assert_eq!(output.rows, expected_combos);
2480        assert_eq!(output.cols, c.close.len());
2481
2482        Ok(())
2483    }
2484
2485    #[cfg(debug_assertions)]
2486    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2487        skip_if_unsupported!(kernel, test);
2488
2489        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2490        let c = read_candles_from_csv(file)?;
2491
2492        let test_configs = vec![
2493            (2, 10, 2, 0.0, 1.0, 0.2, 1.0, 10.0, 3.0),
2494            (5, 25, 5, 0.85, 0.85, 0.0, 6.0, 6.0, 0.0),
2495            (10, 10, 0, 0.0, 1.0, 0.1, 5.0, 5.0, 0.0),
2496            (2, 5, 1, 0.5, 0.5, 0.0, 3.0, 8.0, 1.0),
2497            (30, 60, 15, 0.85, 0.85, 0.0, 6.0, 6.0, 0.0),
2498            (9, 15, 3, 0.8, 0.9, 0.1, 6.0, 8.0, 2.0),
2499            (8, 12, 1, 0.7, 0.9, 0.05, 4.0, 8.0, 0.5),
2500        ];
2501
2502        for (cfg_idx, &(p_start, p_end, p_step, o_start, o_end, o_step, s_start, s_end, s_step)) in
2503            test_configs.iter().enumerate()
2504        {
2505            let output = AlmaBatchBuilder::new()
2506                .kernel(kernel)
2507                .period_range(p_start, p_end, p_step)
2508                .offset_range(o_start, o_end, o_step)
2509                .sigma_range(s_start, s_end, s_step)
2510                .apply_candles(&c, "close")?;
2511
2512            for (idx, &val) in output.values.iter().enumerate() {
2513                if val.is_nan() {
2514                    continue;
2515                }
2516
2517                let bits = val.to_bits();
2518                let row = idx / output.cols;
2519                let col = idx % output.cols;
2520                let combo = &output.combos[row];
2521
2522                if bits == 0x11111111_11111111 {
2523                    panic!(
2524						"[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2525                        at row {} col {} (flat index {}) with params: period={}, offset={}, sigma={}",
2526						test,
2527						cfg_idx,
2528						val,
2529						bits,
2530						row,
2531						col,
2532						idx,
2533						combo.period.unwrap_or(9),
2534						combo.offset.unwrap_or(0.85),
2535						combo.sigma.unwrap_or(6.0)
2536					);
2537                }
2538
2539                if bits == 0x22222222_22222222 {
2540                    panic!(
2541						"[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2542                        at row {} col {} (flat index {}) with params: period={}, offset={}, sigma={}",
2543						test,
2544						cfg_idx,
2545						val,
2546						bits,
2547						row,
2548						col,
2549						idx,
2550						combo.period.unwrap_or(9),
2551						combo.offset.unwrap_or(0.85),
2552						combo.sigma.unwrap_or(6.0)
2553					);
2554                }
2555
2556                if bits == 0x33333333_33333333 {
2557                    panic!(
2558						"[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2559                        at row {} col {} (flat index {}) with params: period={}, offset={}, sigma={}",
2560						test,
2561						cfg_idx,
2562						val,
2563						bits,
2564						row,
2565						col,
2566						idx,
2567						combo.period.unwrap_or(9),
2568						combo.offset.unwrap_or(0.85),
2569						combo.sigma.unwrap_or(6.0)
2570					);
2571                }
2572            }
2573        }
2574
2575        Ok(())
2576    }
2577
2578    #[cfg(not(debug_assertions))]
2579    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2580        Ok(())
2581    }
2582
2583    gen_batch_tests!(check_batch_default_row);
2584    gen_batch_tests!(check_batch_sweep);
2585    gen_batch_tests!(check_batch_no_poison);
2586
2587    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2588    #[test]
2589    fn test_alma_simd128_correctness() {
2590        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2591        let period = 5;
2592        let offset = 0.85;
2593        let sigma = 6.0;
2594
2595        let params = AlmaParams {
2596            period: Some(period),
2597            offset: Some(offset),
2598            sigma: Some(sigma),
2599        };
2600        let input = AlmaInput::from_slice(&data, params);
2601        let scalar_output = alma_with_kernel(&input, Kernel::Scalar).unwrap();
2602
2603        let simd128_output = alma_with_kernel(&input, Kernel::Scalar).unwrap();
2604
2605        assert_eq!(scalar_output.values.len(), simd128_output.values.len());
2606        for (i, (scalar_val, simd_val)) in scalar_output
2607            .values
2608            .iter()
2609            .zip(simd128_output.values.iter())
2610            .enumerate()
2611        {
2612            assert!(
2613                (scalar_val - simd_val).abs() < 1e-10,
2614                "SIMD128 mismatch at index {}: scalar={}, simd128={}",
2615                i,
2616                scalar_val,
2617                simd_val
2618            );
2619        }
2620    }
2621}
2622
2623#[cfg(feature = "python")]
2624#[pyfunction(name = "alma")]
2625#[pyo3(signature = (data, period, offset, sigma, kernel=None))]
2626
2627pub fn alma_py<'py>(
2628    py: Python<'py>,
2629    data: numpy::PyReadonlyArray1<'py, f64>,
2630    period: usize,
2631    offset: f64,
2632    sigma: f64,
2633    kernel: Option<&str>,
2634) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2635    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2636
2637    let kern = validate_kernel(kernel, false)?;
2638    let params = AlmaParams {
2639        period: Some(period),
2640        offset: Some(offset),
2641        sigma: Some(sigma),
2642    };
2643
2644    let result_vec: Vec<f64> = if let Ok(slice_in) = data.as_slice() {
2645        let alma_in = AlmaInput::from_slice(slice_in, params);
2646        py.allow_threads(|| alma_with_kernel(&alma_in, kern).map(|o| o.values))
2647            .map_err(|e| PyValueError::new_err(e.to_string()))?
2648    } else {
2649        let owned = data.as_array().to_owned();
2650        let slice_in = owned.as_slice().expect("owned array should be contiguous");
2651        let alma_in = AlmaInput::from_slice(slice_in, params);
2652        let out = py
2653            .allow_threads(|| alma_with_kernel(&alma_in, kern).map(|o| o.values))
2654            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2655        out
2656    };
2657
2658    Ok(result_vec.into_pyarray(py))
2659}
2660
2661#[cfg(feature = "python")]
2662#[pyclass(name = "AlmaStream")]
2663pub struct AlmaStreamPy {
2664    stream: AlmaStream,
2665}
2666
2667#[cfg(feature = "python")]
2668#[pymethods]
2669impl AlmaStreamPy {
2670    #[new]
2671    fn new(period: usize, offset: f64, sigma: f64) -> PyResult<Self> {
2672        let params = AlmaParams {
2673            period: Some(period),
2674            offset: Some(offset),
2675            sigma: Some(sigma),
2676        };
2677        let stream =
2678            AlmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2679        Ok(AlmaStreamPy { stream })
2680    }
2681
2682    fn update(&mut self, value: f64) -> Option<f64> {
2683        self.stream.update(value)
2684    }
2685}
2686
2687#[cfg(feature = "python")]
2688#[pyfunction(name = "alma_batch")]
2689#[pyo3(signature = (data, period_range, offset_range, sigma_range, kernel=None))]
2690
2691pub fn alma_batch_py<'py>(
2692    py: Python<'py>,
2693    data: numpy::PyReadonlyArray1<'py, f64>,
2694    period_range: (usize, usize, usize),
2695    offset_range: (f64, f64, f64),
2696    sigma_range: (f64, f64, f64),
2697    kernel: Option<&str>,
2698) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2699    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2700    use pyo3::types::PyDict;
2701
2702    let slice_in = data.as_slice()?;
2703
2704    let sweep = AlmaBatchRange {
2705        period: period_range,
2706        offset: offset_range,
2707        sigma: sigma_range,
2708    };
2709
2710    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2711    let rows = combos.len();
2712    let cols = slice_in.len();
2713    let total = rows
2714        .checked_mul(cols)
2715        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2716
2717    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2718    let slice_out = unsafe { out_arr.as_slice_mut()? };
2719
2720    let kern = validate_kernel(kernel, true)?;
2721
2722    let combos = py
2723        .allow_threads(|| {
2724            let kernel = match kern {
2725                Kernel::Auto => detect_best_batch_kernel(),
2726                k => k,
2727            };
2728            let simd = match kernel {
2729                Kernel::Avx512Batch => Kernel::Avx512,
2730                Kernel::Avx2Batch => Kernel::Avx2,
2731                Kernel::ScalarBatch => Kernel::Scalar,
2732                _ => unreachable!(),
2733            };
2734            alma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2735        })
2736        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2737
2738    let dict = PyDict::new(py);
2739    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2740    dict.set_item(
2741        "periods",
2742        combos
2743            .iter()
2744            .map(|p| p.period.unwrap() as u64)
2745            .collect::<Vec<_>>()
2746            .into_pyarray(py),
2747    )?;
2748    dict.set_item(
2749        "offsets",
2750        combos
2751            .iter()
2752            .map(|p| p.offset.unwrap())
2753            .collect::<Vec<_>>()
2754            .into_pyarray(py),
2755    )?;
2756    dict.set_item(
2757        "sigmas",
2758        combos
2759            .iter()
2760            .map(|p| p.sigma.unwrap())
2761            .collect::<Vec<_>>()
2762            .into_pyarray(py),
2763    )?;
2764
2765    Ok(dict)
2766}
2767
2768#[cfg(all(feature = "python", feature = "cuda"))]
2769#[pyfunction(name = "alma_cuda_batch_dev")]
2770#[pyo3(signature = (data_f32, period_range, offset_range, sigma_range, device_id=0))]
2771pub fn alma_cuda_batch_dev_py(
2772    py: Python<'_>,
2773    data_f32: numpy::PyReadonlyArray1<'_, f32>,
2774    period_range: (usize, usize, usize),
2775    offset_range: (f64, f64, f64),
2776    sigma_range: (f64, f64, f64),
2777    device_id: usize,
2778) -> PyResult<DeviceArrayF32Py> {
2779    use crate::cuda::cuda_available;
2780    use crate::cuda::moving_averages::CudaAlma;
2781
2782    if !cuda_available() {
2783        return Err(PyValueError::new_err("CUDA not available"));
2784    }
2785
2786    let slice_in: &[f32] = data_f32.as_slice()?;
2787    let sweep = AlmaBatchRange {
2788        period: period_range,
2789        offset: offset_range,
2790        sigma: sigma_range,
2791    };
2792
2793    let (inner, ctx, dev_id) = py.allow_threads(|| {
2794        let cuda = CudaAlma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2795        let ctx = cuda.context_arc();
2796        let dev_id = device_id as u32;
2797        cuda.alma_batch_dev(slice_in, &sweep)
2798            .map(|inner| (inner, ctx, dev_id))
2799            .map_err(|e| PyValueError::new_err(e.to_string()))
2800    })?;
2801
2802    Ok(DeviceArrayF32Py {
2803        inner,
2804        _ctx: Some(ctx),
2805        device_id: Some(dev_id),
2806    })
2807}
2808
2809#[cfg(all(feature = "python", feature = "cuda"))]
2810#[pyfunction(name = "alma_cuda_many_series_one_param_dev")]
2811#[pyo3(signature = (data_tm_f32, period, offset, sigma, device_id=0))]
2812pub fn alma_cuda_many_series_one_param_dev_py(
2813    py: Python<'_>,
2814    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2815    period: usize,
2816    offset: f64,
2817    sigma: f64,
2818    device_id: usize,
2819) -> PyResult<DeviceArrayF32Py> {
2820    use crate::cuda::cuda_available;
2821    use crate::cuda::moving_averages::CudaAlma;
2822    use numpy::PyUntypedArrayMethods;
2823
2824    if !cuda_available() {
2825        return Err(PyValueError::new_err("CUDA not available"));
2826    }
2827
2828    let flat_in: &[f32] = data_tm_f32.as_slice()?;
2829    let rows = data_tm_f32.shape()[0];
2830    let cols = data_tm_f32.shape()[1];
2831    let params = AlmaParams {
2832        period: Some(period),
2833        offset: Some(offset),
2834        sigma: Some(sigma),
2835    };
2836
2837    let (inner, ctx, dev_id) = py.allow_threads(|| {
2838        let cuda = CudaAlma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2839        let ctx = cuda.context_arc();
2840        let dev_id = device_id as u32;
2841        cuda.alma_multi_series_one_param_time_major_dev(flat_in, cols, rows, &params)
2842            .map(|inner| (inner, ctx, dev_id))
2843            .map_err(|e| PyValueError::new_err(e.to_string()))
2844    })?;
2845
2846    Ok(DeviceArrayF32Py {
2847        inner,
2848        _ctx: Some(ctx),
2849        device_id: Some(dev_id),
2850    })
2851}
2852
2853#[cfg(feature = "python")]
2854pub fn register_alma_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2855    m.add_function(wrap_pyfunction!(alma_py, m)?)?;
2856    m.add_function(wrap_pyfunction!(alma_batch_py, m)?)?;
2857    m.add_class::<AlmaStreamPy>()?;
2858
2859    #[cfg(feature = "cuda")]
2860    {
2861        m.add_class::<DeviceArrayF32Py>()?;
2862        m.add_function(wrap_pyfunction!(alma_cuda_batch_dev_py, m)?)?;
2863        m.add_function(wrap_pyfunction!(alma_cuda_many_series_one_param_dev_py, m)?)?;
2864    }
2865    Ok(())
2866}
2867
2868#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2869#[wasm_bindgen]
2870pub fn alma_js(data: &[f64], period: usize, offset: f64, sigma: f64) -> Result<Vec<f64>, JsValue> {
2871    let params = AlmaParams {
2872        period: Some(period),
2873        offset: Some(offset),
2874        sigma: Some(sigma),
2875    };
2876    let input = AlmaInput::from_slice(data, params);
2877
2878    let mut output = vec![0.0; data.len()];
2879
2880    alma_into_slice(&mut output, &input, detect_best_kernel())
2881        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2882
2883    Ok(output)
2884}
2885
2886#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2887#[derive(Serialize, Deserialize)]
2888pub struct AlmaBatchConfig {
2889    pub period_range: (usize, usize, usize),
2890    pub offset_range: (f64, f64, f64),
2891    pub sigma_range: (f64, f64, f64),
2892}
2893
2894#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2895#[derive(Serialize, Deserialize)]
2896pub struct AlmaBatchJsOutput {
2897    pub values: Vec<f64>,
2898    pub combos: Vec<AlmaParams>,
2899    pub rows: usize,
2900    pub cols: usize,
2901}
2902
2903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2904#[wasm_bindgen(js_name = alma_batch)]
2905pub fn alma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2906    let config: AlmaBatchConfig = serde_wasm_bindgen::from_value(config)
2907        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2908
2909    let sweep = AlmaBatchRange {
2910        period: config.period_range,
2911        offset: config.offset_range,
2912        sigma: config.sigma_range,
2913    };
2914
2915    let output = alma_batch_inner(data, &sweep, detect_best_kernel(), false)
2916        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2917
2918    let js_output = AlmaBatchJsOutput {
2919        values: output.values,
2920        combos: output.combos,
2921        rows: output.rows,
2922        cols: output.cols,
2923    };
2924
2925    serde_wasm_bindgen::to_value(&js_output)
2926        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2927}
2928
2929#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2930#[wasm_bindgen]
2931pub fn alma_alloc(len: usize) -> *mut f64 {
2932    let mut vec = Vec::<f64>::with_capacity(len);
2933    let ptr = vec.as_mut_ptr();
2934    std::mem::forget(vec);
2935    ptr
2936}
2937
2938#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2939#[wasm_bindgen]
2940pub fn alma_free(ptr: *mut f64, len: usize) {
2941    unsafe {
2942        let _ = Vec::from_raw_parts(ptr, len, len);
2943    }
2944}
2945
2946#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2947#[wasm_bindgen]
2948pub fn alma_into(
2949    in_ptr: *const f64,
2950    out_ptr: *mut f64,
2951    len: usize,
2952    period: usize,
2953    offset: f64,
2954    sigma: f64,
2955) -> Result<(), JsValue> {
2956    if in_ptr.is_null() || out_ptr.is_null() {
2957        return Err(JsValue::from_str("null pointer passed to alma_into"));
2958    }
2959
2960    unsafe {
2961        let data = std::slice::from_raw_parts(in_ptr, len);
2962
2963        if period == 0 || period > len {
2964            return Err(JsValue::from_str("Invalid period"));
2965        }
2966
2967        let params = AlmaParams {
2968            period: Some(period),
2969            offset: Some(offset),
2970            sigma: Some(sigma),
2971        };
2972        let input = AlmaInput::from_slice(data, params);
2973
2974        if in_ptr == out_ptr {
2975            let mut temp = vec![0.0; len];
2976            alma_into_slice(&mut temp, &input, detect_best_kernel())
2977                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2978            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2979            out.copy_from_slice(&temp);
2980        } else {
2981            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2982            alma_into_slice(out, &input, detect_best_kernel())
2983                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2984        }
2985
2986        Ok(())
2987    }
2988}
2989
2990#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2991#[wasm_bindgen]
2992#[deprecated(
2993    since = "1.0.0",
2994    note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2995)]
2996pub struct AlmaContext {
2997    weights: AVec<f64>,
2998    inv_norm: f64,
2999    period: usize,
3000    first: usize,
3001    kernel: Kernel,
3002}
3003
3004#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3005#[wasm_bindgen]
3006#[allow(deprecated)]
3007impl AlmaContext {
3008    #[wasm_bindgen(constructor)]
3009    #[deprecated(
3010        since = "1.0.0",
3011        note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
3012    )]
3013    pub fn new(period: usize, offset: f64, sigma: f64) -> Result<AlmaContext, JsValue> {
3014        if period == 0 {
3015            return Err(JsValue::from_str("Invalid period: 0"));
3016        }
3017        if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
3018            return Err(JsValue::from_str(&format!("Invalid offset: {}", offset)));
3019        }
3020        if sigma <= 0.0 {
3021            return Err(JsValue::from_str(&format!("Invalid sigma: {}", sigma)));
3022        }
3023
3024        let m = offset * (period - 1) as f64;
3025        let s = period as f64 / sigma;
3026        let s2 = 2.0 * s * s;
3027
3028        let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, period);
3029        weights.resize(period, 0.0);
3030        let mut norm = 0.0;
3031
3032        for i in 0..period {
3033            let w = (-(i as f64 - m).powi(2) / s2).exp();
3034            weights[i] = w;
3035            norm += w;
3036        }
3037
3038        let inv_norm = 1.0 / norm;
3039
3040        Ok(AlmaContext {
3041            weights,
3042            inv_norm,
3043            period,
3044            first: 0,
3045            kernel: detect_best_kernel(),
3046        })
3047    }
3048
3049    pub fn update_into(
3050        &self,
3051        in_ptr: *const f64,
3052        out_ptr: *mut f64,
3053        len: usize,
3054    ) -> Result<(), JsValue> {
3055        if len < self.period {
3056            return Err(JsValue::from_str("Data length less than period"));
3057        }
3058
3059        unsafe {
3060            let data = std::slice::from_raw_parts(in_ptr, len);
3061            let out = std::slice::from_raw_parts_mut(out_ptr, len);
3062
3063            let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
3064
3065            if in_ptr == out_ptr {
3066                let mut temp = vec![0.0; len];
3067                alma_compute_into(
3068                    data,
3069                    self.weights.as_slice(),
3070                    self.period,
3071                    first,
3072                    self.inv_norm,
3073                    self.kernel,
3074                    &mut temp,
3075                );
3076
3077                out.copy_from_slice(&temp);
3078            } else {
3079                alma_compute_into(
3080                    data,
3081                    self.weights.as_slice(),
3082                    self.period,
3083                    first,
3084                    self.inv_norm,
3085                    self.kernel,
3086                    out,
3087                );
3088            }
3089
3090            for i in 0..(first + self.period - 1) {
3091                out[i] = f64::NAN;
3092            }
3093        }
3094
3095        Ok(())
3096    }
3097
3098    pub fn get_warmup_period(&self) -> usize {
3099        self.period - 1
3100    }
3101}
3102
3103#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3104#[wasm_bindgen]
3105pub fn alma_batch_into(
3106    in_ptr: *const f64,
3107    out_ptr: *mut f64,
3108    len: usize,
3109    period_start: usize,
3110    period_end: usize,
3111    period_step: usize,
3112    offset_start: f64,
3113    offset_end: f64,
3114    offset_step: f64,
3115    sigma_start: f64,
3116    sigma_end: f64,
3117    sigma_step: f64,
3118) -> Result<usize, JsValue> {
3119    if in_ptr.is_null() || out_ptr.is_null() {
3120        return Err(JsValue::from_str("null pointer passed to alma_batch_into"));
3121    }
3122
3123    unsafe {
3124        let data = std::slice::from_raw_parts(in_ptr, len);
3125
3126        let sweep = AlmaBatchRange {
3127            period: (period_start, period_end, period_step),
3128            offset: (offset_start, offset_end, offset_step),
3129            sigma: (sigma_start, sigma_end, sigma_step),
3130        };
3131
3132        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3133        let rows = combos.len();
3134        let cols = len;
3135        let total = rows
3136            .checked_mul(cols)
3137            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
3138
3139        let out = std::slice::from_raw_parts_mut(out_ptr, total);
3140
3141        alma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
3142            .map_err(|e| JsValue::from_str(&e.to_string()))?;
3143
3144        Ok(rows)
3145    }
3146}