Skip to main content

vector_ta/indicators/moving_averages/
jsa.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
10use core::arch::x86_64::*;
11#[cfg(not(target_arch = "wasm32"))]
12use rayon::prelude::*;
13use std::convert::AsRef;
14use std::error::Error;
15use std::mem::MaybeUninit;
16use thiserror::Error;
17
18impl<'a> AsRef<[f64]> for JsaInput<'a> {
19    #[inline(always)]
20    fn as_ref(&self) -> &[f64] {
21        match &self.data {
22            JsaData::Slice(slice) => slice,
23            JsaData::Candles { candles, source } => source_type(candles, source),
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub enum JsaData<'a> {
30    Candles {
31        candles: &'a Candles,
32        source: &'a str,
33    },
34    Slice(&'a [f64]),
35}
36
37#[derive(Debug, Clone)]
38pub struct JsaOutput {
39    pub values: Vec<f64>,
40}
41
42#[derive(Debug, Clone)]
43pub struct JsaParams {
44    pub period: Option<usize>,
45}
46
47impl Default for JsaParams {
48    fn default() -> Self {
49        Self { period: Some(30) }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct JsaInput<'a> {
55    pub data: JsaData<'a>,
56    pub params: JsaParams,
57}
58
59impl<'a> JsaInput<'a> {
60    #[inline(always)]
61    pub fn from_candles(c: &'a Candles, s: &'a str, p: JsaParams) -> Self {
62        Self {
63            data: JsaData::Candles {
64                candles: c,
65                source: s,
66            },
67            params: p,
68        }
69    }
70    #[inline(always)]
71    pub fn from_slice(sl: &'a [f64], p: JsaParams) -> Self {
72        Self {
73            data: JsaData::Slice(sl),
74            params: p,
75        }
76    }
77    #[inline(always)]
78    pub fn with_default_candles(c: &'a Candles) -> Self {
79        Self::from_candles(c, "close", JsaParams::default())
80    }
81    #[inline(always)]
82    pub fn get_period(&self) -> usize {
83        self.params.period.unwrap_or(30)
84    }
85}
86
87#[derive(Copy, Clone, Debug)]
88pub struct JsaBuilder {
89    period: Option<usize>,
90    kernel: Kernel,
91}
92
93impl Default for JsaBuilder {
94    fn default() -> Self {
95        Self {
96            period: None,
97            kernel: Kernel::Auto,
98        }
99    }
100}
101
102impl JsaBuilder {
103    #[inline(always)]
104    pub fn new() -> Self {
105        Self::default()
106    }
107    #[inline(always)]
108    pub fn period(mut self, n: usize) -> Self {
109        self.period = Some(n);
110        self
111    }
112    #[inline(always)]
113    pub fn kernel(mut self, k: Kernel) -> Self {
114        self.kernel = k;
115        self
116    }
117    #[inline(always)]
118    pub fn apply(self, c: &Candles) -> Result<JsaOutput, JsaError> {
119        let p = JsaParams {
120            period: self.period,
121        };
122        let i = JsaInput::from_candles(c, "close", p);
123        jsa_with_kernel(&i, self.kernel)
124    }
125    #[inline(always)]
126    pub fn apply_slice(self, d: &[f64]) -> Result<JsaOutput, JsaError> {
127        let p = JsaParams {
128            period: self.period,
129        };
130        let i = JsaInput::from_slice(d, p);
131        jsa_with_kernel(&i, self.kernel)
132    }
133    #[inline(always)]
134    pub fn into_stream(self) -> Result<JsaStream, JsaError> {
135        let p = JsaParams {
136            period: self.period,
137        };
138        JsaStream::try_new(p)
139    }
140}
141
142#[derive(Debug, Error)]
143pub enum JsaError {
144    #[error("jsa: Input data slice is empty.")]
145    EmptyInputData,
146
147    #[error("jsa: All values are NaN.")]
148    AllValuesNaN,
149
150    #[error("jsa: Invalid period: period = {period}, data length = {data_len}")]
151    InvalidPeriod { period: usize, data_len: usize },
152
153    #[error("jsa: Not enough valid data: needed = {needed}, valid = {valid}")]
154    NotEnoughValidData { needed: usize, valid: usize },
155
156    #[error("jsa: output length mismatch: expected = {expected}, got = {got}")]
157    OutputLengthMismatch { expected: usize, got: usize },
158
159    #[error("jsa: invalid kernel for batch op: {kernel:?}")]
160    InvalidKernel { kernel: Kernel },
161
162    #[error("jsa: invalid range expansion: start={start}, end={end}, step={step}")]
163    InvalidRange {
164        start: usize,
165        end: usize,
166        step: usize,
167    },
168
169    #[error("jsa: arithmetic overflow while computing sizes")]
170    ArithmeticOverflow,
171
172    #[error("jsa: invalid kernel used for batch op: {kernel:?}")]
173    InvalidKernelForBatch { kernel: Kernel },
174}
175
176#[inline]
177pub fn jsa(input: &JsaInput) -> Result<JsaOutput, JsaError> {
178    jsa_with_kernel(input, Kernel::Auto)
179}
180
181#[inline(always)]
182fn jsa_compute_into(data: &[f64], period: usize, first: usize, k: Kernel, out: &mut [f64]) {
183    unsafe {
184        match k {
185            Kernel::Scalar | Kernel::ScalarBatch => jsa_scalar(data, period, first, out),
186            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
187            Kernel::Avx2 | Kernel::Avx2Batch => jsa_avx2(data, period, first, out),
188            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
189            Kernel::Avx512 | Kernel::Avx512Batch => jsa_avx512(data, period, first, out),
190            _ => unreachable!(),
191        }
192    }
193}
194
195pub fn jsa_with_kernel(input: &JsaInput, kernel: Kernel) -> Result<JsaOutput, JsaError> {
196    let data: &[f64] = match &input.data {
197        JsaData::Candles { candles, source } => source_type(candles, source),
198        JsaData::Slice(sl) => sl,
199    };
200
201    if data.is_empty() {
202        return Err(JsaError::EmptyInputData);
203    }
204
205    let first = data
206        .iter()
207        .position(|x| !x.is_nan())
208        .ok_or(JsaError::AllValuesNaN)?;
209    let len = data.len();
210    let period = input.get_period();
211
212    if period == 0 || period > len {
213        return Err(JsaError::InvalidPeriod {
214            period,
215            data_len: len,
216        });
217    }
218    if (len - first) < period {
219        return Err(JsaError::NotEnoughValidData {
220            needed: period,
221            valid: len - first,
222        });
223    }
224
225    let warm = first
226        .checked_add(period)
227        .ok_or(JsaError::ArithmeticOverflow)?;
228    let mut out = alloc_with_nan_prefix(len, warm);
229    let chosen = match kernel {
230        Kernel::Auto => Kernel::Scalar,
231        k => k,
232    };
233    jsa_compute_into(data, period, first, chosen, &mut out);
234    Ok(JsaOutput { values: out })
235}
236
237#[inline]
238#[inline]
239pub fn jsa_into(input: &JsaInput, out: &mut [f64]) -> Result<(), JsaError> {
240    jsa_with_kernel_into(input, Kernel::Auto, out)
241}
242
243#[inline]
244pub fn jsa_into_slice(dst: &mut [f64], input: &JsaInput, kern: Kernel) -> Result<(), JsaError> {
245    jsa_with_kernel_into(input, kern, dst)
246}
247
248pub fn jsa_with_kernel_into(
249    input: &JsaInput,
250    kernel: Kernel,
251    out: &mut [f64],
252) -> Result<(), JsaError> {
253    let data: &[f64] = match &input.data {
254        JsaData::Candles { candles, source } => source_type(candles, source),
255        JsaData::Slice(sl) => sl,
256    };
257
258    if data.is_empty() {
259        return Err(JsaError::EmptyInputData);
260    }
261
262    let len = data.len();
263
264    if out.len() != len {
265        return Err(JsaError::OutputLengthMismatch {
266            expected: len,
267            got: out.len(),
268        });
269    }
270
271    let first = data
272        .iter()
273        .position(|x| !x.is_nan())
274        .ok_or(JsaError::AllValuesNaN)?;
275    let period = input.get_period();
276
277    if period == 0 || period > len {
278        return Err(JsaError::InvalidPeriod {
279            period,
280            data_len: len,
281        });
282    }
283    if (len - first) < period {
284        return Err(JsaError::NotEnoughValidData {
285            needed: period,
286            valid: len - first,
287        });
288    }
289
290    let warm = first + period;
291
292    out[..warm].fill(f64::NAN);
293
294    let chosen = match kernel {
295        Kernel::Auto => Kernel::Scalar,
296        k => k,
297    };
298    jsa_compute_into(data, period, first, chosen, out);
299    Ok(())
300}
301
302#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
303#[inline]
304pub fn jsa_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
305    if period <= 32 {
306        unsafe { jsa_avx512_short(data, period, first_valid, out) }
307    } else {
308        unsafe { jsa_avx512_long(data, period, first_valid, out) }
309    }
310}
311
312#[inline]
313pub fn jsa_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
314    for i in (first_val + period)..data.len() {
315        out[i] = (data[i] + data[i - period]) * 0.5;
316    }
317}
318
319#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
320#[inline(always)]
321pub unsafe fn jsa_avx2(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
322    use core::arch::x86_64::*;
323    let len = data.len();
324    let start = first_val + period;
325    if start >= len {
326        return;
327    }
328
329    let dp = data.as_ptr();
330    let op = out.as_mut_ptr();
331
332    let mut p_cur = dp.add(start);
333    let mut p_past = dp.add(start - period);
334    let mut p_out = op.add(start);
335    let end = op.add(len);
336
337    let half = _mm256_set1_pd(0.5);
338
339    while p_out.add(4) <= end {
340        let x = _mm256_loadu_pd(p_cur);
341        let y = _mm256_loadu_pd(p_past);
342        let s = _mm256_add_pd(x, y);
343        let a = _mm256_mul_pd(s, half);
344        _mm256_storeu_pd(p_out, a);
345        p_cur = p_cur.add(4);
346        p_past = p_past.add(4);
347        p_out = p_out.add(4);
348    }
349
350    while p_out < end {
351        *p_out = (*p_cur + *p_past) * 0.5;
352        p_cur = p_cur.add(1);
353        p_past = p_past.add(1);
354        p_out = p_out.add(1);
355    }
356}
357
358#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
359#[inline(always)]
360pub unsafe fn jsa_avx512_short(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
361    use core::arch::x86_64::*;
362    let len = data.len();
363    let start = first_val + period;
364    if start >= len {
365        return;
366    }
367
368    let dp = data.as_ptr();
369    let op = out.as_mut_ptr();
370
371    let mut p_cur = dp.add(start);
372    let mut p_past = dp.add(start - period);
373    let mut p_out = op.add(start);
374    let end = op.add(len);
375
376    let half = _mm512_set1_pd(0.5);
377
378    while p_out.add(8) <= end {
379        let x = _mm512_loadu_pd(p_cur);
380        let y = _mm512_loadu_pd(p_past);
381        let s = _mm512_add_pd(x, y);
382        let a = _mm512_mul_pd(s, half);
383        _mm512_storeu_pd(p_out, a);
384        p_cur = p_cur.add(8);
385        p_past = p_past.add(8);
386        p_out = p_out.add(8);
387    }
388
389    while p_out < end {
390        *p_out = (*p_cur + *p_past) * 0.5;
391        p_cur = p_cur.add(1);
392        p_past = p_past.add(1);
393        p_out = p_out.add(1);
394    }
395}
396
397#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
398#[inline(always)]
399pub unsafe fn jsa_avx512_long(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
400    use core::arch::x86_64::*;
401    let len = data.len();
402    let start = first_val + period;
403    if start >= len {
404        return;
405    }
406
407    let dp = data.as_ptr();
408    let op = out.as_mut_ptr();
409
410    let mut p_cur = dp.add(start);
411    let mut p_past = dp.add(start - period);
412    let mut p_out = op.add(start);
413    let end = op.add(len);
414
415    let half = _mm512_set1_pd(0.5);
416
417    while p_out.add(8) <= end {
418        let x = _mm512_loadu_pd(p_cur);
419        let y = _mm512_loadu_pd(p_past);
420        let s = _mm512_add_pd(x, y);
421        let a = _mm512_mul_pd(s, half);
422        _mm512_storeu_pd(p_out, a);
423        p_cur = p_cur.add(8);
424        p_past = p_past.add(8);
425        p_out = p_out.add(8);
426    }
427    while p_out < end {
428        *p_out = (*p_cur + *p_past) * 0.5;
429        p_cur = p_cur.add(1);
430        p_past = p_past.add(1);
431        p_out = p_out.add(1);
432    }
433}
434
435#[derive(Debug, Clone)]
436pub struct JsaStream {
437    period: usize,
438    buffer: Vec<f64>,
439    head: usize,
440    filled: bool,
441}
442
443impl JsaStream {
444    pub fn try_new(params: JsaParams) -> Result<Self, JsaError> {
445        let period = params.period.unwrap_or(30);
446        if period == 0 {
447            return Err(JsaError::InvalidPeriod {
448                period,
449                data_len: 0,
450            });
451        }
452        Ok(Self {
453            period,
454            buffer: vec![f64::NAN; period],
455            head: 0,
456            filled: false,
457        })
458    }
459
460    #[inline(always)]
461    pub fn update(&mut self, value: f64) -> Option<f64> {
462        let out = if self.filled {
463            let past = self.buffer[self.head];
464            Some((value + past) * 0.5)
465        } else {
466            None
467        };
468
469        self.buffer[self.head] = value;
470
471        let next = self.head + 1;
472        if next == self.period {
473            self.head = 0;
474            if !self.filled {
475                self.filled = true;
476            }
477        } else {
478            self.head = next;
479        }
480
481        out
482    }
483}
484
485#[derive(Clone, Debug)]
486pub struct JsaBatchRange {
487    pub period: (usize, usize, usize),
488}
489
490impl Default for JsaBatchRange {
491    fn default() -> Self {
492        Self {
493            period: (30, 279, 1),
494        }
495    }
496}
497
498#[derive(Clone, Debug, Default)]
499pub struct JsaBatchBuilder {
500    range: JsaBatchRange,
501    kernel: Kernel,
502}
503
504impl JsaBatchBuilder {
505    pub fn new() -> Self {
506        Self::default()
507    }
508    pub fn kernel(mut self, k: Kernel) -> Self {
509        self.kernel = k;
510        self
511    }
512    #[inline]
513    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
514        self.range.period = (start, end, step);
515        self
516    }
517    #[inline]
518    pub fn period_static(mut self, p: usize) -> Self {
519        self.range.period = (p, p, 0);
520        self
521    }
522    pub fn apply_slice(self, data: &[f64]) -> Result<JsaBatchOutput, JsaError> {
523        jsa_batch_with_kernel(data, &self.range, self.kernel)
524    }
525    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<JsaBatchOutput, JsaError> {
526        JsaBatchBuilder::new().kernel(k).apply_slice(data)
527    }
528    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<JsaBatchOutput, JsaError> {
529        let slice = source_type(c, src);
530        self.apply_slice(slice)
531    }
532    pub fn with_default_candles(c: &Candles) -> Result<JsaBatchOutput, JsaError> {
533        JsaBatchBuilder::new()
534            .kernel(Kernel::Auto)
535            .apply_candles(c, "close")
536    }
537}
538
539pub fn jsa_batch_with_kernel(
540    data: &[f64],
541    sweep: &JsaBatchRange,
542    k: Kernel,
543) -> Result<JsaBatchOutput, JsaError> {
544    let kernel = match k {
545        Kernel::Auto => detect_best_batch_kernel(),
546        other if other.is_batch() => other,
547        other => return Err(JsaError::InvalidKernelForBatch { kernel: other }),
548    };
549
550    let simd = match kernel {
551        Kernel::Avx512Batch => Kernel::Avx512,
552        Kernel::Avx2Batch => Kernel::Avx2,
553        Kernel::ScalarBatch => Kernel::Scalar,
554        _ => unreachable!(),
555    };
556    jsa_batch_par_slice(data, sweep, simd)
557}
558
559#[derive(Clone, Debug)]
560pub struct JsaBatchOutput {
561    pub values: Vec<f64>,
562    pub combos: Vec<JsaParams>,
563    pub rows: usize,
564    pub cols: usize,
565}
566impl JsaBatchOutput {
567    pub fn row_for_params(&self, p: &JsaParams) -> Option<usize> {
568        self.combos
569            .iter()
570            .position(|c| c.period.unwrap_or(30) == p.period.unwrap_or(30))
571    }
572    pub fn values_for(&self, p: &JsaParams) -> Option<&[f64]> {
573        self.row_for_params(p).map(|row| {
574            let start = row * self.cols;
575            &self.values[start..start + self.cols]
576        })
577    }
578}
579
580#[inline(always)]
581fn expand_grid(r: &JsaBatchRange) -> Result<Vec<JsaParams>, JsaError> {
582    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, JsaError> {
583        if step == 0 || start == end {
584            return Ok(vec![start]);
585        }
586        let mut v = Vec::new();
587        if start < end {
588            let mut cur = start;
589            while cur <= end {
590                v.push(cur);
591
592                match cur.checked_add(step) {
593                    Some(next) => cur = next,
594                    None => break,
595                }
596            }
597        } else {
598            let mut cur = start;
599            while cur >= end {
600                v.push(cur);
601
602                if cur < end + step {
603                    break;
604                }
605                cur -= step;
606                if cur == usize::MAX {
607                    break;
608                }
609            }
610        }
611        if v.is_empty() {
612            return Err(JsaError::InvalidRange { start, end, step });
613        }
614        Ok(v)
615    }
616    let periods = axis_usize(r.period)?;
617    let mut out = Vec::with_capacity(periods.len());
618    for &p in &periods {
619        out.push(JsaParams { period: Some(p) });
620    }
621    Ok(out)
622}
623
624#[inline(always)]
625pub fn jsa_batch_slice(
626    data: &[f64],
627    sweep: &JsaBatchRange,
628    kern: Kernel,
629) -> Result<JsaBatchOutput, JsaError> {
630    jsa_batch_inner(data, sweep, kern, false)
631}
632
633#[inline(always)]
634pub fn jsa_batch_par_slice(
635    data: &[f64],
636    sweep: &JsaBatchRange,
637    kern: Kernel,
638) -> Result<JsaBatchOutput, JsaError> {
639    jsa_batch_inner(data, sweep, kern, true)
640}
641
642#[inline(always)]
643fn jsa_batch_inner(
644    data: &[f64],
645    sweep: &JsaBatchRange,
646    kern: Kernel,
647    parallel: bool,
648) -> Result<JsaBatchOutput, JsaError> {
649    if data.is_empty() {
650        return Err(JsaError::EmptyInputData);
651    }
652
653    let combos = expand_grid(sweep)?;
654    let first = data
655        .iter()
656        .position(|x| !x.is_nan())
657        .ok_or(JsaError::AllValuesNaN)?;
658    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
659    if data.len() - first < max_p {
660        return Err(JsaError::NotEnoughValidData {
661            needed: max_p,
662            valid: data.len() - first,
663        });
664    }
665    let rows = combos.len();
666    let cols = data.len();
667
668    let _total = rows.checked_mul(cols).ok_or(JsaError::ArithmeticOverflow)?;
669    let mut warm: Vec<usize> = Vec::with_capacity(rows);
670    for c in &combos {
671        let p = c.period.unwrap();
672        let w = first.checked_add(p).ok_or(JsaError::ArithmeticOverflow)?;
673        warm.push(w);
674    }
675
676    let mut raw = make_uninit_matrix(rows, cols);
677
678    init_matrix_prefixes(&mut raw, cols, &warm);
679
680    let actual_kern = match kern {
681        Kernel::Auto => detect_best_batch_kernel(),
682        k => k,
683    };
684
685    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
686        let period = combos[row].period.unwrap();
687
688        let out_row =
689            core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
690
691        match actual_kern {
692            Kernel::ScalarBatch | Kernel::Scalar => jsa_row_scalar(data, first, period, out_row),
693            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
694            Kernel::Avx2Batch | Kernel::Avx2 => jsa_row_avx2(data, first, period, out_row),
695            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
696            Kernel::Avx512Batch | Kernel::Avx512 => jsa_row_avx512(data, first, period, out_row),
697            _ => unreachable!(),
698        }
699    };
700
701    if parallel {
702        #[cfg(not(target_arch = "wasm32"))]
703        {
704            raw.par_chunks_mut(cols)
705                .enumerate()
706                .for_each(|(row, slice)| do_row(row, slice));
707        }
708
709        #[cfg(target_arch = "wasm32")]
710        {
711            for (row, slice) in raw.chunks_mut(cols).enumerate() {
712                do_row(row, slice);
713            }
714        }
715    } else {
716        for (row, slice) in raw.chunks_mut(cols).enumerate() {
717            do_row(row, slice);
718        }
719    }
720
721    use core::mem::ManuallyDrop;
722
723    let mut buf_guard = ManuallyDrop::new(raw);
724    let values = unsafe {
725        Vec::from_raw_parts(
726            buf_guard.as_mut_ptr() as *mut f64,
727            buf_guard.len(),
728            buf_guard.capacity(),
729        )
730    };
731
732    Ok(JsaBatchOutput {
733        values,
734        combos,
735        rows,
736        cols,
737    })
738}
739
740#[inline(always)]
741fn jsa_batch_inner_into(
742    data: &[f64],
743    sweep: &JsaBatchRange,
744    kern: Kernel,
745    parallel: bool,
746    out: &mut [f64],
747) -> Result<(Vec<JsaParams>, usize, usize), JsaError> {
748    if data.is_empty() {
749        return Err(JsaError::EmptyInputData);
750    }
751
752    let combos = expand_grid(sweep)?;
753    let first = data
754        .iter()
755        .position(|x| !x.is_nan())
756        .ok_or(JsaError::AllValuesNaN)?;
757    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
758    if data.len() - first < max_p {
759        return Err(JsaError::NotEnoughValidData {
760            needed: max_p,
761            valid: data.len() - first,
762        });
763    }
764    let rows = combos.len();
765    let cols = data.len();
766
767    let expected = rows.checked_mul(cols).ok_or(JsaError::ArithmeticOverflow)?;
768    if out.len() != expected {
769        return Err(JsaError::OutputLengthMismatch {
770            expected,
771            got: out.len(),
772        });
773    }
774    let mut warm: Vec<usize> = Vec::with_capacity(rows);
775    for c in &combos {
776        let p = c.period.unwrap();
777        let w = first.checked_add(p).ok_or(JsaError::ArithmeticOverflow)?;
778        warm.push(w);
779    }
780
781    let out_uninit = unsafe {
782        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
783    };
784    init_matrix_prefixes(out_uninit, cols, &warm);
785
786    let actual_kern = match kern {
787        Kernel::Auto => detect_best_batch_kernel(),
788        k => k,
789    };
790
791    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
792        let period = combos[row].period.unwrap();
793
794        match actual_kern {
795            Kernel::ScalarBatch | Kernel::Scalar => jsa_row_scalar(data, first, period, out_row),
796            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
797            Kernel::Avx2Batch | Kernel::Avx2 => jsa_row_avx2(data, first, period, out_row),
798            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
799            Kernel::Avx512Batch | Kernel::Avx512 => jsa_row_avx512(data, first, period, out_row),
800            _ => unreachable!(),
801        }
802    };
803
804    if parallel {
805        #[cfg(not(target_arch = "wasm32"))]
806        {
807            out.par_chunks_mut(cols)
808                .enumerate()
809                .for_each(|(row, slice)| do_row(row, slice));
810        }
811        #[cfg(target_arch = "wasm32")]
812        {
813            for (row, slice) in out.chunks_mut(cols).enumerate() {
814                do_row(row, slice);
815            }
816        }
817    } else {
818        for (row, slice) in out.chunks_mut(cols).enumerate() {
819            do_row(row, slice);
820        }
821    }
822
823    Ok((combos, rows, cols))
824}
825
826#[inline(always)]
827unsafe fn jsa_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
828    for i in (first + period)..data.len() {
829        out[i] = (data[i] + data[i - period]) * 0.5;
830    }
831}
832
833#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
834#[inline(always)]
835unsafe fn jsa_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
836    use core::arch::x86_64::*;
837    let len = data.len();
838    let start = first + period;
839    if start >= len {
840        return;
841    }
842
843    let dp = data.as_ptr();
844    let op = out.as_mut_ptr();
845
846    let mut p_cur = dp.add(start);
847    let mut p_past = dp.add(start - period);
848    let mut p_out = op.add(start);
849    let end = op.add(len);
850
851    let half = _mm256_set1_pd(0.5);
852
853    while p_out.add(4) <= end {
854        let x = _mm256_loadu_pd(p_cur);
855        let y = _mm256_loadu_pd(p_past);
856        let s = _mm256_add_pd(x, y);
857        let a = _mm256_mul_pd(s, half);
858        _mm256_storeu_pd(p_out, a);
859        p_cur = p_cur.add(4);
860        p_past = p_past.add(4);
861        p_out = p_out.add(4);
862    }
863    while p_out < end {
864        *p_out = (*p_cur + *p_past) * 0.5;
865        p_cur = p_cur.add(1);
866        p_past = p_past.add(1);
867        p_out = p_out.add(1);
868    }
869}
870
871#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
872#[inline(always)]
873unsafe fn jsa_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
874    if period <= 32 {
875        jsa_row_avx512_short(data, first, period, out);
876    } else {
877        jsa_row_avx512_long(data, first, period, out);
878    }
879}
880
881#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
882#[inline(always)]
883unsafe fn jsa_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
884    use core::arch::x86_64::*;
885    let len = data.len();
886    let start = first + period;
887    if start >= len {
888        return;
889    }
890
891    let dp = data.as_ptr();
892    let op = out.as_mut_ptr();
893
894    let mut p_cur = dp.add(start);
895    let mut p_past = dp.add(start - period);
896    let mut p_out = op.add(start);
897    let end = op.add(len);
898
899    let half = _mm512_set1_pd(0.5);
900
901    while p_out.add(8) <= end {
902        let x = _mm512_loadu_pd(p_cur);
903        let y = _mm512_loadu_pd(p_past);
904        let s = _mm512_add_pd(x, y);
905        let a = _mm512_mul_pd(s, half);
906        _mm512_storeu_pd(p_out, a);
907        p_cur = p_cur.add(8);
908        p_past = p_past.add(8);
909        p_out = p_out.add(8);
910    }
911    while p_out < end {
912        *p_out = (*p_cur + *p_past) * 0.5;
913        p_cur = p_cur.add(1);
914        p_past = p_past.add(1);
915        p_out = p_out.add(1);
916    }
917}
918
919#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
920#[inline(always)]
921unsafe fn jsa_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
922    use core::arch::x86_64::*;
923    let len = data.len();
924    let start = first + period;
925    if start >= len {
926        return;
927    }
928
929    let dp = data.as_ptr();
930    let op = out.as_mut_ptr();
931
932    let mut p_cur = dp.add(start);
933    let mut p_past = dp.add(start - period);
934    let mut p_out = op.add(start);
935    let end = op.add(len);
936
937    let half = _mm512_set1_pd(0.5);
938
939    while p_out.add(8) <= end {
940        let x = _mm512_loadu_pd(p_cur);
941        let y = _mm512_loadu_pd(p_past);
942        let s = _mm512_add_pd(x, y);
943        let a = _mm512_mul_pd(s, half);
944        _mm512_storeu_pd(p_out, a);
945        p_cur = p_cur.add(8);
946        p_past = p_past.add(8);
947        p_out = p_out.add(8);
948    }
949    while p_out < end {
950        *p_out = (*p_cur + *p_past) * 0.5;
951        p_cur = p_cur.add(1);
952        p_past = p_past.add(1);
953        p_out = p_out.add(1);
954    }
955}
956
957#[cfg(feature = "python")]
958use numpy::PyUntypedArrayMethods;
959#[cfg(feature = "python")]
960use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
961#[cfg(feature = "python")]
962use pyo3::exceptions::PyValueError;
963#[cfg(feature = "python")]
964use pyo3::prelude::*;
965#[cfg(feature = "python")]
966use pyo3::types::PyDict;
967
968#[cfg(all(feature = "python", feature = "cuda"))]
969use crate::cuda::moving_averages::jsa_wrapper::JsaDeviceHandle;
970#[cfg(all(feature = "python", feature = "cuda"))]
971use crate::cuda::{cuda_available, moving_averages::CudaJsa};
972#[cfg(all(feature = "python", feature = "cuda"))]
973use cust::context::Context;
974#[cfg(all(feature = "python", feature = "cuda"))]
975use cust::memory::DeviceBuffer;
976#[cfg(all(feature = "python", feature = "cuda"))]
977use std::sync::Arc;
978
979#[cfg(feature = "python")]
980#[pyfunction]
981#[pyo3(name = "jsa")]
982#[pyo3(signature = (data, period, kernel=None))]
983pub fn jsa_py<'py>(
984    py: Python<'py>,
985    data: PyReadonlyArray1<'py, f64>,
986    period: usize,
987    kernel: Option<&str>,
988) -> PyResult<Bound<'py, PyArray1<f64>>> {
989    use numpy::PyArrayMethods;
990
991    let kern = validate_kernel(kernel, false)?;
992
993    if data.is_c_contiguous() {
994        let slice_in = data.as_slice()?;
995
996        let out_arr = unsafe { PyArray1::<f64>::new(py, [slice_in.len()], false) };
997        let slice_out = unsafe { out_arr.as_slice_mut()? };
998        let params = JsaParams {
999            period: Some(period),
1000        };
1001        let input = JsaInput::from_slice(slice_in, params);
1002        py.allow_threads(|| jsa_with_kernel_into(&input, kern, slice_out))
1003            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1004        Ok(out_arr)
1005    } else {
1006        let owned = data.as_array().to_owned();
1007        let slice_in = owned.as_slice().expect("owned array should be contiguous");
1008        let params = JsaParams {
1009            period: Some(period),
1010        };
1011        let input = JsaInput::from_slice(slice_in, params);
1012        let mut buf = vec![f64::NAN; slice_in.len()];
1013        py.allow_threads(|| jsa_with_kernel_into(&input, kern, &mut buf))
1014            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1015        Ok(PyArray1::from_vec(py, buf))
1016    }
1017}
1018
1019#[cfg(feature = "python")]
1020#[pyfunction]
1021#[pyo3(name = "jsa_batch")]
1022#[pyo3(signature = (data, period_start, period_end, period_step, kernel=None))]
1023pub fn jsa_batch_py<'py>(
1024    py: Python<'py>,
1025    data: PyReadonlyArray1<'py, f64>,
1026    period_start: usize,
1027    period_end: usize,
1028    period_step: usize,
1029    kernel: Option<&str>,
1030) -> PyResult<Bound<'py, PyDict>> {
1031    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1032    use pyo3::types::PyDict;
1033
1034    let slice_in = data.as_slice()?;
1035    let kern = validate_kernel(kernel, true)?;
1036    let sweep = JsaBatchRange {
1037        period: (period_start, period_end, period_step),
1038    };
1039
1040    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1041    let rows = combos.len();
1042    let cols = slice_in.len();
1043    let total = rows
1044        .checked_mul(cols)
1045        .ok_or_else(|| PyValueError::new_err("size overflow"))?;
1046
1047    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1048    let slice_out = unsafe { out_arr.as_slice_mut()? };
1049
1050    let (combos, _, _) = py
1051        .allow_threads(|| {
1052            let kernel = match kern {
1053                Kernel::Auto => detect_best_batch_kernel(),
1054                k => k,
1055            };
1056
1057            let simd = match kernel {
1058                Kernel::Avx512Batch => Kernel::Avx512,
1059                Kernel::Avx2Batch => Kernel::Avx2,
1060                Kernel::ScalarBatch => Kernel::Scalar,
1061                _ => kernel,
1062            };
1063
1064            jsa_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1065        })
1066        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1067
1068    let dict = PyDict::new(py);
1069    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1070
1071    dict.set_item(
1072        "periods",
1073        combos
1074            .iter()
1075            .map(|p| p.period.unwrap() as u64)
1076            .collect::<Vec<_>>()
1077            .into_pyarray(py),
1078    )?;
1079
1080    Ok(dict)
1081}
1082
1083#[cfg(all(feature = "python", feature = "cuda"))]
1084#[pyfunction(name = "jsa_cuda_batch_dev")]
1085#[pyo3(signature = (data_f32, period_range=(30, 30, 0), device_id=0))]
1086pub fn jsa_cuda_batch_dev_py(
1087    py: Python<'_>,
1088    data_f32: PyReadonlyArray1<'_, f32>,
1089    period_range: (usize, usize, usize),
1090    device_id: usize,
1091) -> PyResult<JsaDeviceArrayF32Py> {
1092    if !cuda_available() {
1093        return Err(PyValueError::new_err("CUDA not available"));
1094    }
1095
1096    let slice_in = data_f32.as_slice()?;
1097    let sweep = JsaBatchRange {
1098        period: period_range,
1099    };
1100
1101    use crate::cuda::moving_averages::CudaJsaError;
1102    let handle: JsaDeviceHandle = py
1103        .allow_threads(|| -> Result<_, CudaJsaError> {
1104            let cuda = CudaJsa::new(device_id)?;
1105            cuda.jsa_batch_dev_handle(slice_in, &sweep)
1106        })
1107        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1108
1109    Ok(JsaDeviceArrayF32Py::from_handle_rust(handle))
1110}
1111
1112#[cfg(all(feature = "python", feature = "cuda"))]
1113#[pyfunction(name = "jsa_cuda_many_series_one_param_dev")]
1114#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1115pub fn jsa_cuda_many_series_one_param_dev_py(
1116    py: Python<'_>,
1117    data_tm_f32: PyReadonlyArray2<'_, f32>,
1118    period: usize,
1119    device_id: usize,
1120) -> PyResult<JsaDeviceArrayF32Py> {
1121    if !cuda_available() {
1122        return Err(PyValueError::new_err("CUDA not available"));
1123    }
1124    if period == 0 {
1125        return Err(PyValueError::new_err("period must be positive"));
1126    }
1127
1128    let flat = data_tm_f32.as_slice()?;
1129    let shape = data_tm_f32.shape();
1130    if shape.len() != 2 {
1131        return Err(PyValueError::new_err("expected a 2D array"));
1132    }
1133    let series_len = shape[0];
1134    let num_series = shape[1];
1135    let params = JsaParams {
1136        period: Some(period),
1137    };
1138
1139    use crate::cuda::moving_averages::CudaJsaError;
1140    let handle: JsaDeviceHandle = py
1141        .allow_threads(|| -> Result<_, CudaJsaError> {
1142            let cuda = CudaJsa::new(device_id)?;
1143            cuda.jsa_many_series_one_param_time_major_dev_handle(
1144                flat, num_series, series_len, &params,
1145            )
1146        })
1147        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1148
1149    Ok(JsaDeviceArrayF32Py::from_handle_rust(handle))
1150}
1151
1152#[cfg(feature = "python")]
1153#[pyclass(name = "JsaStream")]
1154pub struct JsaStreamPy {
1155    inner: JsaStream,
1156}
1157
1158#[cfg(feature = "python")]
1159#[pymethods]
1160impl JsaStreamPy {
1161    #[new]
1162    pub fn new(period: usize) -> PyResult<Self> {
1163        let params = JsaParams {
1164            period: Some(period),
1165        };
1166        let stream =
1167            JsaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1168        Ok(JsaStreamPy { inner: stream })
1169    }
1170
1171    pub fn update(&mut self, value: f64) -> Option<f64> {
1172        self.inner.update(value)
1173    }
1174}
1175
1176#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1177use serde::{Deserialize, Serialize};
1178#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1179use wasm_bindgen::prelude::*;
1180
1181#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1182#[wasm_bindgen]
1183pub fn jsa_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1184    let params = JsaParams {
1185        period: Some(period),
1186    };
1187    let input = JsaInput::from_slice(data, params);
1188
1189    let mut output = vec![0.0; data.len()];
1190
1191    jsa_into(&input, &mut output).map_err(|e| JsValue::from_str(&e.to_string()))?;
1192
1193    Ok(output)
1194}
1195
1196#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1197#[derive(Serialize, Deserialize)]
1198pub struct JsaBatchConfig {
1199    pub period_range: (usize, usize, usize),
1200}
1201
1202#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1203#[derive(Serialize, Deserialize)]
1204pub struct JsaBatchJsOutput {
1205    pub values: Vec<f64>,
1206    pub periods: Vec<usize>,
1207    pub rows: usize,
1208    pub cols: usize,
1209}
1210
1211#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1212#[wasm_bindgen(js_name = jsa_batch)]
1213pub fn jsa_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1214    let config: JsaBatchConfig = serde_wasm_bindgen::from_value(config)
1215        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1216
1217    let sweep = JsaBatchRange {
1218        period: config.period_range,
1219    };
1220
1221    let output = jsa_batch_inner(data, &sweep, Kernel::Auto, false)
1222        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1223
1224    let js_output = JsaBatchJsOutput {
1225        values: output.values,
1226        periods: output
1227            .combos
1228            .iter()
1229            .map(|p| p.period.unwrap_or(30))
1230            .collect(),
1231        rows: output.rows,
1232        cols: output.cols,
1233    };
1234
1235    serde_wasm_bindgen::to_value(&js_output)
1236        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1237}
1238
1239#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1240#[wasm_bindgen(js_name = jsa_batch_simple)]
1241pub fn jsa_batch_simple(
1242    data: &[f64],
1243    period_start: usize,
1244    period_end: usize,
1245    period_step: usize,
1246) -> Result<Vec<f64>, JsValue> {
1247    let sweep = JsaBatchRange {
1248        period: (period_start, period_end, period_step),
1249    };
1250
1251    jsa_batch_inner(data, &sweep, Kernel::Auto, false)
1252        .map(|output| output.values)
1253        .map_err(|e| JsValue::from_str(&e.to_string()))
1254}
1255
1256#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1257#[wasm_bindgen]
1258pub fn jsa_alloc(len: usize) -> *mut f64 {
1259    let mut vec = Vec::<f64>::with_capacity(len);
1260    let ptr = vec.as_mut_ptr();
1261    std::mem::forget(vec);
1262    ptr
1263}
1264
1265#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1266#[wasm_bindgen]
1267pub fn jsa_free(ptr: *mut f64, len: usize) {
1268    if !ptr.is_null() {
1269        unsafe {
1270            let _ = Vec::from_raw_parts(ptr, len, len);
1271        }
1272    }
1273}
1274
1275#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1276#[wasm_bindgen(js_name = jsa_into)]
1277pub fn jsa_into_wasm(
1278    in_ptr: *const f64,
1279    out_ptr: *mut f64,
1280    len: usize,
1281    period: usize,
1282) -> Result<(), JsValue> {
1283    if in_ptr.is_null() || out_ptr.is_null() {
1284        return Err(JsValue::from_str("null pointer passed to jsa_into"));
1285    }
1286    unsafe {
1287        let data = std::slice::from_raw_parts(in_ptr, len);
1288        if period == 0 || period > len {
1289            return Err(JsValue::from_str("Invalid period"));
1290        }
1291        let input = JsaInput::from_slice(
1292            data,
1293            JsaParams {
1294                period: Some(period),
1295            },
1296        );
1297        if in_ptr == out_ptr {
1298            let mut temp = vec![0.0; len];
1299            jsa_into_slice(&mut temp, &input, Kernel::Auto)
1300                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1301            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&temp);
1302        } else {
1303            jsa_into_slice(
1304                std::slice::from_raw_parts_mut(out_ptr, len),
1305                &input,
1306                Kernel::Auto,
1307            )
1308            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1309        }
1310        Ok(())
1311    }
1312}
1313
1314#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1315#[wasm_bindgen]
1316#[deprecated(note = "use jsa_into")]
1317pub fn jsa_fast(
1318    in_ptr: *const f64,
1319    out_ptr: *mut f64,
1320    len: usize,
1321    period: usize,
1322) -> Result<(), JsValue> {
1323    jsa_into_wasm(in_ptr, out_ptr, len, period)
1324}
1325
1326#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1327#[wasm_bindgen]
1328pub fn jsa_batch_into(
1329    in_ptr: *const f64,
1330    out_ptr: *mut f64,
1331    len: usize,
1332    period_start: usize,
1333    period_end: usize,
1334    period_step: usize,
1335) -> Result<usize, JsValue> {
1336    if in_ptr.is_null() || out_ptr.is_null() {
1337        return Err(JsValue::from_str("null pointer passed to jsa_batch_into"));
1338    }
1339
1340    unsafe {
1341        let data = std::slice::from_raw_parts(in_ptr, len);
1342        let sweep = JsaBatchRange {
1343            period: (period_start, period_end, period_step),
1344        };
1345
1346        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1347        let rows = combos.len();
1348        let total_size = rows * len;
1349
1350        let out = std::slice::from_raw_parts_mut(out_ptr, total_size);
1351
1352        jsa_batch_inner_into(data, &sweep, Kernel::Auto, false, out)
1353            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1354
1355        Ok(rows)
1356    }
1357}
1358
1359#[cfg(feature = "python")]
1360pub fn register_jsa_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1361    m.add_function(wrap_pyfunction!(jsa_py, m)?)?;
1362    m.add_function(wrap_pyfunction!(jsa_batch_py, m)?)?;
1363    m.add_class::<JsaStreamPy>()?;
1364    #[cfg(feature = "cuda")]
1365    {
1366        m.add_class::<JsaDeviceArrayF32Py>()?;
1367        m.add_function(wrap_pyfunction!(jsa_cuda_batch_dev_py, m)?)?;
1368        m.add_function(wrap_pyfunction!(jsa_cuda_many_series_one_param_dev_py, m)?)?;
1369    }
1370    Ok(())
1371}
1372
1373#[cfg(all(feature = "python", feature = "cuda"))]
1374#[pyclass(module = "ta_indicators.cuda", unsendable)]
1375pub struct JsaDeviceArrayF32Py {
1376    buf: Option<DeviceBuffer<f32>>,
1377    rows: usize,
1378    cols: usize,
1379    _ctx: Arc<Context>,
1380    device_id: u32,
1381}
1382
1383#[cfg(all(feature = "python", feature = "cuda"))]
1384#[pymethods]
1385impl JsaDeviceArrayF32Py {
1386    #[getter]
1387    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1388        let d = PyDict::new(py);
1389        d.set_item("shape", (self.rows, self.cols))?;
1390        d.set_item("typestr", "<f4")?;
1391        d.set_item(
1392            "strides",
1393            (
1394                self.cols * std::mem::size_of::<f32>(),
1395                std::mem::size_of::<f32>(),
1396            ),
1397        )?;
1398        let ptr = if self.rows == 0 || self.cols == 0 {
1399            0usize
1400        } else {
1401            self.buf
1402                .as_ref()
1403                .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1404                .as_device_ptr()
1405                .as_raw() as usize
1406        };
1407        d.set_item("data", (ptr, false))?;
1408
1409        d.set_item("version", 3)?;
1410        Ok(d)
1411    }
1412
1413    fn __dlpack_device__(&self) -> (i32, i32) {
1414        (2, self.device_id as i32)
1415    }
1416
1417    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1418    fn __dlpack__<'py>(
1419        &mut self,
1420        py: Python<'py>,
1421        stream: Option<pyo3::PyObject>,
1422        max_version: Option<pyo3::PyObject>,
1423        dl_device: Option<pyo3::PyObject>,
1424        copy: Option<pyo3::PyObject>,
1425    ) -> PyResult<PyObject> {
1426        use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1427
1428        let (kdl, alloc_dev) = self.__dlpack_device__();
1429        if let Some(dev_obj) = dl_device.as_ref() {
1430            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1431                if dev_ty != kdl || dev_id != alloc_dev {
1432                    let wants_copy = copy
1433                        .as_ref()
1434                        .and_then(|c| c.extract::<bool>(py).ok())
1435                        .unwrap_or(false);
1436                    if wants_copy {
1437                        return Err(PyValueError::new_err(
1438                            "device copy not implemented for __dlpack__",
1439                        ));
1440                    } else {
1441                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1442                    }
1443                }
1444            }
1445        }
1446
1447        let _ = stream;
1448
1449        let buf = self
1450            .buf
1451            .take()
1452            .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1453
1454        let rows = self.rows;
1455        let cols = self.cols;
1456
1457        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1458
1459        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1460    }
1461}
1462
1463#[cfg(all(feature = "python", feature = "cuda"))]
1464impl JsaDeviceArrayF32Py {
1465    pub(crate) fn from_handle_rust(handle: JsaDeviceHandle) -> Self {
1466        JsaDeviceArrayF32Py {
1467            buf: Some(handle.buf),
1468            rows: handle.rows,
1469            cols: handle.cols,
1470            _ctx: handle._ctx,
1471            device_id: handle.device_id,
1472        }
1473    }
1474}
1475
1476#[cfg(test)]
1477mod tests {
1478    use super::*;
1479    use crate::skip_if_unsupported;
1480    use crate::utilities::data_loader::read_candles_from_csv;
1481
1482    fn check_jsa_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1483        skip_if_unsupported!(kernel, test_name);
1484        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1485        let candles = read_candles_from_csv(file_path)?;
1486
1487        let default_params = JsaParams { period: None };
1488        let input = JsaInput::from_candles(&candles, "close", default_params);
1489        let output = jsa_with_kernel(&input, kernel)?;
1490        assert_eq!(output.values.len(), candles.close.len());
1491        Ok(())
1492    }
1493
1494    fn check_jsa_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1495        skip_if_unsupported!(kernel, test_name);
1496        let expected_last_five = [61640.0, 61418.0, 61240.0, 61060.5, 60889.5];
1497        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1498        let candles = read_candles_from_csv(file_path)?;
1499        let default_params = JsaParams::default();
1500        let input = JsaInput::from_candles(&candles, "close", default_params);
1501        let result = jsa_with_kernel(&input, kernel)?;
1502        let start_idx = result.values.len() - 5;
1503        for (i, &val) in result.values[start_idx..].iter().enumerate() {
1504            let expected = expected_last_five[i];
1505            assert!(
1506                (val - expected).abs() < 1e-5,
1507                "[{}] mismatch idx {}: got {}, expected {}",
1508                test_name,
1509                i,
1510                val,
1511                expected
1512            );
1513        }
1514        Ok(())
1515    }
1516
1517    fn check_jsa_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1518        skip_if_unsupported!(kernel, test_name);
1519        let input_data = [10.0, 20.0, 30.0];
1520        let params = JsaParams { period: Some(0) };
1521        let input = JsaInput::from_slice(&input_data, params);
1522        let res = jsa_with_kernel(&input, kernel);
1523        assert!(
1524            res.is_err(),
1525            "[{}] JSA should fail with zero period",
1526            test_name
1527        );
1528        Ok(())
1529    }
1530
1531    fn check_jsa_period_exceeds_length(
1532        test_name: &str,
1533        kernel: Kernel,
1534    ) -> Result<(), Box<dyn Error>> {
1535        skip_if_unsupported!(kernel, test_name);
1536        let data_small = [10.0, 20.0, 30.0];
1537        let params = JsaParams { period: Some(10) };
1538        let input = JsaInput::from_slice(&data_small, params);
1539        let res = jsa_with_kernel(&input, kernel);
1540        assert!(
1541            res.is_err(),
1542            "[{}] JSA should fail with period exceeding length",
1543            test_name
1544        );
1545        Ok(())
1546    }
1547
1548    fn check_jsa_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1549        skip_if_unsupported!(kernel, test_name);
1550        let single_point = [42.0];
1551        let params = JsaParams { period: Some(5) };
1552        let input = JsaInput::from_slice(&single_point, params);
1553        let res = jsa_with_kernel(&input, kernel);
1554        assert!(
1555            res.is_err(),
1556            "[{}] JSA should fail with insufficient data",
1557            test_name
1558        );
1559        Ok(())
1560    }
1561
1562    fn check_jsa_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1563        skip_if_unsupported!(kernel, test_name);
1564        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1565        let candles = read_candles_from_csv(file_path)?;
1566
1567        let first_params = JsaParams { period: Some(10) };
1568        let first_input = JsaInput::from_candles(&candles, "close", first_params);
1569        let first_result = jsa_with_kernel(&first_input, kernel)?;
1570
1571        let second_params = JsaParams { period: Some(5) };
1572        let second_input = JsaInput::from_slice(&first_result.values, second_params);
1573        let second_result = jsa_with_kernel(&second_input, kernel)?;
1574
1575        assert_eq!(second_result.values.len(), first_result.values.len());
1576        for i in 30..second_result.values.len() {
1577            assert!(
1578                second_result.values[i].is_finite(),
1579                "[{}] NaN at idx {}",
1580                test_name,
1581                i
1582            );
1583        }
1584        Ok(())
1585    }
1586
1587    fn check_jsa_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1588        skip_if_unsupported!(kernel, test_name);
1589        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1590        let candles = read_candles_from_csv(file_path)?;
1591        let period = 30;
1592        let input = JsaInput::from_candles(
1593            &candles,
1594            "close",
1595            JsaParams {
1596                period: Some(period),
1597            },
1598        );
1599        let batch_output = jsa_with_kernel(&input, kernel)?.values;
1600
1601        let mut stream = JsaStream::try_new(JsaParams {
1602            period: Some(period),
1603        })?;
1604        let mut stream_values = Vec::with_capacity(candles.close.len());
1605        for &price in &candles.close {
1606            match stream.update(price) {
1607                Some(val) => stream_values.push(val),
1608                None => stream_values.push(f64::NAN),
1609            }
1610        }
1611        assert_eq!(batch_output.len(), stream_values.len());
1612        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1613            if b.is_nan() && s.is_nan() {
1614                continue;
1615            }
1616            let diff = (b - s).abs();
1617            assert!(
1618                diff < 1e-9,
1619                "[{}] streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1620                test_name,
1621                i,
1622                b,
1623                s,
1624                diff
1625            );
1626        }
1627        Ok(())
1628    }
1629
1630    macro_rules! generate_all_jsa_tests {
1631        ($($test_fn:ident),*) => {
1632            paste::paste! {
1633                $( #[test]
1634                    fn [<$test_fn _scalar_f64>]() {
1635                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1636                    }
1637                )*
1638                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1639                $( #[test]
1640                    fn [<$test_fn _avx2_f64>]() {
1641                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1642                    }
1643                    #[test]
1644                    fn [<$test_fn _avx512_f64>]() {
1645                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1646                    }
1647                )*
1648            }
1649        }
1650    }
1651
1652    #[cfg(debug_assertions)]
1653    fn check_jsa_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1654        skip_if_unsupported!(kernel, test_name);
1655
1656        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1657        let candles = read_candles_from_csv(file_path)?;
1658
1659        let test_periods = vec![2, 5, 10, 14, 20, 30, 50, 100, 200];
1660        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1661
1662        for period in &test_periods {
1663            for source in &test_sources {
1664                let input = JsaInput::from_candles(
1665                    &candles,
1666                    source,
1667                    JsaParams {
1668                        period: Some(*period),
1669                    },
1670                );
1671                let output = jsa_with_kernel(&input, kernel)?;
1672
1673                for (i, &val) in output.values.iter().enumerate() {
1674                    if val.is_nan() {
1675                        continue;
1676                    }
1677
1678                    let bits = val.to_bits();
1679
1680                    if bits == 0x11111111_11111111 {
1681                        panic!(
1682                            "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1683                            test_name, val, bits, i, period, source
1684                        );
1685                    }
1686
1687                    if bits == 0x22222222_22222222 {
1688                        panic!(
1689                            "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1690                            test_name, val, bits, i, period, source
1691                        );
1692                    }
1693
1694                    if bits == 0x33333333_33333333 {
1695                        panic!(
1696                            "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1697                            test_name, val, bits, i, period, source
1698                        );
1699                    }
1700                }
1701            }
1702        }
1703
1704        Ok(())
1705    }
1706
1707    #[cfg(not(debug_assertions))]
1708    fn check_jsa_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1709        Ok(())
1710    }
1711
1712    generate_all_jsa_tests!(
1713        check_jsa_partial_params,
1714        check_jsa_accuracy,
1715        check_jsa_zero_period,
1716        check_jsa_period_exceeds_length,
1717        check_jsa_very_small_dataset,
1718        check_jsa_reinput,
1719        check_jsa_streaming,
1720        check_jsa_no_poison
1721    );
1722
1723    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1724        skip_if_unsupported!(kernel, test);
1725
1726        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1727        let c = read_candles_from_csv(file)?;
1728
1729        let output = JsaBatchBuilder::new()
1730            .kernel(kernel)
1731            .apply_candles(&c, "close")?;
1732
1733        let def = JsaParams::default();
1734        let row = output.values_for(&def).expect("default row missing");
1735        assert_eq!(row.len(), c.close.len());
1736
1737        let expected = [61640.0, 61418.0, 61240.0, 61060.5, 60889.5];
1738        let start = row.len() - 5;
1739        for (i, &v) in row[start..].iter().enumerate() {
1740            assert!(
1741                (v - expected[i]).abs() < 1e-5,
1742                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1743            );
1744        }
1745        Ok(())
1746    }
1747
1748    macro_rules! gen_batch_tests {
1749        ($fn_name:ident) => {
1750            paste::paste! {
1751                #[test] fn [<$fn_name _scalar>]()      {
1752                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1753                }
1754                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1755                #[test] fn [<$fn_name _avx2>]()        {
1756                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1757                }
1758                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1759                #[test] fn [<$fn_name _avx512>]()      {
1760                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1761                }
1762                #[test] fn [<$fn_name _auto_detect>]() {
1763                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1764                }
1765            }
1766        };
1767    }
1768
1769    #[cfg(debug_assertions)]
1770    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1771        skip_if_unsupported!(kernel, test);
1772
1773        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1774        let c = read_candles_from_csv(file)?;
1775
1776        let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1777
1778        for source in &test_sources {
1779            let output = JsaBatchBuilder::new()
1780                .kernel(kernel)
1781                .period_range(2, 200, 3)
1782                .apply_candles(&c, source)?;
1783
1784            for (idx, &val) in output.values.iter().enumerate() {
1785                if val.is_nan() {
1786                    continue;
1787                }
1788
1789                let bits = val.to_bits();
1790                let row = idx / output.cols;
1791                let col = idx % output.cols;
1792
1793                if bits == 0x11111111_11111111 {
1794                    panic!(
1795                        "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1796                        test, val, bits, row, col, idx, source
1797                    );
1798                }
1799
1800                if bits == 0x22222222_22222222 {
1801                    panic!(
1802                        "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1803                        test, val, bits, row, col, idx, source
1804                    );
1805                }
1806
1807                if bits == 0x33333333_33333333 {
1808                    panic!(
1809                        "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1810                        test, val, bits, row, col, idx, source
1811                    );
1812                }
1813            }
1814        }
1815
1816        let edge_case_ranges = vec![(2, 5, 1), (190, 200, 2), (50, 100, 10)];
1817        for (start, end, step) in edge_case_ranges {
1818            let output = JsaBatchBuilder::new()
1819                .kernel(kernel)
1820                .period_range(start, end, step)
1821                .apply_candles(&c, "close")?;
1822
1823            for (idx, &val) in output.values.iter().enumerate() {
1824                if val.is_nan() {
1825                    continue;
1826                }
1827
1828                let bits = val.to_bits();
1829                let row = idx / output.cols;
1830                let col = idx % output.cols;
1831
1832                if bits == 0x11111111_11111111
1833                    || bits == 0x22222222_22222222
1834                    || bits == 0x33333333_33333333
1835                {
1836                    panic!(
1837						"[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1838						test, val, bits, row, col, start, end, step
1839					);
1840                }
1841            }
1842        }
1843
1844        Ok(())
1845    }
1846
1847    #[cfg(not(debug_assertions))]
1848    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1849        Ok(())
1850    }
1851
1852    gen_batch_tests!(check_batch_default_row);
1853    gen_batch_tests!(check_batch_no_poison);
1854
1855    #[cfg(feature = "proptest")]
1856    #[allow(clippy::float_cmp)]
1857    fn check_jsa_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1858        use proptest::prelude::*;
1859        skip_if_unsupported!(kernel, test_name);
1860
1861        let strat = (1usize..=100).prop_flat_map(|period| {
1862            (
1863                prop::collection::vec(
1864                    (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1865                    period..400,
1866                ),
1867                Just(period),
1868            )
1869        });
1870
1871        proptest::test_runner::TestRunner::default().run(&strat, |(data, period)| {
1872            let params = JsaParams {
1873                period: Some(period),
1874            };
1875            let input = JsaInput::from_slice(&data, params);
1876
1877            let JsaOutput { values: out } = jsa_with_kernel(&input, kernel).unwrap();
1878
1879            let JsaOutput { values: ref_out } = jsa_with_kernel(&input, Kernel::Scalar).unwrap();
1880
1881            let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1882            let warmup_period = first_valid + period;
1883
1884            for i in 0..warmup_period.min(data.len()) {
1885                prop_assert!(
1886                    out[i].is_nan(),
1887                    "[{}] Expected NaN during warmup at index {}, got {}",
1888                    test_name,
1889                    i,
1890                    out[i]
1891                );
1892            }
1893
1894            for i in warmup_period..data.len() {
1895                let expected = (data[i] + data[i - period]) * 0.5;
1896                let actual = out[i];
1897
1898                prop_assert!(
1899                    (actual - expected).abs() < 1e-9,
1900                    "[{}] Formula verification failed at index {}: expected {}, got {}, diff = {}",
1901                    test_name,
1902                    i,
1903                    expected,
1904                    actual,
1905                    (actual - expected).abs()
1906                );
1907            }
1908
1909            for i in warmup_period..data.len() {
1910                let val1 = data[i];
1911                let val2 = data[i - period];
1912                let min_val = val1.min(val2);
1913                let max_val = val1.max(val2);
1914                let actual = out[i];
1915
1916                prop_assert!(
1917                    actual >= min_val - 1e-9 && actual <= max_val + 1e-9,
1918                    "[{}] Output bounds check failed at index {}: {} not in [{}, {}]",
1919                    test_name,
1920                    i,
1921                    actual,
1922                    min_val,
1923                    max_val
1924                );
1925            }
1926
1927            if kernel != Kernel::Scalar {
1928                for i in 0..data.len() {
1929                    let y = out[i];
1930                    let r = ref_out[i];
1931
1932                    if y.is_nan() && r.is_nan() {
1933                        continue;
1934                    }
1935
1936                    let y_bits = y.to_bits();
1937                    let r_bits = r.to_bits();
1938                    prop_assert!(
1939                        y_bits == r_bits,
1940                        "[{}] Cross-kernel mismatch at index {}: {} ({:016X}) != {} ({:016X})",
1941                        test_name,
1942                        i,
1943                        y,
1944                        y_bits,
1945                        r,
1946                        r_bits
1947                    );
1948                }
1949            }
1950
1951            if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9) && !data.is_empty() {
1952                let constant = data[first_valid];
1953                for i in warmup_period..data.len() {
1954                    prop_assert!(
1955                        (out[i] - constant).abs() < 1e-9,
1956                        "[{}] Constant data test failed at index {}: expected {}, got {}",
1957                        test_name,
1958                        i,
1959                        constant,
1960                        out[i]
1961                    );
1962                }
1963            }
1964
1965            let is_monotonic_inc = data.windows(2).all(|w| w[1] >= w[0] - 1e-12);
1966            if is_monotonic_inc && warmup_period + 1 < data.len() {
1967                for i in (warmup_period + 1)..data.len() {
1968                    prop_assert!(
1969                        out[i] >= out[i - 1] - 1e-9,
1970                        "[{}] Monotonic test failed at index {}: {} < {}",
1971                        test_name,
1972                        i,
1973                        out[i],
1974                        out[i - 1]
1975                    );
1976                }
1977            }
1978
1979            if period == 1 && warmup_period < data.len() {
1980                for i in warmup_period..data.len() {
1981                    let expected = (data[i] + data[i - 1]) * 0.5;
1982                    let actual = out[i];
1983                    prop_assert!(
1984                        (actual - expected).abs() < 1e-9,
1985                        "[{}] Period=1 test failed at index {}: expected {}, got {}",
1986                        test_name,
1987                        i,
1988                        expected,
1989                        actual
1990                    );
1991                }
1992            }
1993
1994            #[cfg(debug_assertions)]
1995            {
1996                for (i, &val) in out.iter().enumerate() {
1997                    if val.is_nan() {
1998                        continue;
1999                    }
2000
2001                    let bits = val.to_bits();
2002
2003                    prop_assert!(
2004                        bits != 0x11111111_11111111,
2005                        "[{}] Found alloc_with_nan_prefix poison at index {}",
2006                        test_name,
2007                        i
2008                    );
2009                    prop_assert!(
2010                        bits != 0x22222222_22222222,
2011                        "[{}] Found init_matrix_prefixes poison at index {}",
2012                        test_name,
2013                        i
2014                    );
2015                    prop_assert!(
2016                        bits != 0x33333333_33333333,
2017                        "[{}] Found make_uninit_matrix poison at index {}",
2018                        test_name,
2019                        i
2020                    );
2021                }
2022            }
2023
2024            Ok(())
2025        })?;
2026
2027        Ok(())
2028    }
2029
2030    #[cfg(feature = "proptest")]
2031    generate_all_jsa_tests!(check_jsa_property);
2032
2033    #[test]
2034    fn test_jsa_into_matches_api() -> Result<(), Box<dyn Error>> {
2035        let mut data = Vec::with_capacity(256);
2036        for _ in 0..8 {
2037            data.push(f64::NAN);
2038        }
2039        for i in 0..248u64 {
2040            let x = (i as f64).sin() * 3.14159 + (i as f64) * 0.01;
2041            data.push(x);
2042        }
2043
2044        let input = JsaInput::from_slice(&data, JsaParams::default());
2045
2046        let base = jsa(&input)?.values;
2047
2048        let mut out = vec![0.0; data.len()];
2049        jsa_into(&input, &mut out)?;
2050
2051        assert_eq!(base.len(), out.len(), "lengths must match");
2052
2053        for (i, (&a, &b)) in base.iter().zip(out.iter()).enumerate() {
2054            let both_nan = a.is_nan() && b.is_nan();
2055            assert!(both_nan || a == b, "mismatch at idx {}: {} vs {}", i, a, b);
2056        }
2057        Ok(())
2058    }
2059}