Skip to main content

vector_ta/indicators/moving_averages/
jma.rs

1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaJma;
5
6#[cfg(all(feature = "python", feature = "cuda"))]
7use crate::cuda::moving_averages::jma_wrapper::DeviceArrayF32Jma;
8
9use crate::utilities::data_loader::{source_type, Candles};
10use crate::utilities::enums::Kernel;
11use crate::utilities::helpers::{
12    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
13    make_uninit_matrix,
14};
15#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
16use core::arch::x86_64::*;
17#[cfg(not(target_arch = "wasm32"))]
18use rayon::prelude::*;
19use std::convert::AsRef;
20use std::error::Error;
21use std::mem::MaybeUninit;
22use thiserror::Error;
23
24impl<'a> AsRef<[f64]> for JmaInput<'a> {
25    #[inline(always)]
26    fn as_ref(&self) -> &[f64] {
27        match &self.data {
28            JmaData::Slice(slice) => slice,
29            JmaData::Candles { candles, source } => source_type(candles, source),
30        }
31    }
32}
33
34#[derive(Debug, Clone)]
35pub enum JmaData<'a> {
36    Candles {
37        candles: &'a Candles,
38        source: &'a str,
39    },
40    Slice(&'a [f64]),
41}
42
43#[derive(Debug, Clone)]
44pub struct JmaOutput {
45    pub values: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50    all(target_arch = "wasm32", feature = "wasm"),
51    derive(serde::Serialize, serde::Deserialize)
52)]
53pub struct JmaParams {
54    pub period: Option<usize>,
55    pub phase: Option<f64>,
56    pub power: Option<u32>,
57}
58
59impl Default for JmaParams {
60    fn default() -> Self {
61        Self {
62            period: Some(7),
63            phase: Some(50.0),
64            power: Some(2),
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct JmaInput<'a> {
71    pub data: JmaData<'a>,
72    pub params: JmaParams,
73}
74
75impl<'a> JmaInput<'a> {
76    #[inline]
77    pub fn from_candles(c: &'a Candles, s: &'a str, p: JmaParams) -> Self {
78        Self {
79            data: JmaData::Candles {
80                candles: c,
81                source: s,
82            },
83            params: p,
84        }
85    }
86    #[inline]
87    pub fn from_slice(sl: &'a [f64], p: JmaParams) -> Self {
88        Self {
89            data: JmaData::Slice(sl),
90            params: p,
91        }
92    }
93    #[inline]
94    pub fn with_default_candles(c: &'a Candles) -> Self {
95        Self::from_candles(c, "close", JmaParams::default())
96    }
97    #[inline]
98    pub fn get_period(&self) -> usize {
99        self.params.period.unwrap_or(7)
100    }
101    #[inline]
102    pub fn get_phase(&self) -> f64 {
103        self.params.phase.unwrap_or(50.0)
104    }
105    #[inline]
106    pub fn get_power(&self) -> u32 {
107        self.params.power.unwrap_or(2)
108    }
109}
110
111#[derive(Copy, Clone, Debug)]
112pub struct JmaBuilder {
113    period: Option<usize>,
114    phase: Option<f64>,
115    power: Option<u32>,
116    kernel: Kernel,
117}
118
119impl Default for JmaBuilder {
120    fn default() -> Self {
121        Self {
122            period: None,
123            phase: None,
124            power: None,
125            kernel: Kernel::Auto,
126        }
127    }
128}
129
130impl JmaBuilder {
131    #[inline(always)]
132    pub fn new() -> Self {
133        Self::default()
134    }
135    #[inline(always)]
136    pub fn period(mut self, n: usize) -> Self {
137        self.period = Some(n);
138        self
139    }
140    #[inline(always)]
141    pub fn phase(mut self, x: f64) -> Self {
142        self.phase = Some(x);
143        self
144    }
145    #[inline(always)]
146    pub fn power(mut self, p: u32) -> Self {
147        self.power = Some(p);
148        self
149    }
150    #[inline(always)]
151    pub fn kernel(mut self, k: Kernel) -> Self {
152        self.kernel = k;
153        self
154    }
155
156    #[inline(always)]
157    pub fn apply(self, c: &Candles) -> Result<JmaOutput, JmaError> {
158        let p = JmaParams {
159            period: self.period,
160            phase: self.phase,
161            power: self.power,
162        };
163        let i = JmaInput::from_candles(c, "close", p);
164        jma_with_kernel(&i, self.kernel)
165    }
166
167    #[inline(always)]
168    pub fn apply_slice(self, d: &[f64]) -> Result<JmaOutput, JmaError> {
169        let p = JmaParams {
170            period: self.period,
171            phase: self.phase,
172            power: self.power,
173        };
174        let i = JmaInput::from_slice(d, p);
175        jma_with_kernel(&i, self.kernel)
176    }
177
178    #[inline(always)]
179    pub fn into_stream(self) -> Result<JmaStream, JmaError> {
180        let p = JmaParams {
181            period: self.period,
182            phase: self.phase,
183            power: self.power,
184        };
185        JmaStream::try_new(p)
186    }
187}
188
189#[derive(Debug, Error)]
190pub enum JmaError {
191    #[error("jma: All values are NaN.")]
192    AllValuesNaN,
193    #[error("jma: Empty input data.")]
194    EmptyInputData,
195    #[error("jma: Invalid period: period = {period}, data length = {data_len}")]
196    InvalidPeriod { period: usize, data_len: usize },
197    #[error("jma: Not enough valid data: needed = {needed}, valid = {valid}")]
198    NotEnoughValidData { needed: usize, valid: usize },
199    #[error("jma: Invalid phase: {phase}")]
200    InvalidPhase { phase: f64 },
201    #[error("jma: Output length mismatch: expected = {expected}, got = {got}")]
202    OutputLengthMismatch { expected: usize, got: usize },
203    #[error("jma: Invalid range (usize): start={start}, end={end}, step={step}")]
204    InvalidRangeUsize {
205        start: usize,
206        end: usize,
207        step: usize,
208    },
209    #[error("jma: Invalid range (f64): start={start}, end={end}, step={step}")]
210    InvalidRangeF64 { start: f64, end: f64, step: f64 },
211    #[error("jma: Invalid kernel for batch: {0:?}")]
212    InvalidKernelForBatch(Kernel),
213    #[error("jma: Invalid input: {0}")]
214    InvalidInput(String),
215}
216
217#[inline]
218pub fn jma(input: &JmaInput) -> Result<JmaOutput, JmaError> {
219    jma_with_kernel(input, Kernel::Auto)
220}
221
222pub fn jma_with_kernel(input: &JmaInput, kernel: Kernel) -> Result<JmaOutput, JmaError> {
223    let data: &[f64] = match &input.data {
224        JmaData::Candles { candles, source } => source_type(candles, source),
225        JmaData::Slice(sl) => sl,
226    };
227    let len = data.len();
228    if len == 0 {
229        return Err(JmaError::EmptyInputData);
230    }
231    let first = data
232        .iter()
233        .position(|x| !x.is_nan())
234        .ok_or(JmaError::AllValuesNaN)?;
235    let period = input.get_period();
236    let phase = input.get_phase();
237    let power = input.get_power();
238
239    if period == 0 || period > len {
240        return Err(JmaError::InvalidPeriod {
241            period,
242            data_len: len,
243        });
244    }
245    if (len - first) < period {
246        return Err(JmaError::NotEnoughValidData {
247            needed: period,
248            valid: len - first,
249        });
250    }
251    if phase.is_nan() || phase.is_infinite() {
252        return Err(JmaError::InvalidPhase { phase });
253    }
254
255    let chosen = match kernel {
256        Kernel::Auto => Kernel::Scalar,
257        other => other,
258    };
259
260    let mut out = alloc_with_nan_prefix(len, first);
261    unsafe {
262        match chosen {
263            Kernel::Scalar | Kernel::ScalarBatch => {
264                jma_scalar(data, period, phase, power, first, &mut out)
265            }
266            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
267            Kernel::Avx2 | Kernel::Avx2Batch => {
268                jma_avx2(data, period, phase, power, first, &mut out)
269            }
270            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
271            Kernel::Avx512 | Kernel::Avx512Batch => {
272                jma_avx512(data, period, phase, power, first, &mut out)
273            }
274            _ => unreachable!(),
275        }
276    }
277    Ok(JmaOutput { values: out })
278}
279
280pub fn jma_with_kernel_into(
281    input: &JmaInput,
282    kernel: Kernel,
283    out: &mut [f64],
284) -> Result<(), JmaError> {
285    let data: &[f64] = match &input.data {
286        JmaData::Candles { candles, source } => source_type(candles, source),
287        JmaData::Slice(sl) => sl,
288    };
289    let len = data.len();
290
291    if out.len() != len {
292        return Err(JmaError::OutputLengthMismatch {
293            expected: len,
294            got: out.len(),
295        });
296    }
297    if len == 0 {
298        return Err(JmaError::EmptyInputData);
299    }
300
301    let first = data
302        .iter()
303        .position(|x| !x.is_nan())
304        .ok_or(JmaError::AllValuesNaN)?;
305    let period = input.get_period();
306    let phase = input.get_phase();
307    let power = input.get_power();
308
309    if period == 0 || period > len {
310        return Err(JmaError::InvalidPeriod {
311            period,
312            data_len: len,
313        });
314    }
315    if (len - first) < period {
316        return Err(JmaError::NotEnoughValidData {
317            needed: period,
318            valid: len - first,
319        });
320    }
321    if phase.is_nan() || phase.is_infinite() {
322        return Err(JmaError::InvalidPhase { phase });
323    }
324
325    let chosen = match kernel {
326        Kernel::Auto => Kernel::Scalar,
327        other => other,
328    };
329
330    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
331    out[..first].fill(qnan);
332
333    unsafe {
334        match chosen {
335            Kernel::Scalar | Kernel::ScalarBatch => {
336                jma_scalar(data, period, phase, power, first, out)
337            }
338            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
339            Kernel::Avx2 | Kernel::Avx2Batch => jma_avx2(data, period, phase, power, first, out),
340            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
341            Kernel::Avx512 | Kernel::Avx512Batch => {
342                jma_avx512(data, period, phase, power, first, out)
343            }
344            _ => unreachable!(),
345        }
346    }
347    Ok(())
348}
349
350#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
351pub fn jma_into(input: &JmaInput, out: &mut [f64]) -> Result<(), JmaError> {
352    jma_with_kernel_into(input, Kernel::Auto, out)
353}
354
355#[inline]
356pub fn jma_scalar(
357    data: &[f64],
358    period: usize,
359    phase: f64,
360    power: u32,
361    first_valid: usize,
362    output: &mut [f64],
363) {
364    assert_eq!(data.len(), output.len());
365    assert!(first_valid < data.len());
366
367    let pr = if phase < -100.0 {
368        0.5
369    } else if phase > 100.0 {
370        2.5
371    } else {
372        phase / 100.0 + 1.5
373    };
374
375    let beta = {
376        let num = 0.45 * (period as f64 - 1.0);
377        num / (num + 2.0)
378    };
379    let one_minus_beta = 1.0 - beta;
380
381    let alpha = beta.powi(power as i32);
382    let one_minus_alpha = 1.0 - alpha;
383    let alpha_sq = alpha * alpha;
384    let oma_sq = one_minus_alpha * one_minus_alpha;
385
386    let mut e0 = data[first_valid];
387    let mut e1 = 0.0;
388    let mut e2 = 0.0;
389    let mut j_prev = data[first_valid];
390
391    output[first_valid] = j_prev;
392
393    let n = data.len();
394    unsafe {
395        let mut p = data.as_ptr().add(first_valid + 1);
396        let mut q = output.as_mut_ptr().add(first_valid + 1);
397        let end_ptr = data.as_ptr().add(n);
398
399        while p.add(3) < end_ptr {
400            let x0 = *p;
401            e0 = one_minus_alpha.mul_add(x0, alpha * e0);
402            e1 = (x0 - e0).mul_add(one_minus_beta, beta * e1);
403            let d0 = e0 + pr * e1 - j_prev;
404            e2 = d0.mul_add(oma_sq, alpha_sq * e2);
405            j_prev += e2;
406            *q = j_prev;
407            p = p.add(1);
408            q = q.add(1);
409
410            let x1 = *p;
411            e0 = one_minus_alpha.mul_add(x1, alpha * e0);
412            e1 = (x1 - e0).mul_add(one_minus_beta, beta * e1);
413            let d1 = e0 + pr * e1 - j_prev;
414            e2 = d1.mul_add(oma_sq, alpha_sq * e2);
415            j_prev += e2;
416            *q = j_prev;
417            p = p.add(1);
418            q = q.add(1);
419
420            let x2 = *p;
421            e0 = one_minus_alpha.mul_add(x2, alpha * e0);
422            e1 = (x2 - e0).mul_add(one_minus_beta, beta * e1);
423            let d2 = e0 + pr * e1 - j_prev;
424            e2 = d2.mul_add(oma_sq, alpha_sq * e2);
425            j_prev += e2;
426            *q = j_prev;
427            p = p.add(1);
428            q = q.add(1);
429
430            let x3 = *p;
431            e0 = one_minus_alpha.mul_add(x3, alpha * e0);
432            e1 = (x3 - e0).mul_add(one_minus_beta, beta * e1);
433            let d3 = e0 + pr * e1 - j_prev;
434            e2 = d3.mul_add(oma_sq, alpha_sq * e2);
435            j_prev += e2;
436            *q = j_prev;
437            p = p.add(1);
438            q = q.add(1);
439        }
440
441        while p < end_ptr {
442            let x = *p;
443            e0 = one_minus_alpha.mul_add(x, alpha * e0);
444            e1 = (x - e0).mul_add(one_minus_beta, beta * e1);
445            let d = e0 + pr * e1 - j_prev;
446            e2 = d.mul_add(oma_sq, alpha_sq * e2);
447            j_prev += e2;
448            *q = j_prev;
449
450            p = p.add(1);
451            q = q.add(1);
452        }
453    }
454}
455
456#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
457#[target_feature(enable = "avx2,fma")]
458#[inline]
459pub unsafe fn jma_avx2(
460    data: &[f64],
461    period: usize,
462    phase: f64,
463    power: u32,
464    first_valid: usize,
465    output: &mut [f64],
466) {
467    assert_eq!(data.len(), output.len());
468    assert!(first_valid < data.len());
469
470    let pr = if phase < -100.0 {
471        0.5
472    } else if phase > 100.0 {
473        2.5
474    } else {
475        phase / 100.0 + 1.5
476    };
477
478    let beta = {
479        let num = 0.45 * (period as f64 - 1.0);
480        num / (num + 2.0)
481    };
482    let one_minus_beta = 1.0 - beta;
483
484    let alpha = beta.powi(power as i32);
485    let one_minus_alpha = 1.0 - alpha;
486    let alpha_sq = alpha * alpha;
487    let oma_sq = one_minus_alpha * one_minus_alpha;
488
489    let mut e0 = data[first_valid];
490    let mut e1 = 0.0;
491    let mut e2 = 0.0;
492    let mut j_prev = e0;
493
494    output[first_valid] = j_prev;
495
496    let n = data.len();
497    unsafe {
498        let mut p = data.as_ptr().add(first_valid + 1);
499        let mut q = output.as_mut_ptr().add(first_valid + 1);
500        let end_ptr = data.as_ptr().add(n);
501
502        while p.add(3) < end_ptr {
503            let x0 = *p;
504            e0 = one_minus_alpha.mul_add(x0, alpha * e0);
505            e1 = (x0 - e0).mul_add(one_minus_beta, beta * e1);
506            let d0 = e0 + pr * e1 - j_prev;
507            e2 = d0.mul_add(oma_sq, alpha_sq * e2);
508            j_prev += e2;
509            *q = j_prev;
510            p = p.add(1);
511            q = q.add(1);
512
513            let x1 = *p;
514            e0 = one_minus_alpha.mul_add(x1, alpha * e0);
515            e1 = (x1 - e0).mul_add(one_minus_beta, beta * e1);
516            let d1 = e0 + pr * e1 - j_prev;
517            e2 = d1.mul_add(oma_sq, alpha_sq * e2);
518            j_prev += e2;
519            *q = j_prev;
520            p = p.add(1);
521            q = q.add(1);
522
523            let x2 = *p;
524            e0 = one_minus_alpha.mul_add(x2, alpha * e0);
525            e1 = (x2 - e0).mul_add(one_minus_beta, beta * e1);
526            let d2 = e0 + pr * e1 - j_prev;
527            e2 = d2.mul_add(oma_sq, alpha_sq * e2);
528            j_prev += e2;
529            *q = j_prev;
530            p = p.add(1);
531            q = q.add(1);
532
533            let x3 = *p;
534            e0 = one_minus_alpha.mul_add(x3, alpha * e0);
535            e1 = (x3 - e0).mul_add(one_minus_beta, beta * e1);
536            let d3 = e0 + pr * e1 - j_prev;
537            e2 = d3.mul_add(oma_sq, alpha_sq * e2);
538            j_prev += e2;
539            *q = j_prev;
540            p = p.add(1);
541            q = q.add(1);
542        }
543
544        while p < end_ptr {
545            let x = *p;
546            e0 = one_minus_alpha.mul_add(x, alpha * e0);
547            e1 = (x - e0).mul_add(one_minus_beta, beta * e1);
548            let d = e0 + pr * e1 - j_prev;
549            e2 = d.mul_add(oma_sq, alpha_sq * e2);
550            j_prev += e2;
551            *q = j_prev;
552
553            p = p.add(1);
554            q = q.add(1);
555        }
556    }
557}
558
559#[inline(always)]
560fn jma_consts(period: usize, phase: f64, power: u32) -> (f64, f64, f64, f64, f64, f64, f64) {
561    let pr = if phase < -100.0 {
562        0.5
563    } else if phase > 100.0 {
564        2.5
565    } else {
566        phase / 100.0 + 1.5
567    };
568
569    let beta = {
570        let num = 0.45 * (period as f64 - 1.0);
571        num / (num + 2.0)
572    };
573    let alpha = beta.powi(power as i32);
574    (
575        pr,
576        beta,
577        alpha,
578        alpha * alpha,
579        (1.0 - alpha) * (1.0 - alpha),
580        1.0 - alpha,
581        1.0 - beta,
582    )
583}
584
585#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
586#[target_feature(enable = "avx512f,avx512dq,avx512vl,fma")]
587#[inline]
588pub unsafe fn jma_avx512(
589    data: &[f64],
590    period: usize,
591    phase: f64,
592    power: u32,
593    first_valid: usize,
594    out: &mut [f64],
595) {
596    debug_assert!(data.len() == out.len() && first_valid < data.len());
597
598    let (pr, beta, alpha, alpha_sq, oma_sq, one_minus_alpha, one_minus_beta) =
599        jma_consts(period, phase, power);
600
601    let pr_v = _mm512_set1_pd(pr);
602    let oma_sq_v = _mm512_set1_pd(oma_sq);
603    let alpha_sq_v = _mm512_set1_pd(alpha_sq);
604    let one_minus_alpha_v = _mm512_set1_pd(one_minus_alpha);
605    let alpha_v = _mm512_set1_pd(alpha);
606    let one_minus_beta_v = _mm512_set1_pd(one_minus_beta);
607    let beta_v = _mm512_set1_pd(beta);
608
609    let mut e0 = data[first_valid];
610    let mut e1 = 0.0;
611    let mut e2 = 0.0;
612    let mut j_prev = e0;
613
614    out[first_valid] = j_prev;
615
616    let mut i = first_valid + 1;
617    let n = data.len();
618
619    while i + 7 < n {
620        macro_rules! step {
621            ($idx:expr) => {{
622                let price = *data.get_unchecked(i + $idx);
623                e0 = one_minus_alpha.mul_add(price, alpha * e0);
624                e1 = (price - e0).mul_add(one_minus_beta, beta * e1);
625                let diff = e0 + pr * e1 - j_prev;
626                e2 = diff.mul_add(oma_sq, alpha_sq * e2);
627                j_prev += e2;
628                *out.get_unchecked_mut(i + $idx) = j_prev;
629            }};
630        }
631
632        step!(0);
633        step!(1);
634        step!(2);
635        step!(3);
636        step!(4);
637        step!(5);
638        step!(6);
639        step!(7);
640
641        i += 8;
642    }
643
644    while i < n {
645        let price = *data.get_unchecked(i);
646        e0 = one_minus_alpha.mul_add(price, alpha * e0);
647        e1 = (price - e0).mul_add(one_minus_beta, beta * e1);
648        let diff = e0 + pr * e1 - j_prev;
649        e2 = diff.mul_add(oma_sq, alpha_sq * e2);
650        j_prev += e2;
651
652        *out.get_unchecked_mut(i) = j_prev;
653        i += 1;
654    }
655}
656
657#[derive(Debug, Clone)]
658pub struct JmaStream {
659    period: usize,
660    phase: f64,
661    power: u32,
662
663    alpha: f64,
664    beta: f64,
665    phase_ratio: f64,
666    one_minus_alpha: f64,
667    one_minus_beta: f64,
668    alpha_sq: f64,
669    oma_sq: f64,
670
671    initialized: bool,
672    e0: f64,
673    e1: f64,
674    e2: f64,
675    jma_prev: f64,
676}
677
678impl JmaStream {
679    #[inline(always)]
680    pub fn try_new(params: JmaParams) -> Result<Self, JmaError> {
681        let period = params.period.unwrap_or(7);
682        if period == 0 {
683            return Err(JmaError::InvalidPeriod {
684                period,
685                data_len: 0,
686            });
687        }
688        let phase = params.phase.unwrap_or(50.0);
689        if phase.is_nan() || phase.is_infinite() {
690            return Err(JmaError::InvalidPhase { phase });
691        }
692        let power = params.power.unwrap_or(2);
693
694        let clamped = phase.max(-100.0).min(100.0);
695        let phase_ratio = clamped / 100.0 + 1.5;
696
697        let numerator = 0.45 * (period as f64 - 1.0);
698        let denominator = numerator + 2.0;
699        let beta = if denominator.abs() < f64::EPSILON {
700            0.0
701        } else {
702            numerator / denominator
703        };
704
705        let alpha = pow_u32(beta, power);
706
707        let one_minus_alpha = 1.0 - alpha;
708        let one_minus_beta = 1.0 - beta;
709        let alpha_sq = alpha * alpha;
710        let oma_sq = one_minus_alpha * one_minus_alpha;
711
712        Ok(Self {
713            period,
714            phase,
715            power,
716            alpha,
717            beta,
718            phase_ratio,
719            one_minus_alpha,
720            one_minus_beta,
721            alpha_sq,
722            oma_sq,
723            initialized: false,
724            e0: f64::NAN,
725            e1: 0.0,
726            e2: 0.0,
727            jma_prev: f64::NAN,
728        })
729    }
730
731    #[inline(always)]
732    pub fn update(&mut self, value: f64) -> Option<f64> {
733        if !self.initialized {
734            if value.is_nan() {
735                return None;
736            }
737            self.initialized = true;
738            self.e0 = value;
739            self.e1 = 0.0;
740            self.e2 = 0.0;
741            self.jma_prev = value;
742            return Some(value);
743        }
744
745        self.e0 = self.one_minus_alpha * value + self.alpha * self.e0;
746        self.e1 = (value - self.e0) * self.one_minus_beta + self.beta * self.e1;
747        let diff = self.e0 + self.phase_ratio * self.e1 - self.jma_prev;
748        self.e2 = diff * self.oma_sq + self.alpha_sq * self.e2;
749        self.jma_prev += self.e2;
750        Some(self.jma_prev)
751    }
752
753    #[inline(always)]
754    pub fn update_fma(&mut self, value: f64) -> Option<f64> {
755        if !self.initialized {
756            if value.is_nan() {
757                return None;
758            }
759            self.initialized = true;
760            self.e0 = value;
761            self.e1 = 0.0;
762            self.e2 = 0.0;
763            self.jma_prev = value;
764            return Some(value);
765        }
766
767        self.e0 = self.one_minus_alpha.mul_add(value, self.alpha * self.e0);
768        self.e1 = (value - self.e0).mul_add(self.one_minus_beta, self.beta * self.e1);
769        let diff = self.e0 + self.phase_ratio * self.e1 - self.jma_prev;
770        self.e2 = diff.mul_add(self.oma_sq, self.alpha_sq * self.e2);
771        self.jma_prev += self.e2;
772        Some(self.jma_prev)
773    }
774}
775
776#[inline(always)]
777fn pow_u32(mut base: f64, mut exp: u32) -> f64 {
778    let mut acc = 1.0;
779    while exp != 0 {
780        if (exp & 1) != 0 {
781            acc *= base;
782        }
783        base *= base;
784        exp >>= 1;
785    }
786    acc
787}
788
789#[derive(Clone, Debug)]
790pub struct JmaBatchRange {
791    pub period: (usize, usize, usize),
792    pub phase: (f64, f64, f64),
793    pub power: (u32, u32, u32),
794}
795
796impl Default for JmaBatchRange {
797    fn default() -> Self {
798        Self {
799            period: (7, 256, 1),
800            phase: (50.0, 50.0, 0.0),
801            power: (2, 2, 0),
802        }
803    }
804}
805
806#[derive(Clone, Debug, Default)]
807pub struct JmaBatchBuilder {
808    range: JmaBatchRange,
809    kernel: Kernel,
810}
811
812impl JmaBatchBuilder {
813    pub fn new() -> Self {
814        Self::default()
815    }
816    pub fn kernel(mut self, k: Kernel) -> Self {
817        self.kernel = k;
818        self
819    }
820    #[inline]
821    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
822        self.range.period = (start, end, step);
823        self
824    }
825    #[inline]
826    pub fn period_static(mut self, p: usize) -> Self {
827        self.range.period = (p, p, 0);
828        self
829    }
830    #[inline]
831    pub fn phase_range(mut self, start: f64, end: f64, step: f64) -> Self {
832        self.range.phase = (start, end, step);
833        self
834    }
835    #[inline]
836    pub fn phase_static(mut self, x: f64) -> Self {
837        self.range.phase = (x, x, 0.0);
838        self
839    }
840    #[inline]
841    pub fn power_range(mut self, start: u32, end: u32, step: u32) -> Self {
842        self.range.power = (start, end, step);
843        self
844    }
845    #[inline]
846    pub fn power_static(mut self, p: u32) -> Self {
847        self.range.power = (p, p, 0);
848        self
849    }
850    pub fn apply_slice(self, data: &[f64]) -> Result<JmaBatchOutput, JmaError> {
851        jma_batch_with_kernel(data, &self.range, self.kernel)
852    }
853    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<JmaBatchOutput, JmaError> {
854        JmaBatchBuilder::new().kernel(k).apply_slice(data)
855    }
856    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<JmaBatchOutput, JmaError> {
857        let slice = source_type(c, src);
858        self.apply_slice(slice)
859    }
860    pub fn with_default_candles(c: &Candles) -> Result<JmaBatchOutput, JmaError> {
861        JmaBatchBuilder::new()
862            .kernel(Kernel::Auto)
863            .apply_candles(c, "close")
864    }
865}
866
867pub fn jma_batch_with_kernel(
868    data: &[f64],
869    sweep: &JmaBatchRange,
870    k: Kernel,
871) -> Result<JmaBatchOutput, JmaError> {
872    let kernel = match k {
873        Kernel::Auto => detect_best_batch_kernel(),
874        other if other.is_batch() => other,
875        other => return Err(JmaError::InvalidKernelForBatch(other)),
876    };
877    let simd = match kernel {
878        Kernel::Avx512Batch => Kernel::Avx512,
879        Kernel::Avx2Batch => Kernel::Avx2,
880        Kernel::ScalarBatch => Kernel::Scalar,
881        _ => unreachable!(),
882    };
883    jma_batch_par_slice(data, sweep, simd)
884}
885
886#[derive(Clone, Debug)]
887pub struct JmaBatchOutput {
888    pub values: Vec<f64>,
889    pub combos: Vec<JmaParams>,
890    pub rows: usize,
891    pub cols: usize,
892}
893
894impl JmaBatchOutput {
895    pub fn row_for_params(&self, p: &JmaParams) -> Option<usize> {
896        self.combos.iter().position(|c| {
897            c.period.unwrap_or(7) == p.period.unwrap_or(7)
898                && (c.phase.unwrap_or(50.0) - p.phase.unwrap_or(50.0)).abs() < 1e-12
899                && c.power.unwrap_or(2) == p.power.unwrap_or(2)
900        })
901    }
902    pub fn values_for(&self, p: &JmaParams) -> Option<&[f64]> {
903        self.row_for_params(p).map(|row| {
904            let start = row * self.cols;
905            &self.values[start..start + self.cols]
906        })
907    }
908}
909
910#[inline(always)]
911fn expand_grid(r: &JmaBatchRange) -> Vec<JmaParams> {
912    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
913        if step == 0 || start == end {
914            return vec![start];
915        }
916        let mut v = Vec::new();
917        if start < end {
918            let mut x = start;
919            while x <= end {
920                v.push(x);
921                match x.checked_add(step) {
922                    Some(n) => x = n,
923                    None => break,
924                }
925            }
926        } else {
927            let mut x = start;
928            while x >= end {
929                v.push(x);
930                if x < end + step {
931                    break;
932                }
933                x -= step;
934            }
935        }
936        v
937    }
938    fn axis_f64((start, end, step): (f64, f64, f64)) -> Vec<f64> {
939        const EPS: f64 = 1e-12;
940        if step.abs() < EPS || (start - end).abs() < EPS {
941            return vec![start];
942        }
943        let s = step.abs();
944        let mut v = Vec::new();
945        if start < end {
946            let mut x = start;
947            while x <= end + EPS {
948                v.push(x);
949                x += s;
950            }
951        } else {
952            let mut x = start;
953            while x >= end - EPS {
954                v.push(x);
955                x -= s;
956            }
957        }
958        v
959    }
960    fn axis_u32((start, end, step): (u32, u32, u32)) -> Vec<u32> {
961        if step == 0 || start == end {
962            return vec![start];
963        }
964        let mut v = Vec::new();
965        if start < end {
966            let mut x = start;
967            while x <= end {
968                v.push(x);
969                x = x.saturating_add(step);
970                if x == *v.last().unwrap() {
971                    break;
972                }
973            }
974        } else {
975            let mut x = start;
976            while x >= end {
977                v.push(x);
978                if x < end + step {
979                    break;
980                }
981                x = x.saturating_sub(step);
982                if x == *v.last().unwrap() {
983                    break;
984                }
985            }
986        }
987        v
988    }
989    let periods = axis_usize(r.period);
990    let phases = axis_f64(r.phase);
991    let powers = axis_u32(r.power);
992    let mut out = Vec::with_capacity(periods.len() * phases.len() * powers.len());
993    for &p in &periods {
994        for &ph in &phases {
995            for &po in &powers {
996                out.push(JmaParams {
997                    period: Some(p),
998                    phase: Some(ph),
999                    power: Some(po),
1000                });
1001            }
1002        }
1003    }
1004    out
1005}
1006
1007#[inline(always)]
1008pub fn jma_batch_slice(
1009    data: &[f64],
1010    sweep: &JmaBatchRange,
1011    kern: Kernel,
1012) -> Result<JmaBatchOutput, JmaError> {
1013    jma_batch_inner(data, sweep, kern, false)
1014}
1015
1016#[inline(always)]
1017pub fn jma_batch_par_slice(
1018    data: &[f64],
1019    sweep: &JmaBatchRange,
1020    kern: Kernel,
1021) -> Result<JmaBatchOutput, JmaError> {
1022    jma_batch_inner(data, sweep, kern, true)
1023}
1024
1025#[inline(always)]
1026fn jma_batch_inner(
1027    data: &[f64],
1028    sweep: &JmaBatchRange,
1029    kern: Kernel,
1030    parallel: bool,
1031) -> Result<JmaBatchOutput, JmaError> {
1032    let combos = expand_grid(sweep);
1033    if combos.is_empty() {
1034        return Err(JmaError::InvalidInput("no parameter combinations".into()));
1035    }
1036    let cols = data.len();
1037    if cols == 0 {
1038        return Err(JmaError::EmptyInputData);
1039    }
1040    let first = data
1041        .iter()
1042        .position(|x| !x.is_nan())
1043        .ok_or(JmaError::AllValuesNaN)?;
1044    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1045    if data.len() - first < max_p {
1046        return Err(JmaError::NotEnoughValidData {
1047            needed: max_p,
1048            valid: data.len() - first,
1049        });
1050    }
1051    let rows = combos.len();
1052
1053    let warm: Vec<usize> = vec![first; rows];
1054
1055    let _cap = rows
1056        .checked_mul(cols)
1057        .ok_or_else(|| JmaError::InvalidInput("rows * cols overflow".into()))?;
1058    let mut raw = make_uninit_matrix(rows, cols);
1059    init_matrix_prefixes(&mut raw, cols, &warm);
1060
1061    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
1062        let prm = &combos[row];
1063        let period = prm.period.unwrap();
1064        let phase = prm.phase.unwrap();
1065        let power = prm.power.unwrap();
1066
1067        let out_row =
1068            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
1069
1070        match kern {
1071            Kernel::Scalar => jma_row_scalar(data, first, period, phase, power, out_row),
1072            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1073            Kernel::Avx2 => jma_row_avx2(data, first, period, phase, power, out_row),
1074            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1075            Kernel::Avx512 => jma_row_avx512(data, first, period, phase, power, out_row),
1076            _ => unreachable!(),
1077        }
1078    };
1079
1080    if parallel {
1081        #[cfg(not(target_arch = "wasm32"))]
1082        {
1083            raw.par_chunks_mut(cols)
1084                .enumerate()
1085                .for_each(|(row, slice)| do_row(row, slice));
1086        }
1087
1088        #[cfg(target_arch = "wasm32")]
1089        {
1090            for (row, slice) in raw.chunks_mut(cols).enumerate() {
1091                do_row(row, slice);
1092            }
1093        }
1094    } else {
1095        for (row, slice) in raw.chunks_mut(cols).enumerate() {
1096            do_row(row, slice);
1097        }
1098    }
1099
1100    let mut guard = core::mem::ManuallyDrop::new(raw);
1101    let values = unsafe {
1102        Vec::from_raw_parts(
1103            guard.as_mut_ptr() as *mut f64,
1104            guard.len(),
1105            guard.capacity(),
1106        )
1107    };
1108
1109    Ok(JmaBatchOutput {
1110        values,
1111        combos,
1112        rows,
1113        cols,
1114    })
1115}
1116
1117#[inline(always)]
1118fn jma_batch_inner_into(
1119    data: &[f64],
1120    sweep: &JmaBatchRange,
1121    kern: Kernel,
1122    parallel: bool,
1123    out: &mut [f64],
1124) -> Result<(Vec<JmaParams>, usize, usize), JmaError> {
1125    let combos = expand_grid(sweep);
1126    if combos.is_empty() {
1127        return Err(JmaError::InvalidInput("no parameter combinations".into()));
1128    }
1129    let cols = data.len();
1130    if cols == 0 {
1131        return Err(JmaError::EmptyInputData);
1132    }
1133    let first = data
1134        .iter()
1135        .position(|x| !x.is_nan())
1136        .ok_or(JmaError::AllValuesNaN)?;
1137    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1138    if cols - first < max_p {
1139        return Err(JmaError::NotEnoughValidData {
1140            needed: max_p,
1141            valid: cols - first,
1142        });
1143    }
1144    let rows = combos.len();
1145
1146    let expected = rows
1147        .checked_mul(cols)
1148        .ok_or_else(|| JmaError::InvalidInput("rows * cols overflow".into()))?;
1149    if out.len() != expected {
1150        return Err(JmaError::OutputLengthMismatch {
1151            expected,
1152            got: out.len(),
1153        });
1154    }
1155
1156    let warm: Vec<usize> = vec![first; rows];
1157    let out_uninit = unsafe {
1158        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1159    };
1160    init_matrix_prefixes(out_uninit, cols, &warm);
1161
1162    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
1163        let prm = &combos[row];
1164        let period = prm.period.unwrap();
1165        let phase = prm.phase.unwrap();
1166        let power = prm.power.unwrap();
1167
1168        match kern {
1169            Kernel::Scalar => jma_row_scalar(data, first, period, phase, power, out_row),
1170            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1171            Kernel::Avx2 => jma_row_avx2(data, first, period, phase, power, out_row),
1172            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1173            Kernel::Avx512 => jma_row_avx512(data, first, period, phase, power, out_row),
1174            _ => unreachable!(),
1175        }
1176    };
1177
1178    if parallel {
1179        #[cfg(not(target_arch = "wasm32"))]
1180        {
1181            out.par_chunks_mut(cols)
1182                .enumerate()
1183                .for_each(|(row, slice)| do_row(row, slice));
1184        }
1185        #[cfg(target_arch = "wasm32")]
1186        {
1187            for (row, slice) in out.chunks_mut(cols).enumerate() {
1188                do_row(row, slice);
1189            }
1190        }
1191    } else {
1192        for (row, slice) in out.chunks_mut(cols).enumerate() {
1193            do_row(row, slice);
1194        }
1195    }
1196
1197    Ok((combos, rows, cols))
1198}
1199
1200#[inline(always)]
1201unsafe fn jma_row_scalar(
1202    data: &[f64],
1203    first: usize,
1204    period: usize,
1205    phase: f64,
1206    power: u32,
1207    out: &mut [f64],
1208) {
1209    jma_scalar(data, period, phase, power, first, out);
1210}
1211
1212#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1213#[inline(always)]
1214unsafe fn jma_row_avx2(
1215    data: &[f64],
1216    first: usize,
1217    period: usize,
1218    phase: f64,
1219    power: u32,
1220    out: &mut [f64],
1221) {
1222    jma_avx2(data, period, phase, power, first, out);
1223}
1224
1225#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1226#[inline(always)]
1227unsafe fn jma_row_avx512(
1228    data: &[f64],
1229    first: usize,
1230    period: usize,
1231    phase: f64,
1232    power: u32,
1233    out: &mut [f64],
1234) {
1235    jma_avx512(data, period, phase, power, first, out);
1236}
1237
1238#[inline(always)]
1239pub fn expand_grid_jma(r: &JmaBatchRange) -> Vec<JmaParams> {
1240    expand_grid(r)
1241}
1242
1243#[cfg(test)]
1244mod tests {
1245    use super::*;
1246    use crate::skip_if_unsupported;
1247    use crate::utilities::data_loader::read_candles_from_csv;
1248    #[cfg(feature = "proptest")]
1249    use proptest::prelude::*;
1250
1251    fn check_jma_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1252        skip_if_unsupported!(kernel, test_name);
1253        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1254        let candles = read_candles_from_csv(file_path)?;
1255
1256        let default_params = JmaParams {
1257            period: None,
1258            phase: None,
1259            power: None,
1260        };
1261        let input = JmaInput::from_candles(&candles, "close", default_params);
1262        let output = jma_with_kernel(&input, kernel)?;
1263        assert_eq!(output.values.len(), candles.close.len());
1264
1265        Ok(())
1266    }
1267
1268    fn check_jma_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1269        skip_if_unsupported!(kernel, test_name);
1270        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1271        let candles = read_candles_from_csv(file_path)?;
1272        let input = JmaInput::from_candles(&candles, "close", JmaParams::default());
1273        let result = jma_with_kernel(&input, kernel)?;
1274        let expected_last_five = [
1275            59305.04794668568,
1276            59261.270455005455,
1277            59156.791263606865,
1278            59128.30656791065,
1279            58918.89223153998,
1280        ];
1281        let start = result.values.len().saturating_sub(5);
1282        for (i, &val) in result.values[start..].iter().enumerate() {
1283            let diff = (val - expected_last_five[i]).abs();
1284            assert!(
1285                diff < 1e-6,
1286                "[{}] JMA {:?} mismatch at idx {}: got {}, expected {}",
1287                test_name,
1288                kernel,
1289                i,
1290                val,
1291                expected_last_five[i]
1292            );
1293        }
1294        Ok(())
1295    }
1296
1297    fn check_jma_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1298        skip_if_unsupported!(kernel, test_name);
1299        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1300        let candles = read_candles_from_csv(file_path)?;
1301        let input = JmaInput::with_default_candles(&candles);
1302        match input.data {
1303            JmaData::Candles { source, .. } => assert_eq!(source, "close"),
1304            _ => panic!("Expected JmaData::Candles"),
1305        }
1306        let output = jma_with_kernel(&input, kernel)?;
1307        assert_eq!(output.values.len(), candles.close.len());
1308        Ok(())
1309    }
1310
1311    fn check_jma_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1312        skip_if_unsupported!(kernel, test_name);
1313        let input_data = [10.0, 20.0, 30.0];
1314        let params = JmaParams {
1315            period: Some(0),
1316            phase: None,
1317            power: None,
1318        };
1319        let input = JmaInput::from_slice(&input_data, params);
1320        let res = jma_with_kernel(&input, kernel);
1321        assert!(
1322            res.is_err(),
1323            "[{}] JMA should fail with zero period",
1324            test_name
1325        );
1326        Ok(())
1327    }
1328
1329    fn check_jma_period_exceeds_length(
1330        test_name: &str,
1331        kernel: Kernel,
1332    ) -> Result<(), Box<dyn Error>> {
1333        skip_if_unsupported!(kernel, test_name);
1334        let data_small = [10.0, 20.0, 30.0];
1335        let params = JmaParams {
1336            period: Some(10),
1337            phase: None,
1338            power: None,
1339        };
1340        let input = JmaInput::from_slice(&data_small, params);
1341        let res = jma_with_kernel(&input, kernel);
1342        assert!(
1343            res.is_err(),
1344            "[{}] JMA should fail with period exceeding length",
1345            test_name
1346        );
1347        Ok(())
1348    }
1349
1350    fn check_jma_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1351        skip_if_unsupported!(kernel, test_name);
1352        let single_point = [42.0];
1353        let params = JmaParams {
1354            period: Some(7),
1355            phase: None,
1356            power: None,
1357        };
1358        let input = JmaInput::from_slice(&single_point, params);
1359        let res = jma_with_kernel(&input, kernel);
1360        assert!(
1361            res.is_err(),
1362            "[{}] JMA should fail with insufficient data",
1363            test_name
1364        );
1365        Ok(())
1366    }
1367
1368    fn check_jma_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1369        skip_if_unsupported!(kernel, test_name);
1370        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1371        let candles = read_candles_from_csv(file_path)?;
1372        let first_params = JmaParams {
1373            period: Some(7),
1374            phase: None,
1375            power: None,
1376        };
1377        let first_input = JmaInput::from_candles(&candles, "close", first_params);
1378        let first_result = jma_with_kernel(&first_input, kernel)?;
1379        let second_params = JmaParams {
1380            period: Some(7),
1381            phase: None,
1382            power: None,
1383        };
1384        let second_input = JmaInput::from_slice(&first_result.values, second_params);
1385        let second_result = jma_with_kernel(&second_input, kernel)?;
1386        assert_eq!(second_result.values.len(), first_result.values.len());
1387        Ok(())
1388    }
1389
1390    fn check_jma_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1391        skip_if_unsupported!(kernel, test_name);
1392        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1393        let candles = read_candles_from_csv(file_path)?;
1394        let input = JmaInput::from_candles(
1395            &candles,
1396            "close",
1397            JmaParams {
1398                period: Some(7),
1399                phase: None,
1400                power: None,
1401            },
1402        );
1403        let res = jma_with_kernel(&input, kernel)?;
1404        assert_eq!(res.values.len(), candles.close.len());
1405        if res.values.len() > 240 {
1406            for (i, &val) in res.values[240..].iter().enumerate() {
1407                assert!(
1408                    !val.is_nan(),
1409                    "[{}] Found unexpected NaN at out-index {}",
1410                    test_name,
1411                    240 + i
1412                );
1413            }
1414        }
1415        Ok(())
1416    }
1417
1418    fn check_jma_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1419        skip_if_unsupported!(kernel, test_name);
1420        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1421        let candles = read_candles_from_csv(file_path)?;
1422        let period = 7;
1423        let phase = 50.0;
1424        let power = 2;
1425        let input = JmaInput::from_candles(
1426            &candles,
1427            "close",
1428            JmaParams {
1429                period: Some(period),
1430                phase: Some(phase),
1431                power: Some(power),
1432            },
1433        );
1434        let batch_output = jma_with_kernel(&input, kernel)?.values;
1435        let mut stream = JmaStream::try_new(JmaParams {
1436            period: Some(period),
1437            phase: Some(phase),
1438            power: Some(power),
1439        })?;
1440        let mut stream_values = Vec::with_capacity(candles.close.len());
1441        for &price in &candles.close {
1442            match stream.update(price) {
1443                Some(jma_val) => stream_values.push(jma_val),
1444                None => stream_values.push(f64::NAN),
1445            }
1446        }
1447        assert_eq!(batch_output.len(), stream_values.len());
1448        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1449            if b.is_nan() && s.is_nan() {
1450                continue;
1451            }
1452            let diff = (b - s).abs();
1453            assert!(
1454                diff < 1e-8,
1455                "[{}] JMA streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1456                test_name,
1457                i,
1458                b,
1459                s,
1460                diff
1461            );
1462        }
1463        Ok(())
1464    }
1465
1466    #[cfg(debug_assertions)]
1467    fn check_jma_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1468        skip_if_unsupported!(kernel, test_name);
1469
1470        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1471        let candles = read_candles_from_csv(file_path)?;
1472
1473        let test_params = vec![
1474            JmaParams::default(),
1475            JmaParams {
1476                period: Some(3),
1477                phase: Some(0.0),
1478                power: Some(1),
1479            },
1480            JmaParams {
1481                period: Some(3),
1482                phase: Some(50.0),
1483                power: Some(2),
1484            },
1485            JmaParams {
1486                period: Some(3),
1487                phase: Some(100.0),
1488                power: Some(3),
1489            },
1490            JmaParams {
1491                period: Some(7),
1492                phase: Some(25.0),
1493                power: Some(1),
1494            },
1495            JmaParams {
1496                period: Some(7),
1497                phase: Some(50.0),
1498                power: Some(2),
1499            },
1500            JmaParams {
1501                period: Some(7),
1502                phase: Some(75.0),
1503                power: Some(3),
1504            },
1505            JmaParams {
1506                period: Some(10),
1507                phase: Some(0.0),
1508                power: Some(2),
1509            },
1510            JmaParams {
1511                period: Some(14),
1512                phase: Some(100.0),
1513                power: Some(2),
1514            },
1515            JmaParams {
1516                period: Some(20),
1517                phase: Some(50.0),
1518                power: Some(1),
1519            },
1520            JmaParams {
1521                period: Some(30),
1522                phase: Some(50.0),
1523                power: Some(2),
1524            },
1525            JmaParams {
1526                period: Some(50),
1527                phase: Some(50.0),
1528                power: Some(3),
1529            },
1530            JmaParams {
1531                period: Some(1),
1532                phase: Some(0.0),
1533                power: Some(1),
1534            },
1535            JmaParams {
1536                period: Some(100),
1537                phase: Some(100.0),
1538                power: Some(5),
1539            },
1540            JmaParams {
1541                period: Some(10),
1542                phase: Some(-100.0),
1543                power: Some(2),
1544            },
1545            JmaParams {
1546                period: Some(10),
1547                phase: Some(200.0),
1548                power: Some(2),
1549            },
1550        ];
1551
1552        for (param_idx, params) in test_params.iter().enumerate() {
1553            let input = JmaInput::from_candles(&candles, "close", params.clone());
1554            let output = jma_with_kernel(&input, kernel)?;
1555
1556            for (i, &val) in output.values.iter().enumerate() {
1557                if val.is_nan() {
1558                    continue;
1559                }
1560
1561                let bits = val.to_bits();
1562
1563                if bits == 0x11111111_11111111 {
1564                    panic!(
1565                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1566                        with params: period={}, phase={}, power={}",
1567                        test_name,
1568                        val,
1569                        bits,
1570                        i,
1571                        params.period.unwrap_or(7),
1572                        params.phase.unwrap_or(50.0),
1573                        params.power.unwrap_or(2)
1574                    );
1575                }
1576
1577                if bits == 0x22222222_22222222 {
1578                    panic!(
1579                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1580                        with params: period={}, phase={}, power={}",
1581                        test_name,
1582                        val,
1583                        bits,
1584                        i,
1585                        params.period.unwrap_or(7),
1586                        params.phase.unwrap_or(50.0),
1587                        params.power.unwrap_or(2)
1588                    );
1589                }
1590
1591                if bits == 0x33333333_33333333 {
1592                    panic!(
1593                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1594                        with params: period={}, phase={}, power={}",
1595                        test_name,
1596                        val,
1597                        bits,
1598                        i,
1599                        params.period.unwrap_or(7),
1600                        params.phase.unwrap_or(50.0),
1601                        params.power.unwrap_or(2)
1602                    );
1603                }
1604            }
1605        }
1606
1607        Ok(())
1608    }
1609
1610    #[cfg(not(debug_assertions))]
1611    fn check_jma_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1612        Ok(())
1613    }
1614
1615    #[cfg(feature = "proptest")]
1616    #[allow(clippy::float_cmp)]
1617    fn check_jma_property(
1618        test_name: &str,
1619        kernel: Kernel,
1620    ) -> Result<(), Box<dyn std::error::Error>> {
1621        use proptest::prelude::*;
1622        skip_if_unsupported!(kernel, test_name);
1623
1624        let strat = (2usize..=100).prop_flat_map(|period| {
1625            (
1626                prop::collection::vec(
1627                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1628                    period..400,
1629                ),
1630                Just(period),
1631                -100f64..=100f64,
1632                1u32..=10,
1633            )
1634        });
1635
1636        proptest::test_runner::TestRunner::default()
1637			.run(&strat, |(data, period, phase, power)| {
1638				let params = JmaParams {
1639					period: Some(period),
1640					phase: Some(phase),
1641					power: Some(power),
1642				};
1643				let input = JmaInput::from_slice(&data, params);
1644
1645				let JmaOutput { values: out } = jma_with_kernel(&input, kernel).unwrap();
1646				let JmaOutput { values: ref_out } = jma_with_kernel(&input, Kernel::Scalar).unwrap();
1647
1648
1649				prop_assert_eq!(out.len(), data.len());
1650
1651
1652				let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(data.len());
1653
1654
1655
1656
1657				for i in 0..first_valid.min(out.len()) {
1658					prop_assert!(
1659						out[i].is_nan(),
1660						"idx {}: expected NaN before first valid input, got {}", i, out[i]
1661					);
1662				}
1663
1664
1665				if first_valid < out.len() {
1666					prop_assert!(
1667						out[first_valid].is_finite(),
1668						"JMA should output a finite value at first_valid index {}", first_valid
1669					);
1670				}
1671
1672
1673				for i in (first_valid + 1)..data.len() {
1674					if data[i].is_finite() && !data[first_valid..=i].iter().any(|x| x.is_nan()) {
1675						prop_assert!(
1676							out[i].is_finite(),
1677							"idx {}: expected finite value, got {}", i, out[i]
1678						);
1679					}
1680				}
1681
1682
1683				let warmup_estimate = first_valid + period;
1684				if warmup_estimate + 20 < data.len() {
1685					let window_start = warmup_estimate;
1686					let window_end = (warmup_estimate + 50).min(data.len());
1687
1688					let input_slice = &data[window_start..window_end];
1689					let output_slice = &out[window_start..window_end];
1690
1691					if input_slice.iter().all(|x| x.is_finite()) && output_slice.iter().all(|x| x.is_finite()) {
1692						let input_mean: f64 = input_slice.iter().sum::<f64>() / input_slice.len() as f64;
1693						let output_mean: f64 = output_slice.iter().sum::<f64>() / output_slice.len() as f64;
1694
1695						let input_var: f64 = input_slice.iter()
1696							.map(|x| (x - input_mean).powi(2))
1697							.sum::<f64>() / input_slice.len() as f64;
1698						let output_var: f64 = output_slice.iter()
1699							.map(|x| (x - output_mean).powi(2))
1700							.sum::<f64>() / output_slice.len() as f64;
1701
1702
1703						if input_var > 1e-10 {
1704							prop_assert!(
1705								output_var <= input_var * 1.1,
1706								"JMA should smooth data: input_var={}, output_var={}, period={}, phase={}, power={}",
1707								input_var, output_var, period, phase, power
1708							);
1709						}
1710					}
1711				}
1712
1713
1714				if period == 1 {
1715
1716					for i in first_valid..data.len().min(first_valid + 20) {
1717						if data[i].is_finite() && out[i].is_finite() {
1718							prop_assert!(
1719								(out[i] - data[i]).abs() <= data[i].abs() * 0.1 + 1e-6,
1720								"period=1 should closely track input: idx={}, data={}, out={}, diff={}",
1721								i, data[i], out[i], (out[i] - data[i]).abs()
1722							);
1723						}
1724					}
1725				}
1726
1727
1728				let warmup_estimate = first_valid + period;
1729				if data[first_valid..].windows(2).all(|w| (w[0] - w[1]).abs() < 1e-10) && warmup_estimate + 5 < data.len() {
1730					let constant_val = data[first_valid];
1731					for i in (warmup_estimate + 5)..data.len() {
1732						if out[i].is_finite() {
1733							prop_assert!(
1734								(out[i] - constant_val).abs() <= 1e-6,
1735								"constant input should produce constant output: idx={}, expected={}, got={}",
1736								i, constant_val, out[i]
1737							);
1738						}
1739					}
1740				}
1741
1742
1743				if kernel != Kernel::Scalar {
1744					for i in first_valid..data.len() {
1745						if out[i].is_finite() && ref_out[i].is_finite() {
1746							let y_bits = out[i].to_bits();
1747							let r_bits = ref_out[i].to_bits();
1748							let diff_bits = if y_bits > r_bits {
1749								y_bits - r_bits
1750							} else {
1751								r_bits - y_bits
1752							};
1753
1754
1755
1756							let abs_diff = (out[i] - ref_out[i]).abs();
1757							let rel_diff = if ref_out[i].abs() > 1e-10 {
1758								abs_diff / ref_out[i].abs()
1759							} else {
1760								abs_diff
1761							};
1762
1763							prop_assert!(
1764								diff_bits <= 1000 || abs_diff < 1e-9 || rel_diff < 1e-12,
1765								"kernel consistency failed at idx {}: {:?}={}, Scalar={}, diff_bits={}, abs_diff={}, rel_diff={}",
1766								i, kernel, out[i], ref_out[i], diff_bits, abs_diff, rel_diff
1767							);
1768						}
1769					}
1770				}
1771
1772
1773
1774				let warmup_estimate = first_valid + period;
1775				if phase.abs() > 10.0 && warmup_estimate + 30 < data.len() {
1776					let params_neutral = JmaParams {
1777						period: Some(period),
1778						phase: Some(0.0),
1779						power: Some(power),
1780					};
1781					let input_neutral = JmaInput::from_slice(&data, params_neutral);
1782					if let Ok(JmaOutput { values: out_neutral }) = jma_with_kernel(&input_neutral, kernel) {
1783
1784						let check_start = warmup_estimate + 10;
1785						let check_end = (warmup_estimate + 30).min(data.len() - 1);
1786
1787
1788						let mut lead_count = 0;
1789						let mut lag_count = 0;
1790
1791						for i in check_start..check_end {
1792							if data[i].is_finite() && data[i-1].is_finite() &&
1793							   out[i].is_finite() && out_neutral[i].is_finite() {
1794								let data_change = data[i] - data[i-1];
1795								if data_change.abs() > 1e-10 {
1796									let phase_diff = out[i] - out_neutral[i];
1797									if data_change > 0.0 {
1798
1799										if phase > 0.0 && phase_diff > 0.0 {
1800											lead_count += 1;
1801										} else if phase < 0.0 && phase_diff < 0.0 {
1802											lag_count += 1;
1803										}
1804									} else {
1805
1806										if phase > 0.0 && phase_diff < 0.0 {
1807											lead_count += 1;
1808										} else if phase < 0.0 && phase_diff > 0.0 {
1809											lag_count += 1;
1810										}
1811									}
1812								}
1813							}
1814						}
1815
1816
1817						if phase > 10.0 {
1818							prop_assert!(
1819								lead_count > 0,
1820								"Positive phase should show leading behavior: phase={}, lead_count={}",
1821								phase, lead_count
1822							);
1823						} else if phase < -10.0 {
1824							prop_assert!(
1825								lag_count > 0,
1826								"Negative phase should show lagging behavior: phase={}, lag_count={}",
1827								phase, lag_count
1828							);
1829						}
1830					}
1831				}
1832
1833
1834				let warmup_estimate2 = first_valid + period;
1835				if power > 1 && warmup_estimate2 + 20 < data.len() {
1836					let params_low_power = JmaParams {
1837						period: Some(period),
1838						phase: Some(phase),
1839						power: Some(1),
1840					};
1841					let input_low_power = JmaInput::from_slice(&data, params_low_power);
1842					if let Ok(JmaOutput { values: out_low_power }) = jma_with_kernel(&input_low_power, kernel) {
1843
1844						let check_start = warmup_estimate2;
1845						let check_end = (warmup_estimate2 + 30).min(data.len());
1846
1847						let mut high_power_responsiveness = 0.0;
1848						let mut low_power_responsiveness = 0.0;
1849						let mut count = 0;
1850
1851						for i in check_start..check_end {
1852							if data[i].is_finite() && out[i].is_finite() && out_low_power[i].is_finite() {
1853								high_power_responsiveness += (out[i] - data[i]).abs();
1854								low_power_responsiveness += (out_low_power[i] - data[i]).abs();
1855								count += 1;
1856							}
1857						}
1858
1859						if count > 0 {
1860							high_power_responsiveness /= count as f64;
1861							low_power_responsiveness /= count as f64;
1862
1863
1864
1865
1866							if low_power_responsiveness > high_power_responsiveness * 1.5 {
1867								prop_assert!(
1868									high_power_responsiveness <= low_power_responsiveness,
1869									"Higher power should be more responsive: power={} resp={}, power=1 resp={}",
1870									power, high_power_responsiveness, low_power_responsiveness
1871								);
1872							}
1873						}
1874					}
1875				}
1876
1877				Ok(())
1878			})
1879			.unwrap();
1880
1881        Ok(())
1882    }
1883
1884    #[cfg(not(feature = "proptest"))]
1885    fn check_jma_property(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1886        Ok(())
1887    }
1888
1889    macro_rules! generate_all_jma_tests {
1890        ($($test_fn:ident),*) => {
1891            paste::paste! {
1892                $(
1893                    #[test]
1894                    fn [<$test_fn _scalar_f64>]() {
1895                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1896                    }
1897                )*
1898                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1899                $(
1900                    #[test]
1901                    fn [<$test_fn _avx2_f64>]() {
1902                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1903                    }
1904                    #[test]
1905                    fn [<$test_fn _avx512_f64>]() {
1906                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1907                    }
1908                )*
1909            }
1910        }
1911    }
1912
1913    generate_all_jma_tests!(
1914        check_jma_partial_params,
1915        check_jma_accuracy,
1916        check_jma_default_candles,
1917        check_jma_zero_period,
1918        check_jma_period_exceeds_length,
1919        check_jma_very_small_dataset,
1920        check_jma_reinput,
1921        check_jma_nan_handling,
1922        check_jma_streaming,
1923        check_jma_property,
1924        check_jma_no_poison
1925    );
1926
1927    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1928        skip_if_unsupported!(kernel, test);
1929
1930        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1931        let c = read_candles_from_csv(file)?;
1932
1933        let output = JmaBatchBuilder::new()
1934            .kernel(kernel)
1935            .apply_candles(&c, "close")?;
1936
1937        let def = JmaParams::default();
1938        let row = output.values_for(&def).expect("default row missing");
1939
1940        assert_eq!(row.len(), c.close.len());
1941
1942        let expected = [
1943            59305.04794668568,
1944            59261.270455005455,
1945            59156.791263606865,
1946            59128.30656791065,
1947            58918.89223153998,
1948        ];
1949        let start = row.len() - 5;
1950        for (i, &v) in row[start..].iter().enumerate() {
1951            assert!(
1952                (v - expected[i]).abs() < 1e-6,
1953                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1954            );
1955        }
1956        Ok(())
1957    }
1958
1959    #[cfg(debug_assertions)]
1960    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1961        skip_if_unsupported!(kernel, test);
1962
1963        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1964        let c = read_candles_from_csv(file)?;
1965
1966        let batch_configs = vec![
1967            (5, 20, 5, 0.0, 100.0, 50.0, 1, 3, 1),
1968            (1, 10, 3, -100.0, 200.0, 100.0, 1, 5, 2),
1969            (50, 100, 25, -50.0, 50.0, 25.0, 2, 4, 1),
1970            (14, 14, 1, -100.0, 200.0, 50.0, 1, 5, 1),
1971            (1, 1, 1, 0.0, 0.0, 1.0, 1, 1, 1),
1972            (10, 30, 10, 50.0, 50.0, 1.0, 1, 5, 2),
1973            (5, 50, 5, -50.0, 150.0, 25.0, 1, 3, 1),
1974            (7, 21, 7, -100.0, 100.0, 40.0, 2, 2, 1),
1975            (80, 100, 20, 100.0, 200.0, 50.0, 4, 5, 1),
1976        ];
1977
1978        for (idx, config) in batch_configs.iter().enumerate() {
1979            let output = JmaBatchBuilder::new()
1980                .kernel(kernel)
1981                .period_range(config.0, config.1, config.2)
1982                .phase_range(config.3, config.4, config.5)
1983                .power_range(config.6, config.7, config.8)
1984                .apply_candles(&c, "close")?;
1985
1986            for (val_idx, &val) in output.values.iter().enumerate() {
1987                if val.is_nan() {
1988                    continue;
1989                }
1990
1991                let bits = val.to_bits();
1992                let row = val_idx / output.cols;
1993                let col = val_idx % output.cols;
1994                let combo = &output.combos[row];
1995
1996                if bits == 0x11111111_11111111 {
1997                    panic!(
1998                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) \
1999                        in batch config #{}: period_range=({},{},{}), phase_range=({},{},{}), power_range=({},{},{}) \
2000                        combo params: period={}, phase={}, power={}",
2001                        test, val, bits, row, col, val_idx, idx,
2002                        config.0, config.1, config.2, config.3, config.4, config.5, config.6, config.7, config.8,
2003                        combo.period.unwrap_or(7), combo.phase.unwrap_or(50.0), combo.power.unwrap_or(2)
2004                    );
2005                }
2006
2007                if bits == 0x22222222_22222222 {
2008                    panic!(
2009						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) \
2010                        in batch config #{}: period_range=({},{},{}), phase_range=({},{},{}), power_range=({},{},{}) \
2011                        combo params: period={}, phase={}, power={}",
2012						test,
2013						val,
2014						bits,
2015						row,
2016						col,
2017						val_idx,
2018						idx,
2019						config.0,
2020						config.1,
2021						config.2,
2022						config.3,
2023						config.4,
2024						config.5,
2025						config.6,
2026						config.7,
2027						config.8,
2028						combo.period.unwrap_or(7),
2029						combo.phase.unwrap_or(50.0),
2030						combo.power.unwrap_or(2)
2031					);
2032                }
2033
2034                if bits == 0x33333333_33333333 {
2035                    panic!(
2036						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) \
2037                        in batch config #{}: period_range=({},{},{}), phase_range=({},{},{}), power_range=({},{},{}) \
2038                        combo params: period={}, phase={}, power={}",
2039						test,
2040						val,
2041						bits,
2042						row,
2043						col,
2044						val_idx,
2045						idx,
2046						config.0,
2047						config.1,
2048						config.2,
2049						config.3,
2050						config.4,
2051						config.5,
2052						config.6,
2053						config.7,
2054						config.8,
2055						combo.period.unwrap_or(7),
2056						combo.phase.unwrap_or(50.0),
2057						combo.power.unwrap_or(2)
2058					);
2059                }
2060            }
2061        }
2062
2063        Ok(())
2064    }
2065
2066    #[cfg(not(debug_assertions))]
2067    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2068        Ok(())
2069    }
2070
2071    macro_rules! gen_batch_tests {
2072        ($fn_name:ident) => {
2073            paste::paste! {
2074                #[test] fn [<$fn_name _scalar>]()      {
2075                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2076                }
2077                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2078                #[test] fn [<$fn_name _avx2>]()        {
2079                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2080                }
2081                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2082                #[test] fn [<$fn_name _avx512>]()      {
2083                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2084                }
2085                #[test] fn [<$fn_name _auto_detect>]() {
2086                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2087                }
2088            }
2089        };
2090    }
2091    gen_batch_tests!(check_batch_default_row);
2092    gen_batch_tests!(check_batch_no_poison);
2093
2094    #[test]
2095    fn test_jma_into_matches_api() -> Result<(), Box<dyn Error>> {
2096        let len = 256usize;
2097        let mut data = Vec::with_capacity(len);
2098        for _ in 0..5 {
2099            data.push(f64::from_bits(0x7ff8_0000_0000_0000));
2100        }
2101        for i in 0..(len - 5) {
2102            let x = i as f64;
2103
2104            data.push((x * 0.1).sin() * 10.0 + x * 0.01);
2105        }
2106
2107        let input = JmaInput::from_slice(&data, JmaParams::default());
2108
2109        let baseline = jma(&input)?.values;
2110
2111        let mut out = vec![0.0; data.len()];
2112
2113        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2114        {
2115            jma_into(&input, &mut out)?;
2116
2117            assert_eq!(out.len(), baseline.len());
2118            for (a, b) in out.iter().zip(baseline.iter()) {
2119                if a.is_nan() || b.is_nan() {
2120                    assert!(a.is_nan() && b.is_nan());
2121                } else {
2122                    assert_eq!(a, b);
2123                }
2124            }
2125        }
2126
2127        Ok(())
2128    }
2129}
2130
2131#[cfg(feature = "python")]
2132use crate::utilities::kernel_validation::validate_kernel;
2133#[cfg(all(feature = "python", feature = "cuda"))]
2134use numpy::PyUntypedArrayMethods;
2135#[cfg(feature = "python")]
2136use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
2137#[cfg(feature = "python")]
2138use pyo3::exceptions::PyValueError;
2139#[cfg(feature = "python")]
2140use pyo3::prelude::*;
2141#[cfg(feature = "python")]
2142use pyo3::types::PyDict;
2143
2144#[cfg(feature = "python")]
2145#[pyfunction(name = "jma")]
2146#[pyo3(signature = (data, period, phase=50.0, power=2, kernel=None))]
2147pub fn jma_py<'py>(
2148    py: Python<'py>,
2149    data: PyReadonlyArray1<'py, f64>,
2150    period: usize,
2151    phase: f64,
2152    power: u32,
2153    kernel: Option<&str>,
2154) -> PyResult<Bound<'py, PyArray1<f64>>> {
2155    let slice_in = data.as_slice()?;
2156    let kern = validate_kernel(kernel, false)?;
2157
2158    let params = JmaParams {
2159        period: Some(period),
2160        phase: Some(phase),
2161        power: Some(power),
2162    };
2163    let jma_in = JmaInput::from_slice(slice_in, params);
2164
2165    let result_vec: Vec<f64> = py
2166        .allow_threads(|| jma_with_kernel(&jma_in, kern).map(|o| o.values))
2167        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2168
2169    Ok(result_vec.into_pyarray(py))
2170}
2171
2172#[cfg(feature = "python")]
2173#[pyfunction(name = "jma_batch")]
2174#[pyo3(signature = (data, period_range, phase_range=(50.0, 50.0, 0.0), power_range=(2, 2, 0), kernel=None))]
2175pub fn jma_batch_py<'py>(
2176    py: Python<'py>,
2177    data: PyReadonlyArray1<'py, f64>,
2178    period_range: (usize, usize, usize),
2179    phase_range: (f64, f64, f64),
2180    power_range: (u32, u32, u32),
2181    kernel: Option<&str>,
2182) -> PyResult<Bound<'py, PyDict>> {
2183    let slice_in = data.as_slice()?;
2184    let kern = validate_kernel(kernel, true)?;
2185
2186    let sweep = JmaBatchRange {
2187        period: period_range,
2188        phase: phase_range,
2189        power: power_range,
2190    };
2191
2192    let combos = expand_grid(&sweep);
2193    if combos.is_empty() {
2194        return Err(PyValueError::new_err("Invalid parameter ranges"));
2195    }
2196
2197    let rows = combos.len();
2198    let cols = slice_in.len();
2199
2200    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
2201    let slice_out = unsafe { out_arr.as_slice_mut()? };
2202
2203    let (combos_result, _, _) = py
2204        .allow_threads(|| {
2205            let kernel = match kern {
2206                Kernel::Auto => detect_best_batch_kernel(),
2207                k => k,
2208            };
2209
2210            let simd = match kernel {
2211                Kernel::Avx512Batch => Kernel::Avx512,
2212                Kernel::Avx2Batch => Kernel::Avx2,
2213                Kernel::ScalarBatch => Kernel::Scalar,
2214                _ => kernel,
2215            };
2216
2217            jma_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
2218        })
2219        .map_err(|e| PyValueError::new_err(e.to_string()))?;
2220
2221    let dict = PyDict::new(py);
2222    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2223
2224    dict.set_item(
2225        "periods",
2226        combos_result
2227            .iter()
2228            .map(|p| p.period.unwrap() as u64)
2229            .collect::<Vec<_>>()
2230            .into_pyarray(py),
2231    )?;
2232    dict.set_item(
2233        "phases",
2234        combos_result
2235            .iter()
2236            .map(|p| p.phase.unwrap())
2237            .collect::<Vec<_>>()
2238            .into_pyarray(py),
2239    )?;
2240    dict.set_item(
2241        "powers",
2242        combos_result
2243            .iter()
2244            .map(|p| p.power.unwrap() as u64)
2245            .collect::<Vec<_>>()
2246            .into_pyarray(py),
2247    )?;
2248
2249    Ok(dict)
2250}
2251
2252#[cfg(all(feature = "python", feature = "cuda"))]
2253#[pyfunction(name = "jma_cuda_batch_dev")]
2254#[pyo3(signature = (data_f32, period_range, phase_range=(50.0, 50.0, 0.0), power_range=(2, 2, 0), device_id=0))]
2255pub fn jma_cuda_batch_dev_py(
2256    py: Python<'_>,
2257    data_f32: PyReadonlyArray1<'_, f32>,
2258    period_range: (usize, usize, usize),
2259    phase_range: (f64, f64, f64),
2260    power_range: (u32, u32, u32),
2261    device_id: usize,
2262) -> PyResult<JmaDeviceArrayF32Py> {
2263    if !cuda_available() {
2264        return Err(PyValueError::new_err("CUDA not available"));
2265    }
2266
2267    let slice_in = data_f32.as_slice()?;
2268    let sweep = JmaBatchRange {
2269        period: period_range,
2270        phase: phase_range,
2271        power: power_range,
2272    };
2273
2274    let inner = py.allow_threads(|| {
2275        let cuda = CudaJma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2276        cuda.jma_batch_dev(slice_in, &sweep)
2277            .map_err(|e| PyValueError::new_err(e.to_string()))
2278    })?;
2279
2280    Ok(JmaDeviceArrayF32Py { inner: Some(inner) })
2281}
2282
2283#[cfg(all(feature = "python", feature = "cuda"))]
2284#[pyfunction(name = "jma_cuda_many_series_one_param_dev")]
2285#[pyo3(signature = (prices_tm_f32, period, phase=50.0, power=2, device_id=0))]
2286pub fn jma_cuda_many_series_one_param_dev_py(
2287    py: Python<'_>,
2288    prices_tm_f32: PyReadonlyArray2<'_, f32>,
2289    period: usize,
2290    phase: f64,
2291    power: u32,
2292    device_id: usize,
2293) -> PyResult<JmaDeviceArrayF32Py> {
2294    if !cuda_available() {
2295        return Err(PyValueError::new_err("CUDA not available"));
2296    }
2297
2298    let shape = prices_tm_f32.shape();
2299    if shape.len() != 2 {
2300        return Err(PyValueError::new_err("prices matrix must be 2-dimensional"));
2301    }
2302    let rows = shape[0];
2303    let cols = shape[1];
2304    let prices_flat = prices_tm_f32.as_slice()?;
2305    let params = JmaParams {
2306        period: Some(period),
2307        phase: Some(phase),
2308        power: Some(power),
2309    };
2310
2311    let inner = py.allow_threads(|| {
2312        let cuda = CudaJma::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2313        cuda.jma_many_series_one_param_time_major_dev(prices_flat, cols, rows, &params)
2314            .map_err(|e| PyValueError::new_err(e.to_string()))
2315    })?;
2316
2317    Ok(JmaDeviceArrayF32Py { inner: Some(inner) })
2318}
2319
2320#[cfg(all(feature = "python", feature = "cuda"))]
2321#[pyclass(module = "ta_indicators.cuda", unsendable)]
2322pub struct JmaDeviceArrayF32Py {
2323    pub(crate) inner: Option<DeviceArrayF32Jma>,
2324}
2325
2326#[cfg(all(feature = "python", feature = "cuda"))]
2327#[pymethods]
2328impl JmaDeviceArrayF32Py {
2329    #[getter]
2330    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2331        let inner = self
2332            .inner
2333            .as_ref()
2334            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2335        let d = PyDict::new(py);
2336        d.set_item("shape", (inner.rows, inner.cols))?;
2337        d.set_item("typestr", "<f4")?;
2338        d.set_item(
2339            "strides",
2340            (
2341                inner.cols * std::mem::size_of::<f32>(),
2342                std::mem::size_of::<f32>(),
2343            ),
2344        )?;
2345        d.set_item("data", (inner.device_ptr() as usize, false))?;
2346
2347        d.set_item("version", 3)?;
2348        Ok(d)
2349    }
2350
2351    fn __dlpack_device__(&self) -> (i32, i32) {
2352        let dev = self.inner.as_ref().map(|h| h.device_id as i32).unwrap_or(0);
2353        (2, dev)
2354    }
2355
2356    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2357    fn __dlpack__<'py>(
2358        &mut self,
2359        py: Python<'py>,
2360        stream: Option<pyo3::PyObject>,
2361        max_version: Option<pyo3::PyObject>,
2362        dl_device: Option<pyo3::PyObject>,
2363        copy: Option<pyo3::PyObject>,
2364    ) -> PyResult<PyObject> {
2365        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2366
2367        let (kdl, alloc_dev) = self.__dlpack_device__();
2368
2369        if let Some(dev_obj) = dl_device.as_ref() {
2370            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2371                if dev_ty != kdl || dev_id != alloc_dev {
2372                    let wants_copy = copy
2373                        .as_ref()
2374                        .and_then(|c| c.extract::<bool>(py).ok())
2375                        .unwrap_or(false);
2376                    if wants_copy {
2377                        return Err(PyValueError::new_err(
2378                            "device copy not implemented for __dlpack__",
2379                        ));
2380                    } else {
2381                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2382                    }
2383                }
2384            }
2385        }
2386
2387        let _ = stream;
2388
2389        let inner = self
2390            .inner
2391            .take()
2392            .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?;
2393
2394        let rows = inner.rows;
2395        let cols = inner.cols;
2396        let buf = inner.buf;
2397
2398        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2399
2400        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2401    }
2402}
2403
2404#[cfg(feature = "python")]
2405#[pyclass(name = "JmaStream")]
2406pub struct JmaStreamPy {
2407    inner: JmaStream,
2408}
2409
2410#[cfg(feature = "python")]
2411#[pymethods]
2412impl JmaStreamPy {
2413    #[new]
2414    #[pyo3(signature = (period, phase=50.0, power=2))]
2415    fn new(period: usize, phase: f64, power: u32) -> PyResult<Self> {
2416        let params = JmaParams {
2417            period: Some(period),
2418            phase: Some(phase),
2419            power: Some(power),
2420        };
2421
2422        let stream = JmaStream::try_new(params)
2423            .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
2424
2425        Ok(Self { inner: stream })
2426    }
2427
2428    fn update(&mut self, value: f64) -> Option<f64> {
2429        self.inner.update(value)
2430    }
2431}
2432
2433#[inline]
2434pub fn jma_into_slice(dst: &mut [f64], input: &JmaInput, kern: Kernel) -> Result<(), JmaError> {
2435    let data: &[f64] = match &input.data {
2436        JmaData::Candles { candles, source } => source_type(candles, source),
2437        JmaData::Slice(sl) => sl,
2438    };
2439
2440    if dst.len() != data.len() {
2441        return Err(JmaError::OutputLengthMismatch {
2442            expected: data.len(),
2443            got: dst.len(),
2444        });
2445    }
2446
2447    jma_with_kernel_into(input, kern, dst)
2448}
2449
2450#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2451use serde::{Deserialize, Serialize};
2452#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2453use wasm_bindgen::prelude::*;
2454
2455#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2456#[wasm_bindgen]
2457pub fn jma_js(data: &[f64], period: usize, phase: f64, power: u32) -> Result<Vec<f64>, JsValue> {
2458    let params = JmaParams {
2459        period: Some(period),
2460        phase: Some(phase),
2461        power: Some(power),
2462    };
2463    let input = JmaInput::from_slice(data, params);
2464
2465    let mut output = vec![0.0; data.len()];
2466
2467    jma_into_slice(&mut output, &input, Kernel::Auto)
2468        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2469
2470    Ok(output)
2471}
2472
2473#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2474#[wasm_bindgen]
2475pub fn jma_alloc(len: usize) -> *mut f64 {
2476    let mut vec = Vec::<f64>::with_capacity(len);
2477    let ptr = vec.as_mut_ptr();
2478    std::mem::forget(vec);
2479    ptr
2480}
2481
2482#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2483#[wasm_bindgen]
2484pub fn jma_free(ptr: *mut f64, len: usize) {
2485    if !ptr.is_null() {
2486        unsafe {
2487            let _ = Vec::from_raw_parts(ptr, len, len);
2488        }
2489    }
2490}
2491
2492#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2493#[wasm_bindgen]
2494pub fn jma_into(
2495    in_ptr: *const f64,
2496    out_ptr: *mut f64,
2497    len: usize,
2498    period: usize,
2499    phase: f64,
2500    power: u32,
2501) -> Result<(), JsValue> {
2502    if in_ptr.is_null() || out_ptr.is_null() {
2503        return Err(JsValue::from_str("null pointer passed to jma_into"));
2504    }
2505
2506    unsafe {
2507        let data = std::slice::from_raw_parts(in_ptr, len);
2508
2509        if period == 0 || period > len {
2510            return Err(JsValue::from_str("Invalid period"));
2511        }
2512
2513        let params = JmaParams {
2514            period: Some(period),
2515            phase: Some(phase),
2516            power: Some(power),
2517        };
2518        let input = JmaInput::from_slice(data, params);
2519
2520        if in_ptr == out_ptr {
2521            let mut temp = vec![0.0; len];
2522            jma_into_slice(&mut temp, &input, Kernel::Auto)
2523                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2524
2525            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2526            out.copy_from_slice(&temp);
2527        } else {
2528            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2529            jma_into_slice(out, &input, Kernel::Auto)
2530                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2531        }
2532
2533        Ok(())
2534    }
2535}
2536
2537#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2538#[derive(Serialize, Deserialize)]
2539pub struct JmaBatchConfig {
2540    pub period_range: (usize, usize, usize),
2541    pub phase_range: (f64, f64, f64),
2542    pub power_range: (u32, u32, u32),
2543}
2544
2545#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2546#[derive(Serialize, Deserialize)]
2547pub struct JmaBatchJsOutput {
2548    pub values: Vec<f64>,
2549    pub combos: Vec<JmaParams>,
2550    pub rows: usize,
2551    pub cols: usize,
2552}
2553
2554#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2555#[wasm_bindgen(js_name = jma_batch)]
2556pub fn jma_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2557    let config: JmaBatchConfig = serde_wasm_bindgen::from_value(config)
2558        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2559
2560    let sweep = JmaBatchRange {
2561        period: config.period_range,
2562        phase: config.phase_range,
2563        power: config.power_range,
2564    };
2565
2566    let simd = match detect_best_batch_kernel() {
2567        Kernel::Avx512Batch => Kernel::Avx512,
2568        Kernel::Avx2Batch => Kernel::Avx2,
2569        Kernel::ScalarBatch => Kernel::Scalar,
2570        _ => Kernel::Scalar,
2571    };
2572
2573    let output = jma_batch_inner(data, &sweep, simd, false)
2574        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2575
2576    let js_output = JmaBatchJsOutput {
2577        values: output.values,
2578        combos: output.combos,
2579        rows: output.rows,
2580        cols: output.cols,
2581    };
2582
2583    serde_wasm_bindgen::to_value(&js_output)
2584        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2585}
2586
2587#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2588#[wasm_bindgen]
2589#[deprecated(since = "1.0.0", note = "Use jma_batch instead")]
2590pub fn jma_batch_js(
2591    data: &[f64],
2592    period_start: usize,
2593    period_end: usize,
2594    period_step: usize,
2595    phase_start: f64,
2596    phase_end: f64,
2597    phase_step: f64,
2598    power_start: u32,
2599    power_end: u32,
2600    power_step: u32,
2601) -> Result<Vec<f64>, JsValue> {
2602    let sweep = JmaBatchRange {
2603        period: (period_start, period_end, period_step),
2604        phase: (phase_start, phase_end, phase_step),
2605        power: (power_start, power_end, power_step),
2606    };
2607
2608    jma_batch_inner(data, &sweep, Kernel::Scalar, false)
2609        .map(|output| output.values)
2610        .map_err(|e| JsValue::from_str(&e.to_string()))
2611}
2612
2613#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2614#[wasm_bindgen]
2615pub fn jma_batch_metadata_js(
2616    period_start: usize,
2617    period_end: usize,
2618    period_step: usize,
2619    phase_start: f64,
2620    phase_end: f64,
2621    phase_step: f64,
2622    power_start: u32,
2623    power_end: u32,
2624    power_step: u32,
2625) -> Vec<f64> {
2626    let mut metadata = Vec::new();
2627
2628    let mut current_period = period_start;
2629    while current_period <= period_end {
2630        let mut current_phase = phase_start;
2631        while current_phase <= phase_end || (phase_step == 0.0 && current_phase == phase_start) {
2632            let mut current_power = power_start;
2633            while current_power <= power_end || (power_step == 0 && current_power == power_start) {
2634                metadata.push(current_period as f64);
2635                metadata.push(current_phase);
2636                metadata.push(current_power as f64);
2637
2638                if power_step == 0 {
2639                    break;
2640                }
2641                current_power += power_step;
2642            }
2643            if phase_step == 0.0 {
2644                break;
2645            }
2646            current_phase += phase_step;
2647        }
2648        if period_step == 0 {
2649            break;
2650        }
2651        current_period += period_step;
2652    }
2653
2654    metadata
2655}
2656
2657#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2658#[wasm_bindgen]
2659pub fn jma_batch_into(
2660    in_ptr: *const f64,
2661    out_ptr: *mut f64,
2662    len: usize,
2663    period_start: usize,
2664    period_end: usize,
2665    period_step: usize,
2666    phase_start: f64,
2667    phase_end: f64,
2668    phase_step: f64,
2669    power_start: u32,
2670    power_end: u32,
2671    power_step: u32,
2672) -> Result<usize, JsValue> {
2673    if in_ptr.is_null() || out_ptr.is_null() {
2674        return Err(JsValue::from_str("null pointer passed to jma_batch_into"));
2675    }
2676
2677    unsafe {
2678        let data = std::slice::from_raw_parts(in_ptr, len);
2679
2680        let sweep = JmaBatchRange {
2681            period: (period_start, period_end, period_step),
2682            phase: (phase_start, phase_end, phase_step),
2683            power: (power_start, power_end, power_step),
2684        };
2685
2686        let combos = expand_grid_jma(&sweep);
2687        let rows = combos.len();
2688        let cols = len;
2689
2690        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2691
2692        jma_batch_inner_into(data, &sweep, Kernel::Scalar, false, out)
2693            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2694
2695        Ok(rows)
2696    }
2697}