1#[cfg(all(feature = "python", feature = "cuda"))]
2pub use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
3
4#[cfg(feature = "python")]
5use numpy::{IntoPyArray, PyArray1};
6#[cfg(feature = "python")]
7use pyo3::exceptions::PyValueError;
8#[cfg(feature = "python")]
9use pyo3::prelude::*;
10#[cfg(feature = "python")]
11use pyo3::types::{PyDict, PyList};
12
13#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
14use serde::{Deserialize, Serialize};
15#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
16use wasm_bindgen::prelude::*;
17
18use crate::utilities::data_loader::{source_type, Candles};
19use crate::utilities::enums::Kernel;
20use crate::utilities::helpers::{
21 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
22 make_uninit_matrix,
23};
24#[cfg(feature = "python")]
25use crate::utilities::kernel_validation::validate_kernel;
26use aligned_vec::{AVec, CACHELINE_ALIGN};
27#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
28use core::arch::x86_64::*;
29#[cfg(not(target_arch = "wasm32"))]
30use rayon::prelude::*;
31use std::alloc::{alloc, dealloc, Layout};
32use std::convert::AsRef;
33use std::error::Error;
34use std::mem::MaybeUninit;
35use thiserror::Error;
36
37impl<'a> AsRef<[f64]> for AlmaInput<'a> {
38 #[inline(always)]
39 fn as_ref(&self) -> &[f64] {
40 match &self.data {
41 AlmaData::Slice(slice) => slice,
42 AlmaData::Candles { candles, source } => source_type(candles, source),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
48pub enum AlmaData<'a> {
49 Candles {
50 candles: &'a Candles,
51 source: &'a str,
52 },
53 Slice(&'a [f64]),
54}
55
56#[derive(Debug, Clone)]
57pub struct AlmaOutput {
58 pub values: Vec<f64>,
59}
60
61#[derive(Debug, Clone)]
62#[cfg_attr(
63 all(target_arch = "wasm32", feature = "wasm"),
64 derive(Serialize, Deserialize)
65)]
66pub struct AlmaParams {
67 pub period: Option<usize>,
68 pub offset: Option<f64>,
69 pub sigma: Option<f64>,
70}
71
72impl Default for AlmaParams {
73 fn default() -> Self {
74 Self {
75 period: Some(9),
76 offset: Some(0.85),
77 sigma: Some(6.0),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct AlmaInput<'a> {
84 pub data: AlmaData<'a>,
85 pub params: AlmaParams,
86}
87
88impl<'a> AlmaInput<'a> {
89 #[inline]
90 pub fn from_candles(c: &'a Candles, s: &'a str, p: AlmaParams) -> Self {
91 Self {
92 data: AlmaData::Candles {
93 candles: c,
94 source: s,
95 },
96 params: p,
97 }
98 }
99 #[inline]
100 pub fn from_slice(sl: &'a [f64], p: AlmaParams) -> Self {
101 Self {
102 data: AlmaData::Slice(sl),
103 params: p,
104 }
105 }
106 #[inline]
107 pub fn with_default_candles(c: &'a Candles) -> Self {
108 Self::from_candles(c, "close", AlmaParams::default())
109 }
110 #[inline]
111 pub fn get_period(&self) -> usize {
112 self.params.period.unwrap_or(9)
113 }
114 #[inline]
115 pub fn get_offset(&self) -> f64 {
116 self.params.offset.unwrap_or(0.85)
117 }
118 #[inline]
119 pub fn get_sigma(&self) -> f64 {
120 self.params.sigma.unwrap_or(6.0)
121 }
122}
123
124#[derive(Copy, Clone, Debug)]
125pub struct AlmaBuilder {
126 period: Option<usize>,
127 offset: Option<f64>,
128 sigma: Option<f64>,
129 kernel: Kernel,
130}
131
132impl Default for AlmaBuilder {
133 fn default() -> Self {
134 Self {
135 period: None,
136 offset: None,
137 sigma: None,
138 kernel: Kernel::Auto,
139 }
140 }
141}
142
143impl AlmaBuilder {
144 #[inline(always)]
145 pub fn new() -> Self {
146 Self::default()
147 }
148 #[inline(always)]
149 pub fn period(mut self, n: usize) -> Self {
150 self.period = Some(n);
151 self
152 }
153 #[inline(always)]
154 pub fn offset(mut self, x: f64) -> Self {
155 self.offset = Some(x);
156 self
157 }
158 #[inline(always)]
159 pub fn sigma(mut self, s: f64) -> Self {
160 self.sigma = Some(s);
161 self
162 }
163 #[inline(always)]
164 pub fn kernel(mut self, k: Kernel) -> Self {
165 self.kernel = k;
166 self
167 }
168
169 #[inline(always)]
170 pub fn apply(self, c: &Candles) -> Result<AlmaOutput, AlmaError> {
171 let p = AlmaParams {
172 period: self.period,
173 offset: self.offset,
174 sigma: self.sigma,
175 };
176 let i = AlmaInput::from_candles(c, "close", p);
177 alma_with_kernel(&i, self.kernel)
178 }
179
180 #[inline(always)]
181 pub fn apply_slice(self, d: &[f64]) -> Result<AlmaOutput, AlmaError> {
182 let p = AlmaParams {
183 period: self.period,
184 offset: self.offset,
185 sigma: self.sigma,
186 };
187 let i = AlmaInput::from_slice(d, p);
188 alma_with_kernel(&i, self.kernel)
189 }
190
191 #[inline(always)]
192 pub fn into_stream(self) -> Result<AlmaStream, AlmaError> {
193 let p = AlmaParams {
194 period: self.period,
195 offset: self.offset,
196 sigma: self.sigma,
197 };
198 AlmaStream::try_new(p)
199 }
200}
201
202#[derive(Debug, Error)]
203pub enum AlmaError {
204 #[error("alma: Input data slice is empty.")]
205 EmptyInputData,
206 #[error("alma: All values are NaN.")]
207 AllValuesNaN,
208
209 #[error("alma: Invalid period: period = {period}, data length = {data_len}")]
210 InvalidPeriod { period: usize, data_len: usize },
211
212 #[error("alma: Not enough valid data: needed = {needed}, valid = {valid}")]
213 NotEnoughValidData { needed: usize, valid: usize },
214
215 #[error("alma: Invalid sigma: {sigma}")]
216 InvalidSigma { sigma: f64 },
217
218 #[error("alma: Invalid offset: {offset}")]
219 InvalidOffset { offset: f64 },
220
221 #[error("alma: Output length mismatch: expected {expected}, got {got}")]
222 OutputLengthMismatch { expected: usize, got: usize },
223
224 #[error("alma: Invalid range: start={start}, end={end}, step={step}")]
225 InvalidRange {
226 start: String,
227 end: String,
228 step: String,
229 },
230
231 #[error("alma: Invalid kernel for batch: {0:?}")]
232 InvalidKernelForBatch(crate::utilities::enums::Kernel),
233}
234
235#[inline]
236pub fn alma(input: &AlmaInput) -> Result<AlmaOutput, AlmaError> {
237 alma_with_kernel(input, Kernel::Auto)
238}
239
240#[inline(always)]
241fn alma_compute_into(
242 data: &[f64],
243 weights: &[f64],
244 period: usize,
245 first: usize,
246 inv_n: f64,
247 kernel: Kernel,
248 out: &mut [f64],
249) {
250 unsafe {
251 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
252 {
253 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
254 alma_simd128(data, weights, period, first, inv_n, out);
255 return;
256 }
257 }
258
259 match kernel {
260 Kernel::Scalar | Kernel::ScalarBatch => {
261 alma_scalar(data, weights, period, first, inv_n, out)
262 }
263 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
264 Kernel::Avx2 | Kernel::Avx2Batch => alma_avx2(data, weights, period, first, inv_n, out),
265 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266 Kernel::Avx512 | Kernel::Avx512Batch => {
267 alma_avx512(data, weights, period, first, inv_n, out)
268 }
269 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
270 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
271 alma_scalar(data, weights, period, first, inv_n, out)
272 }
273 _ => unreachable!(),
274 }
275 }
276}
277
278#[inline(always)]
279fn alma_prepare<'a>(
280 input: &'a AlmaInput,
281 kernel: Kernel,
282) -> Result<(&'a [f64], AVec<f64>, usize, usize, f64, Kernel), AlmaError> {
283 let data: &[f64] = input.as_ref();
284 let len = data.len();
285 if len == 0 {
286 return Err(AlmaError::EmptyInputData);
287 }
288 let first = data
289 .iter()
290 .position(|x| !x.is_nan())
291 .ok_or(AlmaError::AllValuesNaN)?;
292 let period = input.get_period();
293 let offset = input.get_offset();
294 let sigma = input.get_sigma();
295
296 if period == 0 || period > len {
297 return Err(AlmaError::InvalidPeriod {
298 period,
299 data_len: len,
300 });
301 }
302 if len - first < period {
303 return Err(AlmaError::NotEnoughValidData {
304 needed: period,
305 valid: len - first,
306 });
307 }
308 if sigma <= 0.0 {
309 return Err(AlmaError::InvalidSigma { sigma });
310 }
311 if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
312 return Err(AlmaError::InvalidOffset { offset });
313 }
314
315 let m = offset * (period - 1) as f64;
316 let s = period as f64 / sigma;
317 let s2 = 2.0 * s * s;
318
319 let aligned_period = ((period + 7) / 8) * 8;
320 let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, aligned_period);
321 weights.resize(aligned_period, 0.0);
322
323 let inv_s2 = 1.0 / s2;
324 let mut norm = 0.0;
325
326 for i in 0..period {
327 let diff = i as f64 - m;
328 let w = (-diff * diff * inv_s2).exp();
329 weights[i] = w;
330 norm += w;
331 }
332 let inv_norm = 1.0 / norm;
333
334 let chosen = match kernel {
335 Kernel::Auto => detect_best_kernel(),
336 k => k,
337 };
338
339 Ok((data, weights, period, first, inv_norm, chosen))
340}
341
342pub fn alma_with_kernel(input: &AlmaInput, kernel: Kernel) -> Result<AlmaOutput, AlmaError> {
343 let (data, weights, period, first, inv_n, chosen) = alma_prepare(input, kernel)?;
344
345 let mut out = alloc_with_nan_prefix(data.len(), first + period - 1);
346
347 alma_compute_into(data, &weights, period, first, inv_n, chosen, &mut out);
348
349 Ok(AlmaOutput { values: out })
350}
351
352#[inline]
353pub fn alma_into_slice(dst: &mut [f64], input: &AlmaInput, kern: Kernel) -> Result<(), AlmaError> {
354 let (data, weights, period, first, inv_n, chosen) = alma_prepare(input, kern)?;
355
356 if dst.len() != data.len() {
357 return Err(AlmaError::OutputLengthMismatch {
358 expected: data.len(),
359 got: dst.len(),
360 });
361 }
362
363 alma_compute_into(data, &weights, period, first, inv_n, chosen, dst);
364
365 let warmup_end = first + period - 1;
366 for v in &mut dst[..warmup_end] {
367 *v = f64::NAN;
368 }
369
370 Ok(())
371}
372
373#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
374#[inline]
375pub fn alma_into(input: &AlmaInput, out: &mut [f64]) -> Result<(), AlmaError> {
376 let (data, weights, period, first, inv_n, chosen) = alma_prepare(input, Kernel::Auto)?;
377
378 if out.len() != data.len() {
379 return Err(AlmaError::OutputLengthMismatch {
380 expected: data.len(),
381 got: out.len(),
382 });
383 }
384
385 let warmup_end = first + period - 1;
386 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
387 let warm = warmup_end.min(out.len());
388 for v in &mut out[..warm] {
389 *v = qnan;
390 }
391
392 alma_compute_into(data, &weights, period, first, inv_n, chosen, out);
393
394 Ok(())
395}
396
397#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
398#[inline]
399#[target_feature(enable = "avx512f")]
400pub unsafe fn hsum_pd_zmm(v: __m512d) -> f64 {
401 #[allow(unused_unsafe)]
402 {
403 _mm512_reduce_add_pd(v)
404 }
405}
406
407#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
408#[inline]
409#[target_feature(enable = "avx512f")]
410pub fn alma_avx512(
411 data: &[f64],
412 weights: &[f64],
413 period: usize,
414 first_valid: usize,
415 inv_norm: f64,
416 out: &mut [f64],
417) {
418 if period <= 32 {
419 unsafe { alma_avx512_short(data, weights, period, first_valid, inv_norm, out) }
420 } else {
421 unsafe { alma_avx512_long(data, weights, period, first_valid, inv_norm, out) }
422 }
423}
424
425#[inline(always)]
426pub fn alma_scalar(
427 data: &[f64],
428 weights: &[f64],
429 period: usize,
430 first_val: usize,
431 inv_norm: f64,
432 out: &mut [f64],
433) {
434 assert!(
435 weights.len() >= period,
436 "weights.len() must be at least `period`"
437 );
438 assert!(
439 out.len() >= data.len(),
440 "`out` must be at least as long as `data`"
441 );
442
443 let p4 = period & !3;
444
445 for i in (first_val + period - 1)..data.len() {
446 let start = i + 1 - period;
447 let window = &data[start..start + period];
448
449 let mut sum = 0.0;
450 for (d4, w4) in window[..p4]
451 .chunks_exact(4)
452 .zip(weights[..p4].chunks_exact(4))
453 {
454 sum += d4[0] * w4[0] + d4[1] * w4[1] + d4[2] * w4[2] + d4[3] * w4[3];
455 }
456
457 for (d, w) in window[p4..].iter().zip(&weights[p4..]) {
458 sum += d * w;
459 }
460
461 out[i] = sum * inv_norm;
462 }
463}
464
465#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
466#[inline(always)]
467unsafe fn alma_simd128(
468 data: &[f64],
469 weights: &[f64],
470 period: usize,
471 first_val: usize,
472 inv_norm: f64,
473 out: &mut [f64],
474) {
475 use core::arch::wasm32::*;
476
477 assert!(
478 weights.len() >= period,
479 "weights.len() must be at least `period`"
480 );
481 assert!(
482 out.len() >= data.len(),
483 "`out` must be at least as long as `data`"
484 );
485
486 const STEP: usize = 2;
487 let chunks = period / STEP;
488 let tail = period % STEP;
489
490 for i in (first_val + period - 1)..data.len() {
491 let start = i + 1 - period;
492 let mut acc = f64x2_splat(0.0);
493
494 for blk in 0..chunks {
495 let idx = blk * STEP;
496 let w = v128_load(weights.as_ptr().add(idx) as *const v128);
497 let d = v128_load(data.as_ptr().add(start + idx) as *const v128);
498 acc = f64x2_add(acc, f64x2_mul(d, w));
499 }
500
501 let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
502
503 if tail != 0 {
504 sum += data[start + chunks * STEP] * weights[chunks * STEP];
505 }
506
507 out[i] = sum * inv_norm;
508 }
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[inline]
513#[target_feature(enable = "avx2,fma")]
514unsafe fn alma_avx2_short(
515 data: &[f64],
516 weights: &[f64],
517 period: usize,
518 first_valid: usize,
519 inv_norm: f64,
520 out: &mut [f64],
521) {
522 const STEP: usize = 4;
523 let chunks = period / STEP;
524 let tail = period % STEP;
525
526 let tail_mask = match tail {
527 0 => _mm256_setzero_si256(),
528 1 => _mm256_setr_epi64x(-1, 0, 0, 0),
529 2 => _mm256_setr_epi64x(-1, -1, 0, 0),
530 3 => _mm256_setr_epi64x(-1, -1, -1, 0),
531 _ => unreachable!(),
532 };
533
534 for i in (first_valid + period - 1)..data.len() {
535 let start = i + 1 - period;
536 let mut acc = _mm256_setzero_pd();
537
538 for blk in 0..chunks {
539 let idx = blk * STEP;
540 let w = _mm256_loadu_pd(weights.as_ptr().add(idx));
541 let d = _mm256_loadu_pd(data.as_ptr().add(start + idx));
542 acc = _mm256_fmadd_pd(d, w, acc);
543 }
544
545 if tail != 0 {
546 let w_tail = _mm256_maskload_pd(weights.as_ptr().add(chunks * STEP), tail_mask);
547 let d_tail = _mm256_maskload_pd(data.as_ptr().add(start + chunks * STEP), tail_mask);
548 acc = _mm256_fmadd_pd(d_tail, w_tail, acc);
549 }
550
551 let hi = _mm256_extractf128_pd(acc, 1);
552 let lo = _mm256_castpd256_pd128(acc);
553 let sum2 = _mm_add_pd(hi, lo);
554 let sum1 = _mm_add_pd(sum2, _mm_unpackhi_pd(sum2, sum2));
555 let sum = _mm_cvtsd_f64(sum1);
556
557 *out.get_unchecked_mut(i) = sum * inv_norm;
558 }
559}
560
561#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
562#[inline]
563#[target_feature(enable = "avx2,fma")]
564unsafe fn alma_avx2_long(
565 data: &[f64],
566 weights: &[f64],
567 period: usize,
568 first_valid: usize,
569 inv_norm: f64,
570 out: &mut [f64],
571) {
572 const STEP: usize = 4;
573 let chunks = period / STEP;
574 let tail = period % STEP;
575
576 let paired_chunks = chunks / 2;
577 let odd_chunk = chunks % 2;
578
579 let tail_mask = match tail {
580 0 => _mm256_setzero_si256(),
581 1 => _mm256_setr_epi64x(-1, 0, 0, 0),
582 2 => _mm256_setr_epi64x(-1, -1, 0, 0),
583 3 => _mm256_setr_epi64x(-1, -1, -1, 0),
584 _ => unreachable!(),
585 };
586
587 for i in (first_valid + period - 1)..data.len() {
588 let start = i + 1 - period;
589 let mut acc0 = _mm256_setzero_pd();
590 let mut acc1 = _mm256_setzero_pd();
591
592 for blk in 0..paired_chunks {
593 let idx0 = (blk * 2) * STEP;
594 let idx1 = (blk * 2 + 1) * STEP;
595
596 let w0 = _mm256_loadu_pd(weights.as_ptr().add(idx0));
597 let w1 = _mm256_loadu_pd(weights.as_ptr().add(idx1));
598 let d0 = _mm256_loadu_pd(data.as_ptr().add(start + idx0));
599 let d1 = _mm256_loadu_pd(data.as_ptr().add(start + idx1));
600
601 acc0 = _mm256_fmadd_pd(d0, w0, acc0);
602 acc1 = _mm256_fmadd_pd(d1, w1, acc1);
603 }
604
605 if odd_chunk != 0 {
606 let idx = (paired_chunks * 2) * STEP;
607 let w = _mm256_loadu_pd(weights.as_ptr().add(idx));
608 let d = _mm256_loadu_pd(data.as_ptr().add(start + idx));
609 acc0 = _mm256_fmadd_pd(d, w, acc0);
610 }
611
612 let acc = _mm256_add_pd(acc0, acc1);
613
614 let final_acc = if tail != 0 {
615 let w_tail = _mm256_maskload_pd(weights.as_ptr().add(chunks * STEP), tail_mask);
616 let d_tail = _mm256_maskload_pd(data.as_ptr().add(start + chunks * STEP), tail_mask);
617 _mm256_fmadd_pd(d_tail, w_tail, acc)
618 } else {
619 acc
620 };
621
622 let sum128 = _mm_add_pd(
623 _mm256_castpd256_pd128(final_acc),
624 _mm256_extractf128_pd(final_acc, 1),
625 );
626 let sum = _mm_cvtsd_f64(_mm_hadd_pd(sum128, sum128));
627
628 *out.get_unchecked_mut(i) = sum * inv_norm;
629 }
630}
631
632#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
633#[inline]
634#[target_feature(enable = "avx2,fma")]
635pub fn alma_avx2(
636 data: &[f64],
637 weights: &[f64],
638 period: usize,
639 first_valid: usize,
640 inv_norm: f64,
641 out: &mut [f64],
642) {
643 if period <= 32 {
644 unsafe { alma_avx2_short(data, weights, period, first_valid, inv_norm, out) }
645 } else {
646 unsafe { alma_avx2_long(data, weights, period, first_valid, inv_norm, out) }
647 }
648}
649
650#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
651#[inline]
652#[target_feature(enable = "avx512f,fma")]
653unsafe fn alma_avx512_short(
654 data: &[f64],
655 weights: &[f64],
656 period: usize,
657 first_valid: usize,
658 inv_norm: f64,
659 out: &mut [f64],
660) {
661 debug_assert!(period >= 1);
662 debug_assert!(data.len() == out.len());
663 debug_assert!(weights.len() >= period);
664
665 const STEP: usize = 8;
666 let chunks = period / STEP;
667 let tail_len = period % STEP;
668 let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
669
670 if chunks == 0 {
671 let w_vec = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr());
672 for i in (first_valid + period - 1)..data.len() {
673 let start = i + 1 - period;
674 let d_vec = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start));
675 let sum = hsum_pd_zmm(_mm512_mul_pd(d_vec, w_vec)) * inv_norm;
676 *out.get_unchecked_mut(i) = sum;
677 }
678 return;
679 }
680
681 for i in (first_valid + period - 1)..data.len() {
682 let start = i + 1 - period;
683 let mut acc = _mm512_setzero_pd();
684
685 for blk in 0..chunks {
686 let w = _mm512_load_pd(weights.as_ptr().add(blk * STEP));
687 let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
688 acc = _mm512_fmadd_pd(d, w, acc);
689 }
690
691 if tail_len != 0 {
692 let w_tail = _mm512_maskz_loadu_pd(tail_mask, weights.as_ptr().add(chunks * STEP));
693 let d_tail = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start + chunks * STEP));
694 acc = _mm512_fmadd_pd(d_tail, w_tail, acc);
695 }
696
697 *out.get_unchecked_mut(i) = hsum_pd_zmm(acc) * inv_norm;
698 }
699}
700
701#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
702#[inline]
703#[target_feature(enable = "avx512f,fma")]
704unsafe fn alma_avx512_long(
705 data: &[f64],
706 weights: &[f64],
707 period: usize,
708 first_valid: usize,
709 inv_norm: f64,
710 out: &mut [f64],
711) {
712 const STEP: usize = 8;
713 let n_chunks = period / STEP;
714 let tail_len = period % STEP;
715
716 let unroll8 = n_chunks & !7;
717 let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
718
719 debug_assert!(period >= 1 && n_chunks > 0);
720 debug_assert_eq!(data.len(), out.len());
721 debug_assert!(weights.len() >= period);
722
723 const MAX_STACK_CHUNKS: usize = 256;
724 let mut stack_storage = MaybeUninit::<[__m512d; MAX_STACK_CHUNKS]>::uninit();
725 let mut heap_storage: Option<Vec<__m512d>> = None;
726
727 let wregs: &[__m512d] = if n_chunks <= MAX_STACK_CHUNKS {
728 let base = stack_storage.as_mut_ptr().cast::<__m512d>();
729 for blk in 0..n_chunks {
730 unsafe {
731 base.add(blk)
732 .write(_mm512_load_pd(weights.as_ptr().add(blk * STEP)));
733 }
734 }
735 unsafe { core::slice::from_raw_parts(base, n_chunks) }
736 } else {
737 let mut regs = Vec::with_capacity(n_chunks);
738 for blk in 0..n_chunks {
739 regs.push(_mm512_load_pd(weights.as_ptr().add(blk * STEP)));
740 }
741 heap_storage = Some(regs);
742 heap_storage.as_ref().unwrap().as_slice()
743 };
744 let w_tail = if tail_len != 0 {
745 Some(_mm512_maskz_loadu_pd(
746 tail_mask,
747 weights.as_ptr().add(n_chunks * STEP),
748 ))
749 } else {
750 None
751 };
752
753 let mut data_ptr = data.as_ptr().add(first_valid);
754 let stop_ptr = data.as_ptr().add(data.len());
755
756 let mut dst_ptr = out.as_mut_ptr().add(first_valid + period - 1);
757
758 if tail_len == 0 {
759 while data_ptr.add(period) <= stop_ptr {
760 let mut s0 = _mm512_setzero_pd();
761 let mut s1 = _mm512_setzero_pd();
762 let mut s2 = _mm512_setzero_pd();
763 let mut s3 = _mm512_setzero_pd();
764 let mut s4 = _mm512_setzero_pd();
765 let mut s5 = _mm512_setzero_pd();
766 let mut s6 = _mm512_setzero_pd();
767 let mut s7 = _mm512_setzero_pd();
768
769 for blk in (0..unroll8).step_by(8) {
770 let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
771 let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
772 let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
773 let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
774 let d4 = _mm512_loadu_pd(data_ptr.add((blk + 4) * STEP));
775 let d5 = _mm512_loadu_pd(data_ptr.add((blk + 5) * STEP));
776 let d6 = _mm512_loadu_pd(data_ptr.add((blk + 6) * STEP));
777 let d7 = _mm512_loadu_pd(data_ptr.add((blk + 7) * STEP));
778
779 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
780 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
781 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
782 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
783 s4 = _mm512_fmadd_pd(d4, *wregs.get_unchecked(blk + 4), s4);
784 s5 = _mm512_fmadd_pd(d5, *wregs.get_unchecked(blk + 5), s5);
785 s6 = _mm512_fmadd_pd(d6, *wregs.get_unchecked(blk + 6), s6);
786 s7 = _mm512_fmadd_pd(d7, *wregs.get_unchecked(blk + 7), s7);
787 }
788
789 for blk in unroll8..n_chunks {
790 let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
791 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
792 }
793
794 let sum01 = _mm512_add_pd(s0, s1);
795 let sum23 = _mm512_add_pd(s2, s3);
796 let sum45 = _mm512_add_pd(s4, s5);
797 let sum67 = _mm512_add_pd(s6, s7);
798 let sum0123 = _mm512_add_pd(sum01, sum23);
799 let sum4567 = _mm512_add_pd(sum45, sum67);
800 let tot = _mm512_add_pd(sum0123, sum4567);
801
802 *dst_ptr = hsum_pd_zmm(tot) * inv_norm;
803
804 data_ptr = data_ptr.add(1);
805 dst_ptr = dst_ptr.add(1);
806 }
807 } else {
808 let wt = w_tail.expect("tail_len != 0 but w_tail missing");
809
810 while data_ptr.add(period) <= stop_ptr {
811 let mut s0 = _mm512_setzero_pd();
812 let mut s1 = _mm512_setzero_pd();
813 let mut s2 = _mm512_setzero_pd();
814 let mut s3 = _mm512_setzero_pd();
815 let mut s4 = _mm512_setzero_pd();
816 let mut s5 = _mm512_setzero_pd();
817 let mut s6 = _mm512_setzero_pd();
818 let mut s7 = _mm512_setzero_pd();
819
820 for blk in (0..unroll8).step_by(8) {
821 let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
822 let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
823 let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
824 let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
825 let d4 = _mm512_loadu_pd(data_ptr.add((blk + 4) * STEP));
826 let d5 = _mm512_loadu_pd(data_ptr.add((blk + 5) * STEP));
827 let d6 = _mm512_loadu_pd(data_ptr.add((blk + 6) * STEP));
828 let d7 = _mm512_loadu_pd(data_ptr.add((blk + 7) * STEP));
829
830 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
831 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
832 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
833 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
834 s4 = _mm512_fmadd_pd(d4, *wregs.get_unchecked(blk + 4), s4);
835 s5 = _mm512_fmadd_pd(d5, *wregs.get_unchecked(blk + 5), s5);
836 s6 = _mm512_fmadd_pd(d6, *wregs.get_unchecked(blk + 6), s6);
837 s7 = _mm512_fmadd_pd(d7, *wregs.get_unchecked(blk + 7), s7);
838 }
839
840 for blk in unroll8..n_chunks {
841 let d = _mm512_loadu_pd(data_ptr.add(blk * STEP));
842 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(blk), s0);
843 }
844
845 let d_tail = _mm512_maskz_loadu_pd(tail_mask, data_ptr.add(n_chunks * STEP));
846 s0 = _mm512_fmadd_pd(d_tail, wt, s0);
847
848 let sum01 = _mm512_add_pd(s0, s1);
849 let sum23 = _mm512_add_pd(s2, s3);
850 let sum45 = _mm512_add_pd(s4, s5);
851 let sum67 = _mm512_add_pd(s6, s7);
852 let sum0123 = _mm512_add_pd(sum01, sum23);
853 let sum4567 = _mm512_add_pd(sum45, sum67);
854 let tot = _mm512_add_pd(sum0123, sum4567);
855
856 *dst_ptr = hsum_pd_zmm(tot) * inv_norm;
857
858 data_ptr = data_ptr.add(1);
859 dst_ptr = dst_ptr.add(1);
860 }
861 }
862}
863
864#[derive(Debug, Clone)]
865pub struct AlmaStream {
866 period: usize,
867
868 weights: AVec<f64>,
869 inv_norm: f64,
870
871 buffer: Vec<f64>,
872
873 buf2: Vec<f64>,
874
875 head: usize,
876 filled: usize,
877 kernel: Kernel,
878}
879
880impl AlmaStream {
881 pub fn try_new(params: AlmaParams) -> Result<Self, AlmaError> {
882 let period = params.period.unwrap_or(9);
883 if period == 0 {
884 return Err(AlmaError::InvalidPeriod {
885 period,
886 data_len: 0,
887 });
888 }
889 let offset = params.offset.unwrap_or(0.85);
890 if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
891 return Err(AlmaError::InvalidOffset { offset });
892 }
893 let sigma = params.sigma.unwrap_or(6.0);
894 if sigma <= 0.0 {
895 return Err(AlmaError::InvalidSigma { sigma });
896 }
897
898 let m = offset * (period - 1) as f64;
899 let s = period as f64 / sigma;
900 let s2 = 2.0 * s * s;
901
902 let mut weights = AVec::<f64>::with_capacity(CACHELINE_ALIGN, period);
903 weights.resize(period, 0.0);
904
905 let mut norm = 0.0;
906 for i in 0..period {
907 let diff = i as f64 - m;
908 let w = (-(diff * diff) / s2).exp();
909 weights[i] = w;
910 norm += w;
911 }
912 let inv_norm = 1.0 / norm;
913
914 let buffer = vec![f64::NAN; period];
915 let buf2 = vec![f64::NAN; period * 2];
916 let kernel = detect_best_kernel();
917
918 Ok(Self {
919 period,
920 weights,
921 inv_norm,
922 buffer,
923 buf2,
924 head: 0,
925 filled: 0,
926 kernel,
927 })
928 }
929
930 #[inline(always)]
931 pub fn update(&mut self, value: f64) -> Option<f64> {
932 let h = self.head;
933
934 self.buffer[h] = value;
935
936 self.buf2[h] = value;
937 self.buf2[h + self.period] = value;
938
939 let mut new_h = h + 1;
940 if new_h == self.period {
941 new_h = 0;
942 }
943 self.head = new_h;
944
945 if self.filled < self.period {
946 self.filled += 1;
947 if self.filled < self.period {
948 return None;
949 }
950 }
951
952 Some(self.dot_at_head())
953 }
954
955 #[inline(always)]
956 fn dot_at_head(&self) -> f64 {
957 let start = self.head;
958 let end = start + self.period;
959 let x = &self.buf2[start..end];
960 let w = &self.weights[..self.period];
961 let acc = dot_contiguous(self.kernel, x, w);
962 acc * self.inv_norm
963 }
964}
965
966#[inline(always)]
967fn dot_scalar_unrolled_safe(x: &[f64], w: &[f64]) -> f64 {
968 debug_assert_eq!(x.len(), w.len());
969 let n = x.len();
970 let mut i = 0usize;
971 let n4 = n & !3;
972 let mut s0 = 0.0f64;
973 let mut s1 = 0.0f64;
974 let mut s2 = 0.0f64;
975 let mut s3 = 0.0f64;
976
977 while i < n4 {
978 s0 += x[i] * w[i];
979 s1 += x[i + 1] * w[i + 1];
980 s2 += x[i + 2] * w[i + 2];
981 s3 += x[i + 3] * w[i + 3];
982 i += 4;
983 }
984 let mut sum = (s0 + s1) + (s2 + s3);
985 while i < n {
986 sum += x[i] * w[i];
987 i += 1;
988 }
989 sum
990}
991
992#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
993#[inline(always)]
994unsafe fn hsum256(v: __m256d) -> f64 {
995 let hi = _mm256_extractf128_pd(v, 1);
996 let lo = _mm256_castpd256_pd128(v);
997 let s = _mm_add_pd(hi, lo);
998 let s = _mm_add_sd(s, _mm_unpackhi_pd(s, s));
999 _mm_cvtsd_f64(s)
1000}
1001
1002#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1003#[inline(always)]
1004unsafe fn hsum512(v: __m512d) -> f64 {
1005 _mm512_reduce_add_pd(v)
1006}
1007
1008#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1009#[inline(always)]
1010unsafe fn dot_avx2(x: *const f64, w: *const f64, n: usize) -> f64 {
1011 let mut i = 0usize;
1012 let n4 = n & !3;
1013 let mut acc = _mm256_setzero_pd();
1014 while i < n4 {
1015 let xv = _mm256_loadu_pd(x.add(i));
1016 let wv = _mm256_loadu_pd(w.add(i));
1017 acc = _mm256_fmadd_pd(xv, wv, acc);
1018 i += 4;
1019 }
1020 let mut sum = hsum256(acc);
1021 while i < n {
1022 sum += *x.add(i) * *w.add(i);
1023 i += 1;
1024 }
1025 sum
1026}
1027
1028#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1029#[inline(always)]
1030unsafe fn dot_avx512(x: *const f64, w: *const f64, n: usize) -> f64 {
1031 let mut i = 0usize;
1032 let n8 = n & !7;
1033 let mut acc = _mm512_setzero_pd();
1034 while i < n8 {
1035 let xv = _mm512_loadu_pd(x.add(i));
1036 let wv = _mm512_loadu_pd(w.add(i));
1037 acc = _mm512_fmadd_pd(xv, wv, acc);
1038 i += 8;
1039 }
1040 let mut sum = hsum512(acc);
1041 while i < n {
1042 sum += *x.add(i) * *w.add(i);
1043 i += 1;
1044 }
1045 sum
1046}
1047
1048#[inline(always)]
1049fn dot_contiguous(kernel: Kernel, x: &[f64], w: &[f64]) -> f64 {
1050 debug_assert_eq!(x.len(), w.len());
1051 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1052 {
1053 match kernel {
1054 Kernel::Avx512 | Kernel::Avx512Batch => unsafe {
1055 return dot_avx512(x.as_ptr(), w.as_ptr(), x.len());
1056 },
1057 Kernel::Avx2 | Kernel::Avx2Batch => unsafe {
1058 return dot_avx2(x.as_ptr(), w.as_ptr(), x.len());
1059 },
1060 _ => {}
1061 }
1062 }
1063
1064 dot_scalar_unrolled_safe(x, w)
1065}
1066
1067#[derive(Clone, Debug)]
1068pub struct AlmaBatchRange {
1069 pub period: (usize, usize, usize),
1070 pub offset: (f64, f64, f64),
1071 pub sigma: (f64, f64, f64),
1072}
1073
1074impl Default for AlmaBatchRange {
1075 fn default() -> Self {
1076 Self {
1077 period: (9, 258, 1),
1078 offset: (0.85, 0.85, 0.0),
1079 sigma: (6.0, 6.0, 0.0),
1080 }
1081 }
1082}
1083
1084#[derive(Clone, Debug, Default)]
1085pub struct AlmaBatchBuilder {
1086 range: AlmaBatchRange,
1087 kernel: Kernel,
1088}
1089
1090impl AlmaBatchBuilder {
1091 pub fn new() -> Self {
1092 Self::default()
1093 }
1094
1095 pub fn kernel(mut self, k: Kernel) -> Self {
1096 self.kernel = k;
1097 self
1098 }
1099
1100 #[inline]
1101 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1102 self.range.period = (start, end, step);
1103 self
1104 }
1105 #[inline]
1106 pub fn period_static(mut self, p: usize) -> Self {
1107 self.range.period = (p, p, 0);
1108 self
1109 }
1110
1111 #[inline]
1112 pub fn offset_range(mut self, start: f64, end: f64, step: f64) -> Self {
1113 self.range.offset = (start, end, step);
1114 self
1115 }
1116 #[inline]
1117 pub fn offset_static(mut self, x: f64) -> Self {
1118 self.range.offset = (x, x, 0.0);
1119 self
1120 }
1121
1122 #[inline]
1123 pub fn sigma_range(mut self, start: f64, end: f64, step: f64) -> Self {
1124 self.range.sigma = (start, end, step);
1125 self
1126 }
1127 #[inline]
1128 pub fn sigma_static(mut self, s: f64) -> Self {
1129 self.range.sigma = (s, s, 0.0);
1130 self
1131 }
1132
1133 pub fn apply_slice(self, data: &[f64]) -> Result<AlmaBatchOutput, AlmaError> {
1134 alma_batch_with_kernel(data, &self.range, self.kernel)
1135 }
1136
1137 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<AlmaBatchOutput, AlmaError> {
1138 AlmaBatchBuilder::new().kernel(k).apply_slice(data)
1139 }
1140
1141 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<AlmaBatchOutput, AlmaError> {
1142 let slice = source_type(c, src);
1143 self.apply_slice(slice)
1144 }
1145
1146 pub fn with_default_candles(c: &Candles) -> Result<AlmaBatchOutput, AlmaError> {
1147 AlmaBatchBuilder::new()
1148 .kernel(Kernel::Auto)
1149 .apply_candles(c, "close")
1150 }
1151}
1152
1153pub fn alma_batch_with_kernel(
1154 data: &[f64],
1155 sweep: &AlmaBatchRange,
1156 k: Kernel,
1157) -> Result<AlmaBatchOutput, AlmaError> {
1158 let kernel = match k {
1159 Kernel::Auto => detect_best_batch_kernel(),
1160 other if other.is_batch() => other,
1161 _ => return Err(AlmaError::InvalidKernelForBatch(k)),
1162 };
1163
1164 let simd = match kernel {
1165 Kernel::Avx512Batch => Kernel::Avx512,
1166 Kernel::Avx2Batch => Kernel::Avx2,
1167 Kernel::ScalarBatch => Kernel::Scalar,
1168 _ => unreachable!(),
1169 };
1170 alma_batch_par_slice(data, sweep, simd)
1171}
1172
1173#[derive(Clone, Debug)]
1174pub struct AlmaBatchOutput {
1175 pub values: Vec<f64>,
1176 pub combos: Vec<AlmaParams>,
1177 pub rows: usize,
1178 pub cols: usize,
1179}
1180impl AlmaBatchOutput {
1181 pub fn row_for_params(&self, p: &AlmaParams) -> Option<usize> {
1182 self.combos.iter().position(|c| {
1183 c.period.unwrap_or(9) == p.period.unwrap_or(9)
1184 && (c.offset.unwrap_or(0.85) - p.offset.unwrap_or(0.85)).abs() < 1e-12
1185 && (c.sigma.unwrap_or(6.0) - p.sigma.unwrap_or(6.0)).abs() < 1e-12
1186 })
1187 }
1188
1189 pub fn values_for(&self, p: &AlmaParams) -> Option<&[f64]> {
1190 self.row_for_params(p).map(|row| {
1191 let start = row * self.cols;
1192 &self.values[start..start + self.cols]
1193 })
1194 }
1195}
1196
1197#[inline(always)]
1198fn expand_grid(r: &AlmaBatchRange) -> Result<Vec<AlmaParams>, AlmaError> {
1199 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, AlmaError> {
1200 if step == 0 || start == end {
1201 return Ok(vec![start]);
1202 }
1203 if start < end {
1204 return Ok((start..=end).step_by(step.max(1)).collect());
1205 }
1206
1207 let mut v = Vec::new();
1208 let mut x = start as isize;
1209 let end_i = end as isize;
1210 let st = (step as isize).max(1);
1211 while x >= end_i {
1212 v.push(x as usize);
1213 x -= st;
1214 }
1215 if v.is_empty() {
1216 return Err(AlmaError::InvalidRange {
1217 start: start.to_string(),
1218 end: end.to_string(),
1219 step: step.to_string(),
1220 });
1221 }
1222 Ok(v)
1223 }
1224 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, AlmaError> {
1225 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1226 return Ok(vec![start]);
1227 }
1228 if start < end {
1229 let mut v = Vec::new();
1230 let mut x = start;
1231 let st = step.abs();
1232 while x <= end + 1e-12 {
1233 v.push(x);
1234 x += st;
1235 }
1236 if v.is_empty() {
1237 return Err(AlmaError::InvalidRange {
1238 start: start.to_string(),
1239 end: end.to_string(),
1240 step: step.to_string(),
1241 });
1242 }
1243 return Ok(v);
1244 }
1245 let mut v = Vec::new();
1246 let mut x = start;
1247 let st = step.abs();
1248 while x + 1e-12 >= end {
1249 v.push(x);
1250 x -= st;
1251 }
1252 if v.is_empty() {
1253 return Err(AlmaError::InvalidRange {
1254 start: start.to_string(),
1255 end: end.to_string(),
1256 step: step.to_string(),
1257 });
1258 }
1259 Ok(v)
1260 }
1261
1262 let periods = axis_usize(r.period)?;
1263 let offsets = axis_f64(r.offset)?;
1264 let sigmas = axis_f64(r.sigma)?;
1265
1266 let cap = periods
1267 .len()
1268 .checked_mul(offsets.len())
1269 .and_then(|x| x.checked_mul(sigmas.len()))
1270 .ok_or_else(|| AlmaError::InvalidRange {
1271 start: "cap".into(),
1272 end: "overflow".into(),
1273 step: "mul".into(),
1274 })?;
1275
1276 let mut out = Vec::with_capacity(cap);
1277 for &p in &periods {
1278 for &o in &offsets {
1279 for &s in &sigmas {
1280 out.push(AlmaParams {
1281 period: Some(p),
1282 offset: Some(o),
1283 sigma: Some(s),
1284 });
1285 }
1286 }
1287 }
1288 Ok(out)
1289}
1290
1291#[inline(always)]
1292pub fn alma_batch_slice(
1293 data: &[f64],
1294 sweep: &AlmaBatchRange,
1295 kern: Kernel,
1296) -> Result<AlmaBatchOutput, AlmaError> {
1297 alma_batch_inner(data, sweep, kern, false)
1298}
1299
1300#[inline(always)]
1301pub fn alma_batch_par_slice(
1302 data: &[f64],
1303 sweep: &AlmaBatchRange,
1304 kern: Kernel,
1305) -> Result<AlmaBatchOutput, AlmaError> {
1306 alma_batch_inner(data, sweep, kern, true)
1307}
1308
1309#[inline]
1310fn round_up8(x: usize) -> usize {
1311 (x + 7) & !7
1312}
1313
1314#[inline(always)]
1315fn alma_batch_inner(
1316 data: &[f64],
1317 sweep: &AlmaBatchRange,
1318 kern: Kernel,
1319 parallel: bool,
1320) -> Result<AlmaBatchOutput, AlmaError> {
1321 let combos = expand_grid(sweep)?;
1322 let cols = data.len();
1323 let rows = combos.len();
1324
1325 if cols == 0 {
1326 return Err(AlmaError::AllValuesNaN);
1327 }
1328
1329 let _ = rows
1330 .checked_mul(cols)
1331 .ok_or_else(|| AlmaError::InvalidRange {
1332 start: rows.to_string(),
1333 end: cols.to_string(),
1334 step: "rows*cols".into(),
1335 })?;
1336 let mut buf_mu = make_uninit_matrix(rows, cols);
1337
1338 let warm: Vec<usize> = combos
1339 .iter()
1340 .map(|c| data.iter().position(|x| !x.is_nan()).unwrap_or(0) + c.period.unwrap() - 1)
1341 .collect();
1342 init_matrix_prefixes(&mut buf_mu, cols, &warm);
1343
1344 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
1345 let out: &mut [f64] = unsafe {
1346 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
1347 };
1348
1349 alma_batch_inner_into(data, sweep, kern, parallel, out)?;
1350
1351 let values = unsafe {
1352 Vec::from_raw_parts(
1353 buf_guard.as_mut_ptr() as *mut f64,
1354 buf_guard.len(),
1355 buf_guard.capacity(),
1356 )
1357 };
1358
1359 Ok(AlmaBatchOutput {
1360 values,
1361 combos,
1362 rows,
1363 cols,
1364 })
1365}
1366
1367#[inline(always)]
1368fn alma_batch_inner_into(
1369 data: &[f64],
1370 sweep: &AlmaBatchRange,
1371 kern: Kernel,
1372 parallel: bool,
1373 out: &mut [f64],
1374) -> Result<Vec<AlmaParams>, AlmaError> {
1375 let combos = expand_grid(sweep)?;
1376 if combos.is_empty() {
1377 return Err(AlmaError::InvalidRange {
1378 start: "range".into(),
1379 end: "range".into(),
1380 step: "empty".into(),
1381 });
1382 }
1383
1384 let first = data
1385 .iter()
1386 .position(|x| !x.is_nan())
1387 .ok_or(AlmaError::AllValuesNaN)?;
1388 let max_p = combos
1389 .iter()
1390 .map(|c| round_up8(c.period.unwrap()))
1391 .max()
1392 .unwrap();
1393 if data.len() - first < max_p {
1394 return Err(AlmaError::NotEnoughValidData {
1395 needed: max_p,
1396 valid: data.len() - first,
1397 });
1398 }
1399
1400 let rows = combos.len();
1401 let cols = data.len();
1402 let mut inv_norms = vec![0.0; rows];
1403
1404 let cap = rows
1405 .checked_mul(max_p)
1406 .ok_or_else(|| AlmaError::InvalidRange {
1407 start: rows.to_string(),
1408 end: max_p.to_string(),
1409 step: "rows*max_p".into(),
1410 })?;
1411 let mut flat_w = AVec::<f64>::with_capacity(CACHELINE_ALIGN, cap);
1412 flat_w.resize(cap, 0.0);
1413
1414 for (row, prm) in combos.iter().enumerate() {
1415 let period = prm.period.unwrap();
1416 let offset = prm.offset.unwrap();
1417 let sigma = prm.sigma.unwrap();
1418
1419 if sigma <= 0.0 {
1420 return Err(AlmaError::InvalidSigma { sigma });
1421 }
1422 if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
1423 return Err(AlmaError::InvalidOffset { offset });
1424 }
1425
1426 let m = offset * (period - 1) as f64;
1427 let s = period as f64 / sigma;
1428 let s2 = 2.0 * s * s;
1429
1430 let mut norm = 0.0;
1431 for i in 0..period {
1432 let w = (-(i as f64 - m).powi(2) / s2).exp();
1433 flat_w[row * max_p + i] = w;
1434 norm += w;
1435 }
1436 inv_norms[row] = 1.0 / norm;
1437 }
1438 let out_uninit = unsafe {
1439 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1440 };
1441
1442 let warm: Vec<usize> = combos
1443 .iter()
1444 .map(|c| first + c.period.unwrap() - 1)
1445 .collect();
1446 init_matrix_prefixes(out_uninit, cols, &warm);
1447
1448 let actual_kern = match kern {
1449 Kernel::Auto => detect_best_batch_kernel(),
1450 k => k,
1451 };
1452
1453 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1454 let period = combos[row].period.unwrap();
1455 let w_ptr = flat_w.as_ptr().add(row * max_p);
1456 let inv_n = *inv_norms.get_unchecked(row);
1457
1458 let dst = core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1459
1460 match actual_kern {
1461 Kernel::Scalar | Kernel::ScalarBatch => {
1462 alma_row_scalar(data, first, period, w_ptr, inv_n, dst)
1463 }
1464 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1465 Kernel::Avx2 | Kernel::Avx2Batch => {
1466 alma_row_avx2(data, first, period, w_ptr, inv_n, dst)
1467 }
1468 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1469 Kernel::Avx512 | Kernel::Avx512Batch => {
1470 alma_row_avx512(data, first, period, w_ptr, inv_n, dst)
1471 }
1472 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
1473 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
1474 alma_row_scalar(data, first, period, w_ptr, inv_n, dst)
1475 }
1476 Kernel::Auto => unreachable!("Auto kernel should have been resolved"),
1477 }
1478 };
1479
1480 if parallel {
1481 #[cfg(not(target_arch = "wasm32"))]
1482 {
1483 out_uninit
1484 .par_chunks_mut(cols)
1485 .enumerate()
1486 .for_each(|(row, slice)| do_row(row, slice));
1487 }
1488
1489 #[cfg(target_arch = "wasm32")]
1490 {
1491 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1492 do_row(row, slice);
1493 }
1494 }
1495 } else {
1496 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1497 do_row(row, slice);
1498 }
1499 }
1500
1501 Ok(combos)
1502}
1503
1504#[inline(always)]
1505unsafe fn alma_row_scalar(
1506 data: &[f64],
1507 first: usize,
1508 period: usize,
1509 w_ptr: *const f64,
1510 inv_n: f64,
1511 out: &mut [f64],
1512) {
1513 let p4 = period & !3;
1514 for i in (first + period - 1)..data.len() {
1515 let start = i + 1 - period;
1516 let mut sum = 0.0;
1517 for k in (0..p4).step_by(4) {
1518 let w = std::slice::from_raw_parts(w_ptr.add(k), 4);
1519 let d = &data[start + k..start + k + 4];
1520 sum += d[0] * w[0] + d[1] * w[1] + d[2] * w[2] + d[3] * w[3];
1521 }
1522 for k in p4..period {
1523 sum += *data.get_unchecked(start + k) * *w_ptr.add(k);
1524 }
1525 out[i] = sum * inv_n;
1526 }
1527}
1528
1529#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1530#[target_feature(enable = "avx2,fma")]
1531unsafe fn alma_row_avx2(
1532 data: &[f64],
1533 first: usize,
1534 period: usize,
1535 w_ptr: *const f64,
1536 inv_n: f64,
1537 out: &mut [f64],
1538) {
1539 const STEP: usize = 4;
1540 let vec_blocks = period / STEP;
1541 let tail = period % STEP;
1542 let tail_mask = match tail {
1543 0 => _mm256_setzero_si256(),
1544 1 => _mm256_setr_epi64x(-1, 0, 0, 0),
1545 2 => _mm256_setr_epi64x(-1, -1, 0, 0),
1546 3 => _mm256_setr_epi64x(-1, -1, -1, 0),
1547 _ => unreachable!(),
1548 };
1549
1550 for i in (first + period - 1)..data.len() {
1551 let start = i + 1 - period;
1552 let mut acc = _mm256_setzero_pd();
1553
1554 for blk in 0..vec_blocks {
1555 let d = _mm256_loadu_pd(data.as_ptr().add(start + blk * STEP));
1556 let w = _mm256_loadu_pd(w_ptr.add(blk * STEP));
1557 acc = _mm256_fmadd_pd(d, w, acc);
1558 }
1559
1560 if tail != 0 {
1561 let d = _mm256_maskload_pd(data.as_ptr().add(start + vec_blocks * STEP), tail_mask);
1562 let w = _mm256_maskload_pd(w_ptr.add(vec_blocks * STEP), tail_mask);
1563 acc = _mm256_fmadd_pd(d, w, acc);
1564 }
1565
1566 let hi = _mm256_extractf128_pd(acc, 1);
1567 let lo = _mm256_castpd256_pd128(acc);
1568 let s2 = _mm_add_pd(hi, lo);
1569 let s1 = _mm_add_pd(s2, _mm_unpackhi_pd(s2, s2));
1570 out[i] = _mm_cvtsd_f64(s1) * inv_n;
1571 }
1572}
1573
1574#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1575#[target_feature(enable = "avx512f,fma")]
1576pub unsafe fn alma_row_avx512(
1577 data: &[f64],
1578 first: usize,
1579 period: usize,
1580 w_ptr: *const f64,
1581 inv_n: f64,
1582 out: &mut [f64],
1583) {
1584 if period <= 32 {
1585 alma_row_avx512_short(data, first, period, w_ptr, inv_n, out);
1586 } else {
1587 alma_row_avx512_long(data, first, period, w_ptr, inv_n, out);
1588 }
1589}
1590
1591#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1592#[target_feature(enable = "avx512f,fma")]
1593unsafe fn alma_row_avx512_short(
1594 data: &[f64],
1595 first: usize,
1596 period: usize,
1597 w_ptr: *const f64,
1598 inv_n: f64,
1599 out: &mut [f64],
1600) {
1601 debug_assert!(period <= 32);
1602 const STEP: usize = 8;
1603
1604 let chunks = period / STEP;
1605 let tail_len = period % STEP;
1606 let tail_mask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1607
1608 if chunks == 0 {
1609 let w_tail = _mm512_maskz_loadu_pd(tail_mask, w_ptr);
1610 for i in (first + period - 1)..data.len() {
1611 let start = i + 1 - period;
1612 let d_tail = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start));
1613 let res = hsum_pd_zmm(_mm512_mul_pd(d_tail, w_tail)) * inv_n;
1614 *out.get_unchecked_mut(i) = res;
1615 }
1616 return;
1617 }
1618
1619 for i in (first + period - 1)..data.len() {
1620 let start = i + 1 - period;
1621 let mut acc = _mm512_setzero_pd();
1622
1623 for blk in 0..chunks {
1624 let w = _mm512_load_pd(w_ptr.add(blk * STEP));
1625 let d = _mm512_loadu_pd(data.as_ptr().add(start + blk * STEP));
1626 acc = _mm512_fmadd_pd(d, w, acc);
1627 }
1628
1629 if tail_len != 0 {
1630 let w_tail = _mm512_maskz_loadu_pd(tail_mask, w_ptr.add(chunks * STEP));
1631 let d_tail = _mm512_maskz_loadu_pd(tail_mask, data.as_ptr().add(start + chunks * STEP));
1632 acc = _mm512_fmadd_pd(d_tail, w_tail, acc);
1633 }
1634
1635 let res = hsum_pd_zmm(acc) * inv_n;
1636 *out.get_unchecked_mut(i) = res;
1637 }
1638}
1639
1640#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1641#[target_feature(enable = "avx512f,fma")]
1642unsafe fn alma_row_avx512_long(
1643 data: &[f64],
1644 first: usize,
1645 period: usize,
1646 w_ptr: *const f64,
1647 inv_n: f64,
1648 out: &mut [f64],
1649) {
1650 const STEP: usize = 8;
1651 let n_chunks = period / STEP;
1652 let tail_len = period % STEP;
1653 let tmask: __mmask8 = (1u8 << tail_len).wrapping_sub(1);
1654
1655 const MAX_CHUNKS: usize = 512;
1656 debug_assert!(n_chunks + (tail_len != 0) as usize <= MAX_CHUNKS);
1657
1658 let mut wregs: [core::mem::MaybeUninit<__m512d>; MAX_CHUNKS] =
1659 core::mem::MaybeUninit::uninit().assume_init();
1660
1661 for blk in 0..n_chunks {
1662 wregs[blk]
1663 .as_mut_ptr()
1664 .write(_mm512_load_pd(w_ptr.add(blk * STEP)));
1665 }
1666 if tail_len != 0 {
1667 wregs[n_chunks]
1668 .as_mut_ptr()
1669 .write(_mm512_maskz_loadu_pd(tmask, w_ptr.add(n_chunks * STEP)));
1670 }
1671
1672 let wregs: &[__m512d] = core::slice::from_raw_parts(
1673 wregs.as_ptr() as *const __m512d,
1674 n_chunks + (tail_len != 0) as usize,
1675 );
1676
1677 if tail_len == 0 {
1678 long_kernel_no_tail(data, first, n_chunks, wregs, inv_n, out);
1679 } else {
1680 long_kernel_with_tail(data, first, n_chunks, tail_len, tmask, wregs, inv_n, out);
1681 }
1682}
1683
1684#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1685#[target_feature(enable = "avx512f,fma")]
1686unsafe fn long_kernel_no_tail(
1687 data: &[f64],
1688 first: usize,
1689 n_chunks: usize,
1690 wregs: &[__m512d],
1691 inv_n: f64,
1692 out: &mut [f64],
1693) {
1694 const STEP: usize = 8;
1695 let paired = n_chunks & !3;
1696
1697 let mut data_ptr = data.as_ptr().add(first);
1698 let stop_ptr = data.as_ptr().add(data.len());
1699 let mut dst_ptr = out.as_mut_ptr().add(first + n_chunks * STEP - 1);
1700
1701 while data_ptr < stop_ptr {
1702 let mut s0 = _mm512_setzero_pd();
1703 let mut s1 = _mm512_setzero_pd();
1704 let mut s2 = _mm512_setzero_pd();
1705 let mut s3 = _mm512_setzero_pd();
1706
1707 let mut blk = 0;
1708 while blk < paired {
1709 let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
1710 let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
1711 let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
1712 let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
1713
1714 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1715 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1716 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1717 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1718
1719 blk += 4;
1720 }
1721
1722 for r in blk..n_chunks {
1723 let d = _mm512_loadu_pd(data_ptr.add(r * STEP));
1724 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(r), s0);
1725 }
1726
1727 let sum = _mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3));
1728 let res = hsum_pd_zmm(sum) * inv_n;
1729
1730 *dst_ptr = res;
1731
1732 data_ptr = data_ptr.add(1);
1733 dst_ptr = dst_ptr.add(1);
1734 if data_ptr.add(n_chunks * STEP) > stop_ptr {
1735 break;
1736 }
1737 }
1738}
1739
1740#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1741#[target_feature(enable = "avx512f,fma")]
1742unsafe fn long_kernel_with_tail(
1743 data: &[f64],
1744 first: usize,
1745 n_chunks: usize,
1746 tail_len: usize,
1747 tmask: __mmask8,
1748 wregs: &[__m512d],
1749 inv_n: f64,
1750 out: &mut [f64],
1751) {
1752 const STEP: usize = 8;
1753 let paired = n_chunks & !3;
1754
1755 let w_tail = *wregs.get_unchecked(n_chunks);
1756
1757 let mut data_ptr = data.as_ptr().add(first);
1758 let stop_ptr = data.as_ptr().add(data.len());
1759 let mut dst_ptr = out.as_mut_ptr().add(first + n_chunks * STEP + tail_len - 1);
1760
1761 while data_ptr < stop_ptr {
1762 let mut s0 = _mm512_setzero_pd();
1763 let mut s1 = _mm512_setzero_pd();
1764 let mut s2 = _mm512_setzero_pd();
1765 let mut s3 = _mm512_setzero_pd();
1766
1767 let mut blk = 0;
1768 while blk < paired {
1769 let d0 = _mm512_loadu_pd(data_ptr.add((blk + 0) * STEP));
1770 let d1 = _mm512_loadu_pd(data_ptr.add((blk + 1) * STEP));
1771 let d2 = _mm512_loadu_pd(data_ptr.add((blk + 2) * STEP));
1772 let d3 = _mm512_loadu_pd(data_ptr.add((blk + 3) * STEP));
1773
1774 s0 = _mm512_fmadd_pd(d0, *wregs.get_unchecked(blk + 0), s0);
1775 s1 = _mm512_fmadd_pd(d1, *wregs.get_unchecked(blk + 1), s1);
1776 s2 = _mm512_fmadd_pd(d2, *wregs.get_unchecked(blk + 2), s2);
1777 s3 = _mm512_fmadd_pd(d3, *wregs.get_unchecked(blk + 3), s3);
1778
1779 blk += 4;
1780 }
1781
1782 for r in blk..n_chunks {
1783 let d = _mm512_loadu_pd(data_ptr.add(r * STEP));
1784 s0 = _mm512_fmadd_pd(d, *wregs.get_unchecked(r), s0);
1785 }
1786
1787 let d_tail = _mm512_maskz_loadu_pd(tmask, data_ptr.add(n_chunks * STEP));
1788 s0 = _mm512_fmadd_pd(d_tail, w_tail, s0);
1789
1790 let sum = _mm512_add_pd(_mm512_add_pd(s0, s1), _mm512_add_pd(s2, s3));
1791 let res = hsum_pd_zmm(sum) * inv_n;
1792
1793 *dst_ptr = res;
1794
1795 data_ptr = data_ptr.add(1);
1796 dst_ptr = dst_ptr.add(1);
1797 if data_ptr.add(n_chunks * STEP + tail_len) > stop_ptr {
1798 break;
1799 }
1800 }
1801}
1802
1803#[cfg(test)]
1804mod tests {
1805 use super::*;
1806 use crate::skip_if_unsupported;
1807 use crate::utilities::data_loader::read_candles_from_csv;
1808 #[cfg(feature = "proptest")]
1809 use proptest::prelude::*;
1810
1811 fn check_alma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1812 skip_if_unsupported!(kernel, test_name);
1813 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1814 let candles = read_candles_from_csv(file_path)?;
1815
1816 let default_params = AlmaParams {
1817 period: None,
1818 offset: None,
1819 sigma: None,
1820 };
1821 let input = AlmaInput::from_candles(&candles, "close", default_params);
1822 let output = alma_with_kernel(&input, kernel)?;
1823 assert_eq!(output.values.len(), candles.close.len());
1824
1825 Ok(())
1826 }
1827
1828 fn check_alma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1829 skip_if_unsupported!(kernel, test_name);
1830 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1831 let candles = read_candles_from_csv(file_path)?;
1832
1833 let input = AlmaInput::from_candles(&candles, "close", AlmaParams::default());
1834 let result = alma_with_kernel(&input, kernel)?;
1835 let expected_last_five = [
1836 59286.72216704,
1837 59273.53428138,
1838 59204.37290721,
1839 59155.93381742,
1840 59026.92526112,
1841 ];
1842 let start = result.values.len().saturating_sub(5);
1843 for (i, &val) in result.values[start..].iter().enumerate() {
1844 let diff = (val - expected_last_five[i]).abs();
1845 assert!(
1846 diff < 1e-8,
1847 "[{}] ALMA {:?} mismatch at idx {}: got {}, expected {}",
1848 test_name,
1849 kernel,
1850 i,
1851 val,
1852 expected_last_five[i]
1853 );
1854 }
1855 Ok(())
1856 }
1857
1858 fn check_alma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1859 skip_if_unsupported!(kernel, test_name);
1860 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1861 let candles = read_candles_from_csv(file_path)?;
1862
1863 let input = AlmaInput::with_default_candles(&candles);
1864 match input.data {
1865 AlmaData::Candles { source, .. } => assert_eq!(source, "close"),
1866 _ => panic!("Expected AlmaData::Candles"),
1867 }
1868 let output = alma_with_kernel(&input, kernel)?;
1869 assert_eq!(output.values.len(), candles.close.len());
1870
1871 Ok(())
1872 }
1873
1874 fn check_alma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1875 skip_if_unsupported!(kernel, test_name);
1876 let input_data = [10.0, 20.0, 30.0];
1877 let params = AlmaParams {
1878 period: Some(0),
1879 offset: None,
1880 sigma: None,
1881 };
1882 let input = AlmaInput::from_slice(&input_data, params);
1883 let res = alma_with_kernel(&input, kernel);
1884 assert!(
1885 res.is_err(),
1886 "[{}] ALMA should fail with zero period",
1887 test_name
1888 );
1889 Ok(())
1890 }
1891
1892 fn check_alma_period_exceeds_length(
1893 test_name: &str,
1894 kernel: Kernel,
1895 ) -> Result<(), Box<dyn Error>> {
1896 skip_if_unsupported!(kernel, test_name);
1897 let data_small = [10.0, 20.0, 30.0];
1898 let params = AlmaParams {
1899 period: Some(10),
1900 offset: None,
1901 sigma: None,
1902 };
1903 let input = AlmaInput::from_slice(&data_small, params);
1904 let res = alma_with_kernel(&input, kernel);
1905 assert!(
1906 res.is_err(),
1907 "[{}] ALMA should fail with period exceeding length",
1908 test_name
1909 );
1910 Ok(())
1911 }
1912
1913 fn check_alma_very_small_dataset(
1914 test_name: &str,
1915 kernel: Kernel,
1916 ) -> Result<(), Box<dyn Error>> {
1917 skip_if_unsupported!(kernel, test_name);
1918 let single_point = [42.0];
1919 let params = AlmaParams {
1920 period: Some(9),
1921 offset: None,
1922 sigma: None,
1923 };
1924 let input = AlmaInput::from_slice(&single_point, params);
1925 let res = alma_with_kernel(&input, kernel);
1926 assert!(
1927 res.is_err(),
1928 "[{}] ALMA should fail with insufficient data",
1929 test_name
1930 );
1931 Ok(())
1932 }
1933
1934 fn check_alma_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1935 skip_if_unsupported!(kernel, test_name);
1936 let empty: [f64; 0] = [];
1937 let input = AlmaInput::from_slice(&empty, AlmaParams::default());
1938 let res = alma_with_kernel(&input, kernel);
1939 assert!(
1940 matches!(res, Err(AlmaError::EmptyInputData)),
1941 "[{}] ALMA should fail with empty input",
1942 test_name
1943 );
1944 Ok(())
1945 }
1946
1947 fn check_alma_invalid_sigma(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1948 skip_if_unsupported!(kernel, test_name);
1949 let data = [1.0, 2.0, 3.0];
1950 let params = AlmaParams {
1951 period: Some(2),
1952 offset: None,
1953 sigma: Some(0.0),
1954 };
1955 let input = AlmaInput::from_slice(&data, params);
1956 let res = alma_with_kernel(&input, kernel);
1957 assert!(
1958 matches!(res, Err(AlmaError::InvalidSigma { .. })),
1959 "[{}] ALMA should fail with invalid sigma",
1960 test_name
1961 );
1962 Ok(())
1963 }
1964
1965 fn check_alma_invalid_offset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1966 skip_if_unsupported!(kernel, test_name);
1967 let data = [1.0, 2.0, 3.0];
1968 let params = AlmaParams {
1969 period: Some(2),
1970 offset: Some(f64::NAN),
1971 sigma: None,
1972 };
1973 let input = AlmaInput::from_slice(&data, params);
1974 let res = alma_with_kernel(&input, kernel);
1975 assert!(
1976 matches!(res, Err(AlmaError::InvalidOffset { .. })),
1977 "[{}] ALMA should fail with invalid offset",
1978 test_name
1979 );
1980 Ok(())
1981 }
1982
1983 fn check_alma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1984 skip_if_unsupported!(kernel, test_name);
1985 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1986 let candles = read_candles_from_csv(file_path)?;
1987
1988 let first_params = AlmaParams {
1989 period: Some(9),
1990 offset: None,
1991 sigma: None,
1992 };
1993 let first_input = AlmaInput::from_candles(&candles, "close", first_params);
1994 let first_result = alma_with_kernel(&first_input, kernel)?;
1995
1996 let second_params = AlmaParams {
1997 period: Some(9),
1998 offset: None,
1999 sigma: None,
2000 };
2001 let second_input = AlmaInput::from_slice(&first_result.values, second_params);
2002 let second_result = alma_with_kernel(&second_input, kernel)?;
2003
2004 assert_eq!(second_result.values.len(), first_result.values.len());
2005 let expected_last_five = [
2006 59140.73195170,
2007 59211.58090986,
2008 59238.16030697,
2009 59222.63528822,
2010 59165.14427332,
2011 ];
2012 let start = second_result.values.len().saturating_sub(5);
2013 for (i, &val) in second_result.values[start..].iter().enumerate() {
2014 let diff = (val - expected_last_five[i]).abs();
2015 assert!(
2016 diff < 1e-8,
2017 "[{}] ALMA Slice Reinput {:?} mismatch at idx {}: got {}, expected {}",
2018 test_name,
2019 kernel,
2020 i,
2021 val,
2022 expected_last_five[i]
2023 );
2024 }
2025 Ok(())
2026 }
2027
2028 fn check_alma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2029 skip_if_unsupported!(kernel, test_name);
2030 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2031 let candles = read_candles_from_csv(file_path)?;
2032
2033 let input = AlmaInput::from_candles(
2034 &candles,
2035 "close",
2036 AlmaParams {
2037 period: Some(9),
2038 offset: None,
2039 sigma: None,
2040 },
2041 );
2042 let res = alma_with_kernel(&input, kernel)?;
2043 assert_eq!(res.values.len(), candles.close.len());
2044 if res.values.len() > 240 {
2045 for (i, &val) in res.values[240..].iter().enumerate() {
2046 assert!(
2047 !val.is_nan(),
2048 "[{}] Found unexpected NaN at out-index {}",
2049 test_name,
2050 240 + i
2051 );
2052 }
2053 }
2054 Ok(())
2055 }
2056
2057 fn check_alma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2058 skip_if_unsupported!(kernel, test_name);
2059
2060 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2061 let candles = read_candles_from_csv(file_path)?;
2062
2063 let period = 9;
2064 let offset = 0.85;
2065 let sigma = 6.0;
2066
2067 let input = AlmaInput::from_candles(
2068 &candles,
2069 "close",
2070 AlmaParams {
2071 period: Some(period),
2072 offset: Some(offset),
2073 sigma: Some(sigma),
2074 },
2075 );
2076 let batch_output = alma_with_kernel(&input, kernel)?.values;
2077
2078 let mut stream = AlmaStream::try_new(AlmaParams {
2079 period: Some(period),
2080 offset: Some(offset),
2081 sigma: Some(sigma),
2082 })?;
2083
2084 let mut stream_values = Vec::with_capacity(candles.close.len());
2085 for &price in &candles.close {
2086 match stream.update(price) {
2087 Some(alma_val) => stream_values.push(alma_val),
2088 None => stream_values.push(f64::NAN),
2089 }
2090 }
2091
2092 assert_eq!(batch_output.len(), stream_values.len());
2093 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
2094 if b.is_nan() && s.is_nan() {
2095 continue;
2096 }
2097 let diff = (b - s).abs();
2098 assert!(
2099 diff < 1e-9,
2100 "[{}] ALMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
2101 test_name,
2102 i,
2103 b,
2104 s,
2105 diff
2106 );
2107 }
2108 Ok(())
2109 }
2110
2111 #[cfg(debug_assertions)]
2112 fn check_alma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2113 skip_if_unsupported!(kernel, test_name);
2114
2115 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2116 let candles = read_candles_from_csv(file_path)?;
2117
2118 let test_params = vec![
2119 AlmaParams::default(),
2120 AlmaParams {
2121 period: Some(5),
2122 offset: Some(0.5),
2123 sigma: Some(3.0),
2124 },
2125 AlmaParams {
2126 period: Some(5),
2127 offset: Some(0.85),
2128 sigma: Some(6.0),
2129 },
2130 AlmaParams {
2131 period: Some(5),
2132 offset: Some(1.0),
2133 sigma: Some(10.0),
2134 },
2135 AlmaParams {
2136 period: Some(9),
2137 offset: Some(0.2),
2138 sigma: Some(4.0),
2139 },
2140 AlmaParams {
2141 period: Some(9),
2142 offset: Some(0.85),
2143 sigma: Some(6.0),
2144 },
2145 AlmaParams {
2146 period: Some(9),
2147 offset: Some(0.95),
2148 sigma: Some(8.0),
2149 },
2150 AlmaParams {
2151 period: Some(20),
2152 offset: Some(0.0),
2153 sigma: Some(2.0),
2154 },
2155 AlmaParams {
2156 period: Some(20),
2157 offset: Some(0.5),
2158 sigma: Some(5.0),
2159 },
2160 AlmaParams {
2161 period: Some(20),
2162 offset: Some(0.85),
2163 sigma: Some(6.0),
2164 },
2165 AlmaParams {
2166 period: Some(20),
2167 offset: Some(1.0),
2168 sigma: Some(10.0),
2169 },
2170 AlmaParams {
2171 period: Some(2),
2172 offset: Some(0.0),
2173 sigma: Some(0.1),
2174 },
2175 AlmaParams {
2176 period: Some(50),
2177 offset: Some(0.5),
2178 sigma: Some(15.0),
2179 },
2180 AlmaParams {
2181 period: Some(100),
2182 offset: Some(0.85),
2183 sigma: Some(20.0),
2184 },
2185 ];
2186
2187 for (param_idx, params) in test_params.iter().enumerate() {
2188 let input = AlmaInput::from_candles(&candles, "close", params.clone());
2189 let output = alma_with_kernel(&input, kernel)?;
2190
2191 for (i, &val) in output.values.iter().enumerate() {
2192 if val.is_nan() {
2193 continue;
2194 }
2195
2196 let bits = val.to_bits();
2197
2198 if bits == 0x11111111_11111111 {
2199 panic!(
2200 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
2201 with params: period={}, offset={}, sigma={}",
2202 test_name,
2203 val,
2204 bits,
2205 i,
2206 params.period.unwrap_or(9),
2207 params.offset.unwrap_or(0.85),
2208 params.sigma.unwrap_or(6.0)
2209 );
2210 }
2211
2212 if bits == 0x22222222_22222222 {
2213 panic!(
2214 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
2215 with params: period={}, offset={}, sigma={}",
2216 test_name,
2217 val,
2218 bits,
2219 i,
2220 params.period.unwrap_or(9),
2221 params.offset.unwrap_or(0.85),
2222 params.sigma.unwrap_or(6.0)
2223 );
2224 }
2225
2226 if bits == 0x33333333_33333333 {
2227 panic!(
2228 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
2229 with params: period={}, offset={}, sigma={}",
2230 test_name,
2231 val,
2232 bits,
2233 i,
2234 params.period.unwrap_or(9),
2235 params.offset.unwrap_or(0.85),
2236 params.sigma.unwrap_or(6.0)
2237 );
2238 }
2239 }
2240 }
2241
2242 Ok(())
2243 }
2244
2245 #[cfg(not(debug_assertions))]
2246 fn check_alma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2247 Ok(())
2248 }
2249 #[cfg(feature = "proptest")]
2250 #[allow(clippy::float_cmp)]
2251 fn check_alma_property(
2252 test_name: &str,
2253 kernel: Kernel,
2254 ) -> Result<(), Box<dyn std::error::Error>> {
2255 use proptest::prelude::*;
2256 skip_if_unsupported!(kernel, test_name);
2257
2258 let strat = (1usize..=64).prop_flat_map(|period| {
2259 (
2260 prop::collection::vec(
2261 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
2262 period..400,
2263 ),
2264 Just(period),
2265 0f64..1f64,
2266 0.1f64..10.0f64,
2267 )
2268 });
2269
2270 proptest::test_runner::TestRunner::default()
2271 .run(&strat, |(data, period, offset, sigma)| {
2272 let params = AlmaParams {
2273 period: Some(period),
2274 offset: Some(offset),
2275 sigma: Some(sigma),
2276 };
2277 let input = AlmaInput::from_slice(&data, params);
2278
2279 let AlmaOutput { values: out } = alma_with_kernel(&input, kernel).unwrap();
2280 let AlmaOutput { values: ref_out } =
2281 alma_with_kernel(&input, Kernel::Scalar).unwrap();
2282
2283 for i in (period - 1)..data.len() {
2284 let window = &data[i + 1 - period..=i];
2285 let lo = window.iter().cloned().fold(f64::INFINITY, f64::min);
2286 let hi = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2287 let y = out[i];
2288 let r = ref_out[i];
2289
2290 prop_assert!(
2291 y.is_nan() || (y >= lo - 1e-9 && y <= hi + 1e-9),
2292 "idx {i}: {y} ∉ [{lo}, {hi}]"
2293 );
2294
2295 if period == 1 {
2296 prop_assert!((y - data[i]).abs() <= f64::EPSILON);
2297 }
2298
2299 if data.windows(2).all(|w| w[0] == w[1]) {
2300 prop_assert!((y - data[0]).abs() <= 1e-9);
2301 }
2302
2303 let y_bits = y.to_bits();
2304 let r_bits = r.to_bits();
2305
2306 if !y.is_finite() || !r.is_finite() {
2307 prop_assert!(
2308 y.to_bits() == r.to_bits(),
2309 "finite/NaN mismatch idx {i}: {y} vs {r}"
2310 );
2311 continue;
2312 }
2313
2314 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
2315
2316 prop_assert!(
2317 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
2318 "mismatch idx {i}: {y} vs {r} (ULP={ulp_diff})"
2319 );
2320 }
2321 Ok(())
2322 })
2323 .unwrap();
2324
2325 Ok(())
2326 }
2327
2328 macro_rules! generate_all_alma_tests {
2329 ($($test_fn:ident),*) => {
2330 paste::paste! {
2331 $(
2332 #[test]
2333 fn [<$test_fn _scalar_f64>]() {
2334 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2335 }
2336 )*
2337 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2338 $(
2339 #[test]
2340 fn [<$test_fn _avx2_f64>]() {
2341 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2342 }
2343 #[test]
2344 fn [<$test_fn _avx512_f64>]() {
2345 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2346 }
2347 )*
2348 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2349 $(
2350 #[test]
2351 fn [<$test_fn _simd128_f64>]() {
2352 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2353 }
2354 )*
2355 }
2356 }
2357 }
2358
2359 generate_all_alma_tests!(
2360 check_alma_partial_params,
2361 check_alma_accuracy,
2362 check_alma_default_candles,
2363 check_alma_zero_period,
2364 check_alma_period_exceeds_length,
2365 check_alma_very_small_dataset,
2366 check_alma_empty_input,
2367 check_alma_invalid_sigma,
2368 check_alma_invalid_offset,
2369 check_alma_reinput,
2370 check_alma_nan_handling,
2371 check_alma_streaming,
2372 check_alma_no_poison
2373 );
2374
2375 #[cfg(feature = "proptest")]
2376 generate_all_alma_tests!(check_alma_property);
2377
2378 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2379 #[test]
2380 fn test_alma_into_matches_api() -> Result<(), Box<dyn Error>> {
2381 let mut data = vec![f64::NAN; 3];
2382 data.extend((0..256).map(|i| (i as f64).sin() * 100.0 + (i as f64) * 0.1));
2383
2384 let input = AlmaInput::from_slice(&data, AlmaParams::default());
2385
2386 let baseline = alma_with_kernel(&input, Kernel::Auto)?.values;
2387
2388 let mut out = vec![0.0; data.len()];
2389 alma_into(&input, &mut out)?;
2390
2391 assert_eq!(baseline.len(), out.len());
2392
2393 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2394 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
2395 }
2396
2397 for i in 0..out.len() {
2398 assert!(
2399 eq_or_both_nan(baseline[i], out[i]),
2400 "mismatch at {}: baseline={} out={}",
2401 i,
2402 baseline[i],
2403 out[i]
2404 );
2405 }
2406
2407 Ok(())
2408 }
2409
2410 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2411 skip_if_unsupported!(kernel, test);
2412
2413 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2414 let c = read_candles_from_csv(file)?;
2415
2416 let output = AlmaBatchBuilder::new()
2417 .kernel(kernel)
2418 .apply_candles(&c, "close")?;
2419
2420 let def = AlmaParams::default();
2421 let row = output.values_for(&def).expect("default row missing");
2422
2423 assert_eq!(row.len(), c.close.len());
2424
2425 let expected = [
2426 59286.72216704,
2427 59273.53428138,
2428 59204.37290721,
2429 59155.93381742,
2430 59026.92526112,
2431 ];
2432 let start = row.len() - 5;
2433 for (i, &v) in row[start..].iter().enumerate() {
2434 assert!(
2435 (v - expected[i]).abs() < 1e-8,
2436 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2437 );
2438 }
2439 Ok(())
2440 }
2441
2442 macro_rules! gen_batch_tests {
2443 ($fn_name:ident) => {
2444 paste::paste! {
2445 #[test] fn [<$fn_name _scalar>]() {
2446 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2447 }
2448 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2449 #[test] fn [<$fn_name _avx2>]() {
2450 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2451 }
2452 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2453 #[test] fn [<$fn_name _avx512>]() {
2454 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2455 }
2456 #[test] fn [<$fn_name _auto_detect>]() {
2457 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
2458 Kernel::Auto);
2459 }
2460 }
2461 };
2462 }
2463
2464 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2465 skip_if_unsupported!(kernel, test);
2466
2467 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2468 let c = read_candles_from_csv(file)?;
2469
2470 let output = AlmaBatchBuilder::new()
2471 .kernel(kernel)
2472 .period_range(9, 20, 1)
2473 .offset_range(0.5, 1.0, 0.1)
2474 .sigma_range(3.0, 9.0, 1.0)
2475 .apply_candles(&c, "close")?;
2476
2477 let expected_combos = 12 * 6 * 7;
2478 assert_eq!(output.combos.len(), expected_combos);
2479 assert_eq!(output.rows, expected_combos);
2480 assert_eq!(output.cols, c.close.len());
2481
2482 Ok(())
2483 }
2484
2485 #[cfg(debug_assertions)]
2486 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2487 skip_if_unsupported!(kernel, test);
2488
2489 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2490 let c = read_candles_from_csv(file)?;
2491
2492 let test_configs = vec![
2493 (2, 10, 2, 0.0, 1.0, 0.2, 1.0, 10.0, 3.0),
2494 (5, 25, 5, 0.85, 0.85, 0.0, 6.0, 6.0, 0.0),
2495 (10, 10, 0, 0.0, 1.0, 0.1, 5.0, 5.0, 0.0),
2496 (2, 5, 1, 0.5, 0.5, 0.0, 3.0, 8.0, 1.0),
2497 (30, 60, 15, 0.85, 0.85, 0.0, 6.0, 6.0, 0.0),
2498 (9, 15, 3, 0.8, 0.9, 0.1, 6.0, 8.0, 2.0),
2499 (8, 12, 1, 0.7, 0.9, 0.05, 4.0, 8.0, 0.5),
2500 ];
2501
2502 for (cfg_idx, &(p_start, p_end, p_step, o_start, o_end, o_step, s_start, s_end, s_step)) in
2503 test_configs.iter().enumerate()
2504 {
2505 let output = AlmaBatchBuilder::new()
2506 .kernel(kernel)
2507 .period_range(p_start, p_end, p_step)
2508 .offset_range(o_start, o_end, o_step)
2509 .sigma_range(s_start, s_end, s_step)
2510 .apply_candles(&c, "close")?;
2511
2512 for (idx, &val) in output.values.iter().enumerate() {
2513 if val.is_nan() {
2514 continue;
2515 }
2516
2517 let bits = val.to_bits();
2518 let row = idx / output.cols;
2519 let col = idx % output.cols;
2520 let combo = &output.combos[row];
2521
2522 if bits == 0x11111111_11111111 {
2523 panic!(
2524 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2525 at row {} col {} (flat index {}) with params: period={}, offset={}, sigma={}",
2526 test,
2527 cfg_idx,
2528 val,
2529 bits,
2530 row,
2531 col,
2532 idx,
2533 combo.period.unwrap_or(9),
2534 combo.offset.unwrap_or(0.85),
2535 combo.sigma.unwrap_or(6.0)
2536 );
2537 }
2538
2539 if bits == 0x22222222_22222222 {
2540 panic!(
2541 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2542 at row {} col {} (flat index {}) with params: period={}, offset={}, sigma={}",
2543 test,
2544 cfg_idx,
2545 val,
2546 bits,
2547 row,
2548 col,
2549 idx,
2550 combo.period.unwrap_or(9),
2551 combo.offset.unwrap_or(0.85),
2552 combo.sigma.unwrap_or(6.0)
2553 );
2554 }
2555
2556 if bits == 0x33333333_33333333 {
2557 panic!(
2558 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2559 at row {} col {} (flat index {}) with params: period={}, offset={}, sigma={}",
2560 test,
2561 cfg_idx,
2562 val,
2563 bits,
2564 row,
2565 col,
2566 idx,
2567 combo.period.unwrap_or(9),
2568 combo.offset.unwrap_or(0.85),
2569 combo.sigma.unwrap_or(6.0)
2570 );
2571 }
2572 }
2573 }
2574
2575 Ok(())
2576 }
2577
2578 #[cfg(not(debug_assertions))]
2579 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2580 Ok(())
2581 }
2582
2583 gen_batch_tests!(check_batch_default_row);
2584 gen_batch_tests!(check_batch_sweep);
2585 gen_batch_tests!(check_batch_no_poison);
2586
2587 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2588 #[test]
2589 fn test_alma_simd128_correctness() {
2590 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2591 let period = 5;
2592 let offset = 0.85;
2593 let sigma = 6.0;
2594
2595 let params = AlmaParams {
2596 period: Some(period),
2597 offset: Some(offset),
2598 sigma: Some(sigma),
2599 };
2600 let input = AlmaInput::from_slice(&data, params);
2601 let scalar_output = alma_with_kernel(&input, Kernel::Scalar).unwrap();
2602
2603 let simd128_output = alma_with_kernel(&input, Kernel::Scalar).unwrap();
2604
2605 assert_eq!(scalar_output.values.len(), simd128_output.values.len());
2606 for (i, (scalar_val, simd_val)) in scalar_output
2607 .values
2608 .iter()
2609 .zip(simd128_output.values.iter())
2610 .enumerate()
2611 {
2612 assert!(
2613 (scalar_val - simd_val).abs() < 1e-10,
2614 "SIMD128 mismatch at index {}: scalar={}, simd128={}",
2615 i,
2616 scalar_val,
2617 simd_val
2618 );
2619 }
2620 }
2621}
2622
2623#[cfg(feature = "python")]
2624#[pyfunction(name = "alma")]
2625#[pyo3(signature = (data, period, offset, sigma, kernel=None))]
2626
2627pub fn alma_py<'py>(
2628 py: Python<'py>,
2629 data: numpy::PyReadonlyArray1<'py, f64>,
2630 period: usize,
2631 offset: f64,
2632 sigma: f64,
2633 kernel: Option<&str>,
2634) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
2635 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2636
2637 let kern = validate_kernel(kernel, false)?;
2638 let params = AlmaParams {
2639 period: Some(period),
2640 offset: Some(offset),
2641 sigma: Some(sigma),
2642 };
2643
2644 let result_vec: Vec<f64> = if let Ok(slice_in) = data.as_slice() {
2645 let alma_in = AlmaInput::from_slice(slice_in, params);
2646 py.allow_threads(|| alma_with_kernel(&alma_in, kern).map(|o| o.values))
2647 .map_err(|e| PyValueError::new_err(e.to_string()))?
2648 } else {
2649 let owned = data.as_array().to_owned();
2650 let slice_in = owned.as_slice().expect("owned array should be contiguous");
2651 let alma_in = AlmaInput::from_slice(slice_in, params);
2652 let out = py
2653 .allow_threads(|| alma_with_kernel(&alma_in, kern).map(|o| o.values))
2654 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2655 out
2656 };
2657
2658 Ok(result_vec.into_pyarray(py))
2659}
2660
2661#[cfg(feature = "python")]
2662#[pyclass(name = "AlmaStream")]
2663pub struct AlmaStreamPy {
2664 stream: AlmaStream,
2665}
2666
2667#[cfg(feature = "python")]
2668#[pymethods]
2669impl AlmaStreamPy {
2670 #[new]
2671 fn new(period: usize, offset: f64, sigma: f64) -> PyResult<Self> {
2672 let params = AlmaParams {
2673 period: Some(period),
2674 offset: Some(offset),
2675 sigma: Some(sigma),
2676 };
2677 let stream =
2678 AlmaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2679 Ok(AlmaStreamPy { stream })
2680 }
2681
2682 fn update(&mut self, value: f64) -> Option<f64> {
2683 self.stream.update(value)
2684 }
2685}
2686
2687#[cfg(feature = "python")]
2688#[pyfunction(name = "alma_batch")]
2689#[pyo3(signature = (data, period_range, offset_range, sigma_range, kernel=None))]
2690
2691pub fn alma_batch_py<'py>(
2692 py: Python<'py>,
2693 data: numpy::PyReadonlyArray1<'py, f64>,
2694 period_range: (usize, usize, usize),
2695 offset_range: (f64, f64, f64),
2696 sigma_range: (f64, f64, f64),
2697 kernel: Option<&str>,
2698) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2699 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2700 use pyo3::types::PyDict;
2701
2702 let slice_in = data.as_slice()?;
2703
2704 let sweep = AlmaBatchRange {
2705 period: period_range,
2706 offset: offset_range,
2707 sigma: sigma_range,
2708 };
2709
2710 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2711 let rows = combos.len();
2712 let cols = slice_in.len();
2713 let total = rows
2714 .checked_mul(cols)
2715 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
2716
2717 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2718 let slice_out = unsafe { out_arr.as_slice_mut()? };
2719
2720 let kern = validate_kernel(kernel, true)?;
2721
2722 let combos = py
2723 .allow_threads(|| {
2724 let kernel = match kern {
2725 Kernel::Auto => detect_best_batch_kernel(),
2726 k => k,
2727 };
2728 let simd = match kernel {
2729 Kernel::Avx512Batch => Kernel::Avx512,
2730 Kernel::Avx2Batch => Kernel::Avx2,
2731 Kernel::ScalarBatch => Kernel::Scalar,
2732 _ => unreachable!(),
2733 };
2734 alma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2735 })
2736 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2737
2738 let dict = PyDict::new(py);
2739 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2740 dict.set_item(
2741 "periods",
2742 combos
2743 .iter()
2744 .map(|p| p.period.unwrap() as u64)
2745 .collect::<Vec<_>>()
2746 .into_pyarray(py),
2747 )?;
2748 dict.set_item(
2749 "offsets",
2750 combos
2751 .iter()
2752 .map(|p| p.offset.unwrap())
2753 .collect::<Vec<_>>()
2754 .into_pyarray(py),
2755 )?;
2756 dict.set_item(
2757 "sigmas",
2758 combos
2759 .iter()
2760 .map(|p| p.sigma.unwrap())
2761 .collect::<Vec<_>>()
2762 .into_pyarray(py),
2763 )?;
2764
2765 Ok(dict)
2766}
2767
2768#[cfg(all(feature = "python", feature = "cuda"))]
2769#[pyfunction(name = "alma_cuda_batch_dev")]
2770#[pyo3(signature = (data_f32, period_range, offset_range, sigma_range, device_id=0))]
2771pub fn alma_cuda_batch_dev_py(
2772 py: Python<'_>,
2773 data_f32: numpy::PyReadonlyArray1<'_, f32>,
2774 period_range: (usize, usize, usize),
2775 offset_range: (f64, f64, f64),
2776 sigma_range: (f64, f64, f64),
2777 device_id: usize,
2778) -> PyResult<DeviceArrayF32Py> {
2779 use crate::cuda::cuda_available;
2780 use crate::cuda::moving_averages::CudaAlma;
2781
2782 if !cuda_available() {
2783 return Err(PyValueError::new_err("CUDA not available"));
2784 }
2785
2786 let slice_in: &[f32] = data_f32.as_slice()?;
2787 let sweep = AlmaBatchRange {
2788 period: period_range,
2789 offset: offset_range,
2790 sigma: sigma_range,
2791 };
2792
2793 let (inner, ctx, dev_id) = py.allow_threads(|| {
2794 let cuda = CudaAlma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2795 let ctx = cuda.context_arc();
2796 let dev_id = device_id as u32;
2797 cuda.alma_batch_dev(slice_in, &sweep)
2798 .map(|inner| (inner, ctx, dev_id))
2799 .map_err(|e| PyValueError::new_err(e.to_string()))
2800 })?;
2801
2802 Ok(DeviceArrayF32Py {
2803 inner,
2804 _ctx: Some(ctx),
2805 device_id: Some(dev_id),
2806 })
2807}
2808
2809#[cfg(all(feature = "python", feature = "cuda"))]
2810#[pyfunction(name = "alma_cuda_many_series_one_param_dev")]
2811#[pyo3(signature = (data_tm_f32, period, offset, sigma, device_id=0))]
2812pub fn alma_cuda_many_series_one_param_dev_py(
2813 py: Python<'_>,
2814 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
2815 period: usize,
2816 offset: f64,
2817 sigma: f64,
2818 device_id: usize,
2819) -> PyResult<DeviceArrayF32Py> {
2820 use crate::cuda::cuda_available;
2821 use crate::cuda::moving_averages::CudaAlma;
2822 use numpy::PyUntypedArrayMethods;
2823
2824 if !cuda_available() {
2825 return Err(PyValueError::new_err("CUDA not available"));
2826 }
2827
2828 let flat_in: &[f32] = data_tm_f32.as_slice()?;
2829 let rows = data_tm_f32.shape()[0];
2830 let cols = data_tm_f32.shape()[1];
2831 let params = AlmaParams {
2832 period: Some(period),
2833 offset: Some(offset),
2834 sigma: Some(sigma),
2835 };
2836
2837 let (inner, ctx, dev_id) = py.allow_threads(|| {
2838 let cuda = CudaAlma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2839 let ctx = cuda.context_arc();
2840 let dev_id = device_id as u32;
2841 cuda.alma_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
2842 .map(|inner| (inner, ctx, dev_id))
2843 .map_err(|e| PyValueError::new_err(e.to_string()))
2844 })?;
2845
2846 Ok(DeviceArrayF32Py {
2847 inner,
2848 _ctx: Some(ctx),
2849 device_id: Some(dev_id),
2850 })
2851}
2852
2853#[cfg(feature = "python")]
2854pub fn register_alma_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
2855 m.add_function(wrap_pyfunction!(alma_py, m)?)?;
2856 m.add_function(wrap_pyfunction!(alma_batch_py, m)?)?;
2857 m.add_class::<AlmaStreamPy>()?;
2858
2859 #[cfg(feature = "cuda")]
2860 {
2861 m.add_class::<DeviceArrayF32Py>()?;
2862 m.add_function(wrap_pyfunction!(alma_cuda_batch_dev_py, m)?)?;
2863 m.add_function(wrap_pyfunction!(alma_cuda_many_series_one_param_dev_py, m)?)?;
2864 }
2865 Ok(())
2866}
2867
2868#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2869#[wasm_bindgen]
2870pub fn alma_js(data: &[f64], period: usize, offset: f64, sigma: f64) -> Result<Vec<f64>, JsValue> {
2871 let params = AlmaParams {
2872 period: Some(period),
2873 offset: Some(offset),
2874 sigma: Some(sigma),
2875 };
2876 let input = AlmaInput::from_slice(data, params);
2877
2878 let mut output = vec![0.0; data.len()];
2879
2880 alma_into_slice(&mut output, &input, detect_best_kernel())
2881 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2882
2883 Ok(output)
2884}
2885
2886#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2887#[derive(Serialize, Deserialize)]
2888pub struct AlmaBatchConfig {
2889 pub period_range: (usize, usize, usize),
2890 pub offset_range: (f64, f64, f64),
2891 pub sigma_range: (f64, f64, f64),
2892}
2893
2894#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2895#[derive(Serialize, Deserialize)]
2896pub struct AlmaBatchJsOutput {
2897 pub values: Vec<f64>,
2898 pub combos: Vec<AlmaParams>,
2899 pub rows: usize,
2900 pub cols: usize,
2901}
2902
2903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2904#[wasm_bindgen(js_name = alma_batch)]
2905pub fn alma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2906 let config: AlmaBatchConfig = serde_wasm_bindgen::from_value(config)
2907 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2908
2909 let sweep = AlmaBatchRange {
2910 period: config.period_range,
2911 offset: config.offset_range,
2912 sigma: config.sigma_range,
2913 };
2914
2915 let output = alma_batch_inner(data, &sweep, detect_best_kernel(), false)
2916 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2917
2918 let js_output = AlmaBatchJsOutput {
2919 values: output.values,
2920 combos: output.combos,
2921 rows: output.rows,
2922 cols: output.cols,
2923 };
2924
2925 serde_wasm_bindgen::to_value(&js_output)
2926 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2927}
2928
2929#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2930#[wasm_bindgen]
2931pub fn alma_alloc(len: usize) -> *mut f64 {
2932 let mut vec = Vec::<f64>::with_capacity(len);
2933 let ptr = vec.as_mut_ptr();
2934 std::mem::forget(vec);
2935 ptr
2936}
2937
2938#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2939#[wasm_bindgen]
2940pub fn alma_free(ptr: *mut f64, len: usize) {
2941 unsafe {
2942 let _ = Vec::from_raw_parts(ptr, len, len);
2943 }
2944}
2945
2946#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2947#[wasm_bindgen]
2948pub fn alma_into(
2949 in_ptr: *const f64,
2950 out_ptr: *mut f64,
2951 len: usize,
2952 period: usize,
2953 offset: f64,
2954 sigma: f64,
2955) -> Result<(), JsValue> {
2956 if in_ptr.is_null() || out_ptr.is_null() {
2957 return Err(JsValue::from_str("null pointer passed to alma_into"));
2958 }
2959
2960 unsafe {
2961 let data = std::slice::from_raw_parts(in_ptr, len);
2962
2963 if period == 0 || period > len {
2964 return Err(JsValue::from_str("Invalid period"));
2965 }
2966
2967 let params = AlmaParams {
2968 period: Some(period),
2969 offset: Some(offset),
2970 sigma: Some(sigma),
2971 };
2972 let input = AlmaInput::from_slice(data, params);
2973
2974 if in_ptr == out_ptr {
2975 let mut temp = vec![0.0; len];
2976 alma_into_slice(&mut temp, &input, detect_best_kernel())
2977 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2978 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2979 out.copy_from_slice(&temp);
2980 } else {
2981 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2982 alma_into_slice(out, &input, detect_best_kernel())
2983 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2984 }
2985
2986 Ok(())
2987 }
2988}
2989
2990#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2991#[wasm_bindgen]
2992#[deprecated(
2993 since = "1.0.0",
2994 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2995)]
2996pub struct AlmaContext {
2997 weights: AVec<f64>,
2998 inv_norm: f64,
2999 period: usize,
3000 first: usize,
3001 kernel: Kernel,
3002}
3003
3004#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3005#[wasm_bindgen]
3006#[allow(deprecated)]
3007impl AlmaContext {
3008 #[wasm_bindgen(constructor)]
3009 #[deprecated(
3010 since = "1.0.0",
3011 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
3012 )]
3013 pub fn new(period: usize, offset: f64, sigma: f64) -> Result<AlmaContext, JsValue> {
3014 if period == 0 {
3015 return Err(JsValue::from_str("Invalid period: 0"));
3016 }
3017 if !(0.0..=1.0).contains(&offset) || offset.is_nan() || offset.is_infinite() {
3018 return Err(JsValue::from_str(&format!("Invalid offset: {}", offset)));
3019 }
3020 if sigma <= 0.0 {
3021 return Err(JsValue::from_str(&format!("Invalid sigma: {}", sigma)));
3022 }
3023
3024 let m = offset * (period - 1) as f64;
3025 let s = period as f64 / sigma;
3026 let s2 = 2.0 * s * s;
3027
3028 let mut weights: AVec<f64> = AVec::with_capacity(CACHELINE_ALIGN, period);
3029 weights.resize(period, 0.0);
3030 let mut norm = 0.0;
3031
3032 for i in 0..period {
3033 let w = (-(i as f64 - m).powi(2) / s2).exp();
3034 weights[i] = w;
3035 norm += w;
3036 }
3037
3038 let inv_norm = 1.0 / norm;
3039
3040 Ok(AlmaContext {
3041 weights,
3042 inv_norm,
3043 period,
3044 first: 0,
3045 kernel: detect_best_kernel(),
3046 })
3047 }
3048
3049 pub fn update_into(
3050 &self,
3051 in_ptr: *const f64,
3052 out_ptr: *mut f64,
3053 len: usize,
3054 ) -> Result<(), JsValue> {
3055 if len < self.period {
3056 return Err(JsValue::from_str("Data length less than period"));
3057 }
3058
3059 unsafe {
3060 let data = std::slice::from_raw_parts(in_ptr, len);
3061 let out = std::slice::from_raw_parts_mut(out_ptr, len);
3062
3063 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
3064
3065 if in_ptr == out_ptr {
3066 let mut temp = vec![0.0; len];
3067 alma_compute_into(
3068 data,
3069 self.weights.as_slice(),
3070 self.period,
3071 first,
3072 self.inv_norm,
3073 self.kernel,
3074 &mut temp,
3075 );
3076
3077 out.copy_from_slice(&temp);
3078 } else {
3079 alma_compute_into(
3080 data,
3081 self.weights.as_slice(),
3082 self.period,
3083 first,
3084 self.inv_norm,
3085 self.kernel,
3086 out,
3087 );
3088 }
3089
3090 for i in 0..(first + self.period - 1) {
3091 out[i] = f64::NAN;
3092 }
3093 }
3094
3095 Ok(())
3096 }
3097
3098 pub fn get_warmup_period(&self) -> usize {
3099 self.period - 1
3100 }
3101}
3102
3103#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3104#[wasm_bindgen]
3105pub fn alma_batch_into(
3106 in_ptr: *const f64,
3107 out_ptr: *mut f64,
3108 len: usize,
3109 period_start: usize,
3110 period_end: usize,
3111 period_step: usize,
3112 offset_start: f64,
3113 offset_end: f64,
3114 offset_step: f64,
3115 sigma_start: f64,
3116 sigma_end: f64,
3117 sigma_step: f64,
3118) -> Result<usize, JsValue> {
3119 if in_ptr.is_null() || out_ptr.is_null() {
3120 return Err(JsValue::from_str("null pointer passed to alma_batch_into"));
3121 }
3122
3123 unsafe {
3124 let data = std::slice::from_raw_parts(in_ptr, len);
3125
3126 let sweep = AlmaBatchRange {
3127 period: (period_start, period_end, period_step),
3128 offset: (offset_start, offset_end, offset_step),
3129 sigma: (sigma_start, sigma_end, sigma_step),
3130 };
3131
3132 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3133 let rows = combos.len();
3134 let cols = len;
3135 let total = rows
3136 .checked_mul(cols)
3137 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
3138
3139 let out = std::slice::from_raw_parts_mut(out_ptr, total);
3140
3141 alma_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
3142 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3143
3144 Ok(rows)
3145 }
3146}