Skip to main content

vector_ta/indicators/
cora_wave.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23
24#[cfg(not(target_arch = "wasm32"))]
25use rayon::prelude::*;
26
27use std::convert::AsRef;
28use std::error::Error;
29use std::mem::MaybeUninit;
30use thiserror::Error;
31
32#[cfg(all(feature = "python", feature = "cuda"))]
33use crate::cuda::cuda_available;
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::moving_averages::CudaCoraWave;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::moving_averages::DeviceArrayF32;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
40
41impl<'a> AsRef<[f64]> for CoraWaveInput<'a> {
42    #[inline(always)]
43    fn as_ref(&self) -> &[f64] {
44        match &self.data {
45            CoraWaveData::Slice(slice) => slice,
46            CoraWaveData::Candles { candles, source } => source_type(candles, source),
47        }
48    }
49}
50
51#[derive(Debug, Clone)]
52pub enum CoraWaveData<'a> {
53    Candles {
54        candles: &'a Candles,
55        source: &'a str,
56    },
57    Slice(&'a [f64]),
58}
59
60#[derive(Debug, Clone)]
61pub struct CoraWaveOutput {
62    pub values: Vec<f64>,
63}
64
65#[derive(Debug, Clone)]
66#[cfg_attr(
67    all(target_arch = "wasm32", feature = "wasm"),
68    derive(Serialize, Deserialize)
69)]
70pub struct CoraWaveParams {
71    pub period: Option<usize>,
72    pub r_multi: Option<f64>,
73    pub smooth: Option<bool>,
74}
75
76impl Default for CoraWaveParams {
77    fn default() -> Self {
78        Self {
79            period: Some(20),
80            r_multi: Some(2.0),
81            smooth: Some(true),
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub struct CoraWaveInput<'a> {
88    pub data: CoraWaveData<'a>,
89    pub params: CoraWaveParams,
90}
91
92impl<'a> CoraWaveInput<'a> {
93    #[inline]
94    pub fn from_candles(c: &'a Candles, s: &'a str, p: CoraWaveParams) -> Self {
95        Self {
96            data: CoraWaveData::Candles {
97                candles: c,
98                source: s,
99            },
100            params: p,
101        }
102    }
103
104    #[inline]
105    pub fn from_slice(sl: &'a [f64], p: CoraWaveParams) -> Self {
106        Self {
107            data: CoraWaveData::Slice(sl),
108            params: p,
109        }
110    }
111
112    #[inline]
113    pub fn with_default_candles(c: &'a Candles) -> Self {
114        Self::from_candles(c, "close", CoraWaveParams::default())
115    }
116
117    #[inline]
118    pub fn get_period(&self) -> usize {
119        self.params.period.unwrap_or(20)
120    }
121
122    #[inline]
123    pub fn get_r_multi(&self) -> f64 {
124        self.params.r_multi.unwrap_or(2.0)
125    }
126
127    #[inline]
128    pub fn get_smooth(&self) -> bool {
129        self.params.smooth.unwrap_or(true)
130    }
131}
132
133#[derive(Copy, Clone, Debug)]
134pub struct CoraWaveBuilder {
135    period: Option<usize>,
136    r_multi: Option<f64>,
137    smooth: Option<bool>,
138    kernel: Kernel,
139}
140
141impl Default for CoraWaveBuilder {
142    fn default() -> Self {
143        Self {
144            period: None,
145            r_multi: None,
146            smooth: None,
147            kernel: Kernel::Auto,
148        }
149    }
150}
151
152impl CoraWaveBuilder {
153    #[inline(always)]
154    pub fn new() -> Self {
155        Self::default()
156    }
157
158    #[inline(always)]
159    pub fn period(mut self, val: usize) -> Self {
160        self.period = Some(val);
161        self
162    }
163
164    #[inline(always)]
165    pub fn r_multi(mut self, val: f64) -> Self {
166        self.r_multi = Some(val);
167        self
168    }
169
170    #[inline(always)]
171    pub fn smooth(mut self, val: bool) -> Self {
172        self.smooth = Some(val);
173        self
174    }
175
176    #[inline(always)]
177    pub fn kernel(mut self, k: Kernel) -> Self {
178        self.kernel = k;
179        self
180    }
181
182    #[inline(always)]
183    pub fn apply(self, c: &Candles) -> Result<CoraWaveOutput, CoraWaveError> {
184        let p = CoraWaveParams {
185            period: self.period,
186            r_multi: self.r_multi,
187            smooth: self.smooth,
188        };
189        let i = CoraWaveInput::from_candles(c, "close", p);
190        cora_wave_with_kernel(&i, self.kernel)
191    }
192
193    #[inline(always)]
194    pub fn apply_slice(self, d: &[f64]) -> Result<CoraWaveOutput, CoraWaveError> {
195        let p = CoraWaveParams {
196            period: self.period,
197            r_multi: self.r_multi,
198            smooth: self.smooth,
199        };
200        let i = CoraWaveInput::from_slice(d, p);
201        cora_wave_with_kernel(&i, self.kernel)
202    }
203
204    #[inline(always)]
205    pub fn into_stream(self) -> Result<CoraWaveStream, CoraWaveError> {
206        let p = CoraWaveParams {
207            period: self.period,
208            r_multi: self.r_multi,
209            smooth: self.smooth,
210        };
211        CoraWaveStream::try_new(p)
212    }
213}
214
215#[derive(Debug, Error)]
216pub enum CoraWaveError {
217    #[error("cora_wave: Input data slice is empty.")]
218    EmptyInputData,
219
220    #[error("cora_wave: All values are NaN.")]
221    AllValuesNaN,
222
223    #[error("cora_wave: Invalid period: period = {period}, data length = {data_len}")]
224    InvalidPeriod { period: usize, data_len: usize },
225
226    #[error("cora_wave: Not enough valid data: needed = {needed}, valid = {valid}")]
227    NotEnoughValidData { needed: usize, valid: usize },
228
229    #[error("cora_wave: Invalid r_multi: {value}")]
230    InvalidRMulti { value: f64 },
231
232    #[error("cora_wave: Output length mismatch: expected = {expected}, got = {got}")]
233    OutputLengthMismatch { expected: usize, got: usize },
234
235    #[error("cora_wave: Invalid range: start = {start}, end = {end}, step = {step}")]
236    InvalidRange { start: f64, end: f64, step: f64 },
237
238    #[error("cora_wave: invalid kernel for batch: {0:?}")]
239    InvalidKernelForBatch(Kernel),
240
241    #[error("cora_wave: invalid input: {0}")]
242    InvalidInput(String),
243}
244
245#[inline]
246pub fn cora_wave(input: &CoraWaveInput) -> Result<CoraWaveOutput, CoraWaveError> {
247    cora_wave_with_kernel(input, Kernel::Auto)
248}
249
250pub fn cora_wave_with_kernel(
251    input: &CoraWaveInput,
252    kernel: Kernel,
253) -> Result<CoraWaveOutput, CoraWaveError> {
254    let (data, weights, inv_wsum, smooth_period, first, chosen) = cora_wave_prepare(input, kernel)?;
255    let period = weights.len();
256    let warm = first + period - 1 + smooth_period.saturating_sub(1);
257
258    let mut out = alloc_with_nan_prefix(data.len(), warm);
259    cora_wave_compute_into(
260        data,
261        &weights,
262        inv_wsum,
263        smooth_period,
264        first,
265        chosen,
266        &mut out,
267    );
268    Ok(CoraWaveOutput { values: out })
269}
270
271#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
272#[inline]
273pub fn cora_wave_into(input: &CoraWaveInput, out: &mut [f64]) -> Result<(), CoraWaveError> {
274    let (data, weights, inv_wsum, smooth_period, first, chosen) =
275        cora_wave_prepare(input, Kernel::Auto)?;
276
277    if out.len() != data.len() {
278        return Err(CoraWaveError::OutputLengthMismatch {
279            expected: data.len(),
280            got: out.len(),
281        });
282    }
283
284    let warm = first + weights.len() - 1 + smooth_period.saturating_sub(1);
285    let warm = warm.min(out.len());
286    if warm > 0 {
287        let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
288        for v in &mut out[..warm] {
289            *v = qnan;
290        }
291    }
292
293    cora_wave_compute_into(data, &weights, inv_wsum, smooth_period, first, chosen, out);
294    Ok(())
295}
296
297#[inline]
298pub fn cora_wave_into_slice(
299    dst: &mut [f64],
300    input: &CoraWaveInput,
301    kern: Kernel,
302) -> Result<(), CoraWaveError> {
303    let (data, weights, inv_wsum, smooth_period, first, chosen) = cora_wave_prepare(input, kern)?;
304    if dst.len() != data.len() {
305        return Err(CoraWaveError::OutputLengthMismatch {
306            expected: data.len(),
307            got: dst.len(),
308        });
309    }
310    cora_wave_compute_into(data, &weights, inv_wsum, smooth_period, first, chosen, dst);
311
312    let warm = first + weights.len() - 1 + smooth_period.saturating_sub(1);
313    for v in &mut dst[..warm] {
314        *v = f64::NAN;
315    }
316    Ok(())
317}
318
319#[inline(always)]
320fn cora_wave_prepare<'a>(
321    input: &'a CoraWaveInput,
322    kernel: Kernel,
323) -> Result<(&'a [f64], Vec<f64>, f64, usize, usize, Kernel), CoraWaveError> {
324    let data: &[f64] = input.as_ref();
325    let len = data.len();
326    if len == 0 {
327        return Err(CoraWaveError::EmptyInputData);
328    }
329
330    let first = data
331        .iter()
332        .position(|x| !x.is_nan())
333        .ok_or(CoraWaveError::AllValuesNaN)?;
334
335    let period = input.get_period();
336    let r_multi = input.get_r_multi();
337    let smooth = input.get_smooth();
338
339    if period == 0 || period > len {
340        return Err(CoraWaveError::InvalidPeriod {
341            period,
342            data_len: len,
343        });
344    }
345    if len - first < period {
346        return Err(CoraWaveError::NotEnoughValidData {
347            needed: period,
348            valid: len - first,
349        });
350    }
351    if r_multi < 0.0 || !r_multi.is_finite() {
352        return Err(CoraWaveError::InvalidRMulti { value: r_multi });
353    }
354
355    let mut weights = Vec::with_capacity(period);
356    let inv_sum: f64;
357    if period == 1 {
358        weights.push(1.0);
359        inv_sum = 1.0;
360    } else {
361        let start_wt = 0.01;
362        let end_wt = period as f64;
363        let r = (end_wt / start_wt).powf(1.0 / (period as f64 - 1.0)) - 1.0;
364        let base = 1.0 + r * r_multi;
365
366        let mut sum = 0.0;
367
368        let mut w = start_wt * base;
369        for _ in 0..period {
370            weights.push(w);
371            sum += w;
372            w *= base;
373        }
374        inv_sum = 1.0 / sum;
375    }
376
377    let smooth_period = if smooth {
378        ((period as f64).sqrt().round() as usize).max(1)
379    } else {
380        1
381    };
382    let chosen = match kernel {
383        Kernel::Auto => detect_best_kernel(),
384        k => k,
385    };
386
387    Ok((data, weights, inv_sum, smooth_period, first, chosen))
388}
389
390#[inline(always)]
391fn cora_wave_compute_into(
392    data: &[f64],
393    weights: &[f64],
394    inv_wsum: f64,
395    smooth_period: usize,
396    first: usize,
397    kernel: Kernel,
398    out: &mut [f64],
399) {
400    unsafe {
401        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
402        {
403            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
404                cora_wave_scalar_with_weights(data, weights, inv_wsum, smooth_period, first, out);
405                return;
406            }
407        }
408        match kernel {
409            Kernel::Scalar | Kernel::ScalarBatch => {
410                cora_wave_scalar_with_weights(data, weights, inv_wsum, smooth_period, first, out)
411            }
412            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
413            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
414                cora_wave_scalar_with_weights(data, weights, inv_wsum, smooth_period, first, out)
415            }
416            _ => unreachable!(),
417        }
418    }
419}
420
421#[inline]
422pub fn cora_wave_scalar_with_weights(
423    data: &[f64],
424    weights: &[f64],
425    inv_wsum: f64,
426    smooth_period: usize,
427    first_val: usize,
428    out: &mut [f64],
429) {
430    let n = data.len();
431    let p = weights.len();
432    if p == 0 || n == 0 {
433        return;
434    }
435
436    if smooth_period == 1 {
437        if p == 1 {
438            let start = first_val;
439            if start < n {
440                for i in start..n {
441                    unsafe {
442                        let v = *data.get_unchecked(i);
443                        *out.get_unchecked_mut(i) = v * inv_wsum;
444                    }
445                }
446            }
447            return;
448        }
449
450        let w0 = unsafe { *weights.get_unchecked(0) };
451        let w1 = unsafe { *weights.get_unchecked(1) };
452        let inv_R = w0 / w1;
453        let a_old = w0 * inv_R;
454        let w_last = unsafe { *weights.get_unchecked(p - 1) };
455
456        let warm0 = first_val + p - 1;
457        if warm0 >= n {
458            return;
459        }
460        let start0 = warm0 + 1 - p;
461
462        let mut acc0 = 0.0;
463        let mut acc1 = 0.0;
464        let mut acc2 = 0.0;
465        let mut acc3 = 0.0;
466        let mut j = 0usize;
467        let end4 = p & !3usize;
468
469        unsafe {
470            let xptr = data.as_ptr().add(start0);
471            let wptr = weights.as_ptr();
472            while j < end4 {
473                let x0 = *xptr.add(j);
474                let x1 = *xptr.add(j + 1);
475                let x2 = *xptr.add(j + 2);
476                let x3 = *xptr.add(j + 3);
477
478                let y0 = *wptr.add(j);
479                let y1 = *wptr.add(j + 1);
480                let y2 = *wptr.add(j + 2);
481                let y3 = *wptr.add(j + 3);
482
483                acc0 = x0.mul_add(y0, acc0);
484                acc1 = x1.mul_add(y1, acc1);
485                acc2 = x2.mul_add(y2, acc2);
486                acc3 = x3.mul_add(y3, acc3);
487
488                j += 4;
489            }
490            let mut S = (acc0 + acc1) + (acc2 + acc3);
491            while j < p {
492                let x = *xptr.add(j);
493                let y = *wptr.add(j);
494                S = x.mul_add(y, S);
495                j += 1;
496            }
497
498            *out.get_unchecked_mut(warm0) = S * inv_wsum;
499
500            let mut i = warm0;
501            while i + 1 < n {
502                let x_old = *data.get_unchecked(i + 1 - p);
503                let x_new = *data.get_unchecked(i + 1);
504                S = (S * inv_R) - a_old * x_old + w_last * x_new;
505                *out.get_unchecked_mut(i + 1) = S * inv_wsum;
506                i += 1;
507            }
508        }
509        return;
510    }
511
512    let m = smooth_period;
513    let wma_sum = (m as f64) * ((m as f64) + 1.0) * 0.5;
514
515    if p == 1 {
516        let warm0 = first_val;
517        if warm0 >= n {
518            return;
519        }
520
521        let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
522        let mut head = 0usize;
523        let warm_total = warm0 + m - 1;
524        unsafe {
525            for i in warm0..n {
526                ring_mu
527                    .get_unchecked_mut(head)
528                    .write(*data.get_unchecked(i));
529                head = (head + 1) % m;
530
531                if i >= warm_total {
532                    let mut acc = 0.0;
533                    for k in 0..m {
534                        let idx = (head + k) % m;
535                        let v = *ring_mu.get_unchecked(idx).assume_init_ref();
536                        acc += v * ((k + 1) as f64);
537                    }
538                    *out.get_unchecked_mut(i) = acc / wma_sum;
539                }
540            }
541        }
542        return;
543    }
544
545    let w0 = unsafe { *weights.get_unchecked(0) };
546    let w1 = unsafe { *weights.get_unchecked(1) };
547    let inv_R = w0 / w1;
548    let a_old = w0 * inv_R;
549    let w_last = unsafe { *weights.get_unchecked(p - 1) };
550
551    let warm0 = first_val + p - 1;
552    if warm0 >= n {
553        return;
554    }
555    let start0 = warm0 + 1 - p;
556
557    let mut acc0 = 0.0;
558    let mut acc1 = 0.0;
559    let mut acc2 = 0.0;
560    let mut acc3 = 0.0;
561    let mut j = 0usize;
562    let end4 = p & !3usize;
563
564    unsafe {
565        let xptr = data.as_ptr().add(start0);
566        let wptr = weights.as_ptr();
567        while j < end4 {
568            let x0 = *xptr.add(j);
569            let x1 = *xptr.add(j + 1);
570            let x2 = *xptr.add(j + 2);
571            let x3 = *xptr.add(j + 3);
572
573            let y0 = *wptr.add(j);
574            let y1 = *wptr.add(j + 1);
575            let y2 = *wptr.add(j + 2);
576            let y3 = *wptr.add(j + 3);
577
578            acc0 = x0.mul_add(y0, acc0);
579            acc1 = x1.mul_add(y1, acc1);
580            acc2 = x2.mul_add(y2, acc2);
581            acc3 = x3.mul_add(y3, acc3);
582
583            j += 4;
584        }
585        let mut S = (acc0 + acc1) + (acc2 + acc3);
586        while j < p {
587            let x = *xptr.add(j);
588            let y = *wptr.add(j);
589            S = x.mul_add(y, S);
590            j += 1;
591        }
592
593        let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
594        let mut fill = 0usize;
595
596        let mut y = S * inv_wsum;
597        ring_mu.get_unchecked_mut(fill).write(y);
598        fill += 1;
599
600        let warm_total = warm0 + m - 1;
601        let mut i = warm0;
602        while i + 1 <= warm_total && i + 1 < n {
603            let x_old = *data.get_unchecked(i + 1 - p);
604            let x_new = *data.get_unchecked(i + 1);
605            S = (S * inv_R) - a_old * x_old + w_last * x_new;
606            y = S * inv_wsum;
607            ring_mu.get_unchecked_mut(fill).write(y);
608            fill += 1;
609            i += 1;
610        }
611        if warm_total >= n {
612            return;
613        }
614
615        let mut head = 0usize;
616
617        {
618            let mut acc = 0.0;
619            for k in 0..m {
620                let idx = (head + k) % m;
621                let v = *ring_mu.get_unchecked(idx).assume_init_ref();
622                acc += v * ((k + 1) as f64);
623            }
624            *out.get_unchecked_mut(warm_total) = acc / wma_sum;
625        }
626
627        while i + 1 < n {
628            let x_old = *data.get_unchecked(i + 1 - p);
629            let x_new = *data.get_unchecked(i + 1);
630            S = (S * inv_R) - a_old * x_old + w_last * x_new;
631            let y_new = S * inv_wsum;
632
633            ring_mu.get_unchecked_mut(head).write(y_new);
634            head = (head + 1) % m;
635
636            let mut acc = 0.0;
637            for k in 0..m {
638                let idx = (head + k) % m;
639                let v = *ring_mu.get_unchecked(idx).assume_init_ref();
640                acc += v * ((k + 1) as f64);
641            }
642            *out.get_unchecked_mut(i + 1) = acc / wma_sum;
643            i += 1;
644        }
645    }
646}
647
648#[inline]
649pub fn cora_wave_scalar(
650    data: &[f64],
651    period: usize,
652    r_multi: f64,
653    smooth_period: usize,
654    first_val: usize,
655    out: &mut [f64],
656) {
657    if period == 1 {
658        cora_wave_scalar_with_weights(data, &[1.0], 1.0, smooth_period, first_val, out);
659        return;
660    }
661    let start_wt = 0.01;
662    let end_wt = period as f64;
663    let r = (end_wt / start_wt).powf(1.0 / (period as f64 - 1.0)) - 1.0;
664    let base = 1.0 + r * r_multi;
665
666    let mut weights = Vec::with_capacity(period);
667    let mut weight_sum = 0.0;
668    for j in 0..period {
669        let w = start_wt * base.powi((j + 1) as i32);
670        weights.push(w);
671        weight_sum += w;
672    }
673
674    cora_wave_scalar_with_weights(
675        data,
676        &weights,
677        1.0 / weight_sum,
678        smooth_period,
679        first_val,
680        out,
681    );
682}
683
684#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
685#[inline]
686unsafe fn cora_wave_simd128(
687    data: &[f64],
688    period: usize,
689    r_multi: f64,
690    smooth_period: usize,
691    first_val: usize,
692    out: &mut [f64],
693) {
694    use core::arch::wasm32::*;
695
696    cora_wave_scalar(data, period, r_multi, smooth_period, first_val, out);
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[target_feature(enable = "avx2,fma")]
701unsafe fn cora_wave_avx2(
702    data: &[f64],
703    period: usize,
704    r_multi: f64,
705    smooth_period: usize,
706    first_val: usize,
707    out: &mut [f64],
708) {
709    cora_wave_scalar(data, period, r_multi, smooth_period, first_val, out);
710}
711
712#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713#[target_feature(enable = "avx512f,fma")]
714unsafe fn cora_wave_avx512(
715    data: &[f64],
716    period: usize,
717    r_multi: f64,
718    smooth_period: usize,
719    first_val: usize,
720    out: &mut [f64],
721) {
722    cora_wave_scalar(data, period, r_multi, smooth_period, first_val, out);
723}
724
725#[derive(Debug, Clone)]
726pub struct CoraWaveStream {
727    period: usize,
728    r_multi: f64,
729    smooth: bool,
730    smooth_period: usize,
731
732    base: f64,
733    inv_R: f64,
734    a_old: f64,
735    w_last: f64,
736    inv_wsum: f64,
737
738    ring_x: Vec<f64>,
739    head_x: usize,
740    idx: usize,
741    have_S: bool,
742    S: f64,
743
744    m: usize,
745    wma_sum: f64,
746    ring_y: Vec<f64>,
747    head_y: usize,
748    y_count: usize,
749
750    Ssum_y: f64,
751    Wsum_y: f64,
752
753    fast_smooth: bool,
754}
755
756impl CoraWaveStream {
757    #[inline]
758    pub fn try_new(params: CoraWaveParams) -> Result<Self, CoraWaveError> {
759        let period = params.period.unwrap_or(20);
760        let r_multi = params.r_multi.unwrap_or(2.0);
761        let smooth = params.smooth.unwrap_or(true);
762
763        if period == 0 {
764            return Err(CoraWaveError::InvalidPeriod {
765                period,
766                data_len: 0,
767            });
768        }
769        if r_multi < 0.0 || !r_multi.is_finite() {
770            return Err(CoraWaveError::InvalidRMulti { value: r_multi });
771        }
772
773        let m = if smooth {
774            ((period as f64).sqrt().round() as usize).max(1)
775        } else {
776            1
777        };
778
779        let p = period;
780        let start_wt = 0.01_f64;
781
782        let end_wt = p as f64;
783        let r = (end_wt / start_wt).powf(1.0 / (p as f64 - 1.0)) - 1.0;
784        let base = 1.0 + r * r_multi;
785        let inv_R = 1.0 / base;
786
787        let a_old = start_wt;
788
789        let base_pow_p = if (base - 1.0).abs() < 1e-16 {
790            1.0
791        } else {
792            base.powi(p as i32)
793        };
794        let w_last = a_old * base_pow_p;
795
796        let weight_sum = if (base - 1.0).abs() < 1e-16 {
797            a_old * (p as f64)
798        } else {
799            a_old * base * (base_pow_p - 1.0) / (base - 1.0)
800        };
801        let inv_wsum = 1.0 / weight_sum;
802
803        let wma_sum = (m as f64) * ((m as f64) + 1.0) * 0.5;
804
805        const FAST_WMA_O1_DEFAULT: bool = false;
806
807        Ok(Self {
808            period: p,
809            r_multi,
810            smooth,
811            smooth_period: m,
812            base,
813            inv_R,
814            a_old,
815            w_last,
816            inv_wsum,
817            ring_x: vec![0.0; p],
818            head_x: 0,
819            idx: 0,
820            have_S: false,
821            S: 0.0,
822            m,
823            wma_sum,
824            ring_y: vec![0.0; m.max(1)],
825            head_y: 0,
826            y_count: 0,
827            Ssum_y: 0.0,
828            Wsum_y: 0.0,
829            fast_smooth: FAST_WMA_O1_DEFAULT,
830        })
831    }
832
833    #[inline]
834    pub fn update(&mut self, x_new: f64) -> Option<f64> {
835        let pos = self.head_x;
836        let x_old = self.ring_x[pos];
837        self.ring_x[pos] = x_new;
838        self.head_x = (pos + 1) % self.period;
839        self.idx += 1;
840
841        if !self.have_S {
842            if self.idx < self.period {
843                return None;
844            }
845
846            let mut S = 0.0;
847            let mut w = self.a_old * self.base;
848            let mut i = 0usize;
849            while i < self.period {
850                let xi = self.ring_x[(self.head_x + i) % self.period];
851                S = xi.mul_add(w, S);
852                w *= self.base;
853                i += 1;
854            }
855            self.S = S;
856            self.have_S = true;
857
858            let y = self.S * self.inv_wsum;
859            if self.m == 1 {
860                return Some(y);
861            }
862
863            self.ring_y[self.y_count] = y;
864            self.y_count += 1;
865
866            self.head_y = self.y_count % self.m;
867            if self.fast_smooth {
868                self.Ssum_y += y;
869                self.Wsum_y += (self.y_count as f64) * y;
870            }
871            return None;
872        }
873
874        self.S = (self.S * self.inv_R) - self.a_old * x_old + self.w_last * x_new;
875        let y = self.S * self.inv_wsum;
876
877        if self.m == 1 {
878            return Some(y);
879        }
880
881        if !self.fast_smooth {
882            if self.y_count < self.m {
883                self.ring_y[self.head_y] = y;
884                self.head_y = (self.head_y + 1) % self.m;
885                self.y_count += 1;
886                if self.y_count < self.m {
887                    return None;
888                }
889
890                let mut acc = 0.0;
891                let mut k = 0usize;
892                while k < self.m {
893                    let idx = (self.head_y + k) % self.m;
894                    let v = self.ring_y[idx];
895                    acc = v.mul_add((k + 1) as f64, acc);
896                    k += 1;
897                }
898                return Some(acc / self.wma_sum);
899            } else {
900                self.ring_y[self.head_y] = y;
901                self.head_y = (self.head_y + 1) % self.m;
902                let mut acc = 0.0;
903                let mut k = 0usize;
904                while k < self.m {
905                    let idx = (self.head_y + k) % self.m;
906                    let v = self.ring_y[idx];
907                    acc = v.mul_add((k + 1) as f64, acc);
908                    k += 1;
909                }
910                return Some(acc / self.wma_sum);
911            }
912        } else {
913            if self.y_count < self.m {
914                self.ring_y[self.y_count] = y;
915                self.y_count += 1;
916                self.Ssum_y += y;
917                self.Wsum_y += (self.y_count as f64) * y;
918                if self.y_count < self.m {
919                    return None;
920                }
921
922                self.head_y = 0;
923                return Some(self.Wsum_y / self.wma_sum);
924            }
925
926            let y_old = self.ring_y[self.head_y];
927
928            self.Wsum_y = self.Wsum_y - self.Ssum_y + (self.m as f64) * y;
929            self.ring_y[self.head_y] = y;
930            self.Ssum_y = self.Ssum_y + y - y_old;
931            self.head_y = (self.head_y + 1) % self.m;
932
933            Some(self.Wsum_y / self.wma_sum)
934        }
935    }
936}
937
938#[derive(Clone, Debug)]
939pub struct CoraWaveBatchRange {
940    pub period: (usize, usize, usize),
941    pub r_multi: (f64, f64, f64),
942    pub smooth: bool,
943}
944
945impl Default for CoraWaveBatchRange {
946    fn default() -> Self {
947        Self {
948            period: (20, 20, 0),
949            r_multi: (2.0, 2.249, 0.001),
950            smooth: true,
951        }
952    }
953}
954
955#[derive(Clone, Debug)]
956pub struct CoraWaveBatchOutput {
957    pub values: Vec<f64>,
958    pub combos: Vec<CoraWaveParams>,
959    pub rows: usize,
960    pub cols: usize,
961}
962
963impl CoraWaveBatchOutput {
964    pub fn row_for_params(&self, p: &CoraWaveParams) -> Option<usize> {
965        self.combos.iter().position(|c| {
966            c.period.unwrap_or(20) == p.period.unwrap_or(20)
967                && (c.r_multi.unwrap_or(2.0) - p.r_multi.unwrap_or(2.0)).abs() < 1e-12
968                && c.smooth.unwrap_or(true) == p.smooth.unwrap_or(true)
969        })
970    }
971
972    pub fn values_for(&self, p: &CoraWaveParams) -> Option<&[f64]> {
973        self.row_for_params(p).map(|row| {
974            let start = row * self.cols;
975            &self.values[start..start + self.cols]
976        })
977    }
978}
979
980#[inline(always)]
981fn axis_usize((s, e, t): (usize, usize, usize)) -> Result<Vec<usize>, CoraWaveError> {
982    if t == 0 || s == e {
983        return Ok(vec![s]);
984    }
985    let mut v = Vec::new();
986    if s < e {
987        v = (s..=e).step_by(t).collect();
988    } else if s > e {
989        let mut x = s;
990        loop {
991            v.push(x);
992            if x <= e {
993                break;
994            }
995            if x < t {
996                break;
997            }
998            let next = x - t;
999            if next < e {
1000                break;
1001            }
1002            x = next;
1003        }
1004
1005        if *v.last().unwrap_or(&s) != e && s >= e {}
1006    }
1007    if v.is_empty() {
1008        return Err(CoraWaveError::InvalidRange {
1009            start: s as f64,
1010            end: e as f64,
1011            step: t as f64,
1012        });
1013    }
1014    Ok(v)
1015}
1016#[inline(always)]
1017fn axis_f64((s, e, t): (f64, f64, f64)) -> Result<Vec<f64>, CoraWaveError> {
1018    if t.abs() < 1e-12 || (s - e).abs() < 1e-12 {
1019        return Ok(vec![s]);
1020    }
1021    let step = t.abs();
1022    let mut v = Vec::new();
1023    if s <= e {
1024        let mut x = s;
1025        while x <= e + 1e-12 {
1026            v.push(x);
1027            x += step;
1028        }
1029    } else {
1030        let mut x = s;
1031        while x >= e - 1e-12 {
1032            v.push(x);
1033            x -= step;
1034        }
1035    }
1036    if v.is_empty() {
1037        return Err(CoraWaveError::InvalidRange {
1038            start: s,
1039            end: e,
1040            step: t,
1041        });
1042    }
1043    Ok(v)
1044}
1045#[inline(always)]
1046fn expand_grid_cw(r: &CoraWaveBatchRange) -> Result<Vec<CoraWaveParams>, CoraWaveError> {
1047    let periods = axis_usize(r.period)?;
1048    let mults = axis_f64(r.r_multi)?;
1049    if periods.is_empty() || mults.is_empty() {
1050        return Err(CoraWaveError::InvalidRange {
1051            start: r.period.0 as f64,
1052            end: r.period.1 as f64,
1053            step: r.period.2 as f64,
1054        });
1055    }
1056    let cap = periods
1057        .len()
1058        .checked_mul(mults.len())
1059        .ok_or_else(|| CoraWaveError::InvalidInput("periods*mults overflow".into()))?;
1060    let mut out = Vec::with_capacity(cap);
1061    for &p in &periods {
1062        for &m in &mults {
1063            out.push(CoraWaveParams {
1064                period: Some(p),
1065                r_multi: Some(m),
1066                smooth: Some(r.smooth),
1067            });
1068        }
1069    }
1070    Ok(out)
1071}
1072
1073#[inline(always)]
1074pub fn cora_wave_batch_slice(
1075    data: &[f64],
1076    sweep: &CoraWaveBatchRange,
1077    kern: Kernel,
1078) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1079    cora_wave_batch_inner(data, sweep, kern, false)
1080}
1081
1082#[inline(always)]
1083pub fn cora_wave_batch_par_slice(
1084    data: &[f64],
1085    sweep: &CoraWaveBatchRange,
1086    kern: Kernel,
1087) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1088    cora_wave_batch_inner(data, sweep, kern, true)
1089}
1090
1091pub fn cora_wave_batch_with_kernel(
1092    data: &[f64],
1093    sweep: &CoraWaveBatchRange,
1094    k: Kernel,
1095) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1096    let kernel = match k {
1097        Kernel::Auto => detect_best_batch_kernel(),
1098        other if other.is_batch() => other,
1099        _ => return Err(CoraWaveError::InvalidKernelForBatch(k)),
1100    };
1101    let simd = match kernel {
1102        Kernel::Avx512Batch => Kernel::Avx512,
1103        Kernel::Avx2Batch => Kernel::Avx2,
1104        Kernel::ScalarBatch => Kernel::Scalar,
1105        _ => unreachable!(),
1106    };
1107    cora_wave_batch_par_slice(data, sweep, simd)
1108}
1109
1110#[inline(always)]
1111fn cora_wave_batch_inner(
1112    data: &[f64],
1113    sweep: &CoraWaveBatchRange,
1114    kern: Kernel,
1115    parallel: bool,
1116) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1117    let combos = expand_grid_cw(sweep)?;
1118    if combos.is_empty() {
1119        return Err(CoraWaveError::InvalidRange {
1120            start: sweep.period.0 as f64,
1121            end: sweep.period.1 as f64,
1122            step: sweep.period.2 as f64,
1123        });
1124    }
1125
1126    let cols = data.len();
1127    if cols == 0 {
1128        return Err(CoraWaveError::AllValuesNaN);
1129    }
1130
1131    let first = data
1132        .iter()
1133        .position(|x| !x.is_nan())
1134        .ok_or(CoraWaveError::AllValuesNaN)?;
1135    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1136    if cols - first < max_p {
1137        return Err(CoraWaveError::NotEnoughValidData {
1138            needed: max_p,
1139            valid: cols - first,
1140        });
1141    }
1142
1143    let rows = combos.len();
1144    let _total = rows
1145        .checked_mul(cols)
1146        .ok_or_else(|| CoraWaveError::InvalidInput("rows*cols overflow".into()))?;
1147    let mut buf_mu = make_uninit_matrix(rows, cols);
1148
1149    let warms: Vec<usize> = combos
1150        .iter()
1151        .map(|c| {
1152            let p = c.period.unwrap();
1153            let sp = if c.smooth.unwrap_or(true) {
1154                ((p as f64).sqrt().round() as usize).max(1)
1155            } else {
1156                1
1157            };
1158            first + p - 1 + sp.saturating_sub(1)
1159        })
1160        .collect();
1161    init_matrix_prefixes(&mut buf_mu, cols, &warms);
1162
1163    let flat_len = rows
1164        .checked_mul(max_p)
1165        .ok_or_else(|| CoraWaveError::InvalidInput("rows*max_period overflow".into()))?;
1166    let mut flat_w = vec![0.0f64; flat_len];
1167    let mut inv_sums = vec![0.0f64; rows];
1168
1169    for (row, prm) in combos.iter().enumerate() {
1170        let p = prm.period.unwrap();
1171        let r_multi = prm.r_multi.unwrap();
1172
1173        if p == 1 {
1174            flat_w[row * max_p] = 1.0;
1175            inv_sums[row] = 1.0;
1176        } else {
1177            let start_wt = 0.01;
1178            let end_wt = p as f64;
1179            let r = (end_wt / start_wt).powf(1.0 / (p as f64 - 1.0)) - 1.0;
1180            let base = 1.0 + r * r_multi;
1181
1182            let mut sum = 0.0;
1183            for j in 0..p {
1184                let w = start_wt * base.powi((j + 1) as i32);
1185                flat_w[row * max_p + j] = w;
1186                sum += w;
1187            }
1188            inv_sums[row] = 1.0 / sum;
1189        }
1190    }
1191
1192    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1193    let out_uninit: &mut [MaybeUninit<f64>] =
1194        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr(), guard.len()) };
1195
1196    let actual = match kern {
1197        Kernel::Auto => detect_best_batch_kernel(),
1198        k => k,
1199    };
1200
1201    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1202        let p = combos[row].period.unwrap();
1203        let sp = if combos[row].smooth.unwrap_or(true) {
1204            ((p as f64).sqrt().round() as usize).max(1)
1205        } else {
1206            1
1207        };
1208        let w_ptr = flat_w[row * max_p..].as_ptr();
1209        let inv = inv_sums[row];
1210
1211        let dst = unsafe { core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, cols) };
1212
1213        match actual {
1214            Kernel::Scalar
1215            | Kernel::ScalarBatch
1216            | Kernel::Avx2
1217            | Kernel::Avx2Batch
1218            | Kernel::Avx512
1219            | Kernel::Avx512Batch => unsafe {
1220                cora_wave_row_scalar_with_weights(data, first, p, w_ptr, inv, sp, dst)
1221            },
1222            _ => unreachable!(),
1223        }
1224    };
1225
1226    #[cfg(not(target_arch = "wasm32"))]
1227    {
1228        use rayon::prelude::*;
1229        if parallel {
1230            out_uninit
1231                .par_chunks_mut(cols)
1232                .enumerate()
1233                .for_each(|(row, slice)| do_row(row, slice));
1234        } else {
1235            for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1236                do_row(row, slice);
1237            }
1238        }
1239    }
1240    #[cfg(target_arch = "wasm32")]
1241    {
1242        for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1243            do_row(row, slice);
1244        }
1245    }
1246
1247    let values = unsafe {
1248        Vec::from_raw_parts(
1249            guard.as_mut_ptr() as *mut f64,
1250            guard.len(),
1251            guard.capacity(),
1252        )
1253    };
1254
1255    Ok(CoraWaveBatchOutput {
1256        values,
1257        combos,
1258        rows,
1259        cols,
1260    })
1261}
1262
1263#[inline(always)]
1264pub fn cora_wave_batch_inner_into(
1265    data: &[f64],
1266    sweep: &CoraWaveBatchRange,
1267    kern: Kernel,
1268    parallel: bool,
1269    out: &mut [f64],
1270) -> Result<Vec<CoraWaveParams>, CoraWaveError> {
1271    let combos = expand_grid_cw(sweep)?;
1272    if combos.is_empty() {
1273        return Err(CoraWaveError::InvalidRange {
1274            start: sweep.period.0 as f64,
1275            end: sweep.period.1 as f64,
1276            step: sweep.period.2 as f64,
1277        });
1278    }
1279
1280    let cols = data.len();
1281    let rows = combos.len();
1282    if cols == 0 {
1283        return Err(CoraWaveError::AllValuesNaN);
1284    }
1285    let expected = rows
1286        .checked_mul(cols)
1287        .ok_or_else(|| CoraWaveError::InvalidInput("rows*cols overflow".into()))?;
1288    if out.len() != expected {
1289        return Err(CoraWaveError::OutputLengthMismatch {
1290            expected,
1291            got: out.len(),
1292        });
1293    }
1294
1295    let first = data
1296        .iter()
1297        .position(|x| !x.is_nan())
1298        .ok_or(CoraWaveError::AllValuesNaN)?;
1299    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1300    if cols - first < max_p {
1301        return Err(CoraWaveError::NotEnoughValidData {
1302            needed: max_p,
1303            valid: cols - first,
1304        });
1305    }
1306
1307    let warms: Vec<usize> = combos
1308        .iter()
1309        .map(|c| {
1310            let p = c.period.unwrap();
1311            let sp = if c.smooth.unwrap_or(true) {
1312                ((p as f64).sqrt().round() as usize).max(1)
1313            } else {
1314                1
1315            };
1316            first + p - 1 + sp.saturating_sub(1)
1317        })
1318        .collect();
1319
1320    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1321        core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1322    };
1323    init_matrix_prefixes(out_mu, cols, &warms);
1324
1325    let flat_len = rows
1326        .checked_mul(max_p)
1327        .ok_or_else(|| CoraWaveError::InvalidInput("rows*max_period overflow".into()))?;
1328    let mut flat_w = vec![0.0f64; flat_len];
1329    let mut inv_sums = vec![0.0f64; rows];
1330    for (row, prm) in combos.iter().enumerate() {
1331        let p = prm.period.unwrap();
1332        let r_multi = prm.r_multi.unwrap();
1333
1334        if p == 1 {
1335            flat_w[row * max_p] = 1.0;
1336            inv_sums[row] = 1.0;
1337        } else {
1338            let start_wt = 0.01;
1339            let end_wt = p as f64;
1340            let r = (end_wt / start_wt).powf(1.0 / (p as f64 - 1.0)) - 1.0;
1341            let base = 1.0 + r * r_multi;
1342
1343            let mut sum = 0.0;
1344            for j in 0..p {
1345                let w = start_wt * base.powi((j + 1) as i32);
1346                flat_w[row * max_p + j] = w;
1347                sum += w;
1348            }
1349            inv_sums[row] = 1.0 / sum;
1350        }
1351    }
1352
1353    let actual = match kern {
1354        Kernel::Auto => detect_best_batch_kernel(),
1355        k => k,
1356    };
1357
1358    let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1359        let p = combos[row].period.unwrap();
1360        let sp = if combos[row].smooth.unwrap_or(true) {
1361            ((p as f64).sqrt().round() as usize).max(1)
1362        } else {
1363            1
1364        };
1365        let w_ptr = flat_w[row * max_p..].as_ptr();
1366        let inv = inv_sums[row];
1367
1368        let dst: &mut [f64] =
1369            unsafe { core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, cols) };
1370        match actual {
1371            Kernel::Scalar
1372            | Kernel::ScalarBatch
1373            | Kernel::Avx2
1374            | Kernel::Avx2Batch
1375            | Kernel::Avx512
1376            | Kernel::Avx512Batch => unsafe {
1377                cora_wave_row_scalar_with_weights(data, first, p, w_ptr, inv, sp, dst)
1378            },
1379            _ => unreachable!(),
1380        }
1381    };
1382
1383    #[cfg(not(target_arch = "wasm32"))]
1384    {
1385        if parallel {
1386            use rayon::prelude::*;
1387            out_mu
1388                .par_chunks_mut(cols)
1389                .enumerate()
1390                .for_each(|(row, slice)| do_row(row, slice));
1391        } else {
1392            for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1393                do_row(row, slice);
1394            }
1395        }
1396    }
1397    #[cfg(target_arch = "wasm32")]
1398    {
1399        for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1400            do_row(row, slice);
1401        }
1402    }
1403
1404    Ok(combos)
1405}
1406
1407#[inline(always)]
1408unsafe fn cora_wave_row_scalar_with_weights(
1409    data: &[f64],
1410    first: usize,
1411    period: usize,
1412    w_ptr: *const f64,
1413    inv_wsum: f64,
1414    smooth_period: usize,
1415    out: &mut [f64],
1416) {
1417    let n = data.len();
1418    let p = period;
1419    if p == 0 || n == 0 {
1420        return;
1421    }
1422
1423    if smooth_period == 1 {
1424        if p == 1 {
1425            let warm0 = first;
1426            let mut i = warm0;
1427            while i < n {
1428                *out.get_unchecked_mut(i) = *data.get_unchecked(i) * inv_wsum;
1429                i += 1;
1430            }
1431            return;
1432        }
1433
1434        let w0 = *w_ptr.add(0);
1435        let w1 = *w_ptr.add(1);
1436        let inv_R = w0 / w1;
1437        let a_old = w0 * inv_R;
1438        let w_last = *w_ptr.add(p - 1);
1439
1440        let warm0 = first + p - 1;
1441        if warm0 >= n {
1442            return;
1443        }
1444        let start0 = warm0 + 1 - p;
1445
1446        let mut acc0 = 0.0;
1447        let mut acc1 = 0.0;
1448        let mut acc2 = 0.0;
1449        let mut acc3 = 0.0;
1450        let mut j = 0usize;
1451        let end4 = p & !3usize;
1452        let xptr = data.as_ptr().add(start0);
1453
1454        while j < end4 {
1455            let x0 = *xptr.add(j);
1456            let x1 = *xptr.add(j + 1);
1457            let x2 = *xptr.add(j + 2);
1458            let x3 = *xptr.add(j + 3);
1459
1460            let y0 = *w_ptr.add(j);
1461            let y1 = *w_ptr.add(j + 1);
1462            let y2 = *w_ptr.add(j + 2);
1463            let y3 = *w_ptr.add(j + 3);
1464
1465            acc0 = x0.mul_add(y0, acc0);
1466            acc1 = x1.mul_add(y1, acc1);
1467            acc2 = x2.mul_add(y2, acc2);
1468            acc3 = x3.mul_add(y3, acc3);
1469
1470            j += 4;
1471        }
1472        let mut S = (acc0 + acc1) + (acc2 + acc3);
1473        while j < p {
1474            let x = *xptr.add(j);
1475            let y = *w_ptr.add(j);
1476            S = x.mul_add(y, S);
1477            j += 1;
1478        }
1479
1480        *out.get_unchecked_mut(warm0) = S * inv_wsum;
1481
1482        let mut i = warm0;
1483        while i + 1 < n {
1484            let x_old = *data.get_unchecked(i + 1 - p);
1485            let x_new = *data.get_unchecked(i + 1);
1486            S = (S * inv_R) - a_old * x_old + w_last * x_new;
1487            *out.get_unchecked_mut(i + 1) = S * inv_wsum;
1488            i += 1;
1489        }
1490        return;
1491    }
1492
1493    let m = smooth_period;
1494    let wma_sum = (m as f64) * ((m as f64) + 1.0) * 0.5;
1495
1496    if p == 1 {
1497        let warm0 = first;
1498        if warm0 >= n {
1499            return;
1500        }
1501
1502        let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
1503        let mut fill = 0usize;
1504
1505        let warm_total = warm0 + m - 1;
1506        let mut i = warm0;
1507        while i <= warm_total && i < n {
1508            ring_mu
1509                .get_unchecked_mut(fill)
1510                .write(*data.get_unchecked(i));
1511            fill += 1;
1512            i += 1;
1513        }
1514        if warm_total >= n {
1515            return;
1516        }
1517
1518        let mut Ssum = 0.0;
1519        let mut Wsum = 0.0;
1520        for k in 0..m {
1521            let v = *ring_mu.get_unchecked(k).assume_init_ref();
1522            Ssum += v;
1523            Wsum += v * ((k + 1) as f64);
1524        }
1525        let mut head = 0usize;
1526        let mut t = warm_total;
1527        *out.get_unchecked_mut(t) = Wsum / wma_sum;
1528
1529        while t + 1 < n {
1530            let y_old = *ring_mu.get_unchecked(head).assume_init_ref();
1531            let y_new = *data.get_unchecked(t + 1);
1532
1533            Wsum = Wsum - Ssum + (m as f64) * y_new;
1534
1535            ring_mu.get_unchecked_mut(head).write(y_new);
1536            Ssum = Ssum + y_new - y_old;
1537            head = (head + 1) % m;
1538
1539            *out.get_unchecked_mut(t + 1) = Wsum / wma_sum;
1540            t += 1;
1541        }
1542        return;
1543    }
1544
1545    let w0 = *w_ptr.add(0);
1546    let w1 = *w_ptr.add(1);
1547    let inv_R = w0 / w1;
1548    let a_old = w0 * inv_R;
1549    let w_last = *w_ptr.add(p - 1);
1550
1551    let warm0 = first + p - 1;
1552    if warm0 >= n {
1553        return;
1554    }
1555    let start0 = warm0 + 1 - p;
1556
1557    let mut acc0 = 0.0;
1558    let mut acc1 = 0.0;
1559    let mut acc2 = 0.0;
1560    let mut acc3 = 0.0;
1561    let mut j = 0usize;
1562    let end4 = p & !3usize;
1563    let xptr = data.as_ptr().add(start0);
1564    while j < end4 {
1565        let x0 = *xptr.add(j);
1566        let x1 = *xptr.add(j + 1);
1567        let x2 = *xptr.add(j + 2);
1568        let x3 = *xptr.add(j + 3);
1569
1570        let y0 = *w_ptr.add(j);
1571        let y1 = *w_ptr.add(j + 1);
1572        let y2 = *w_ptr.add(j + 2);
1573        let y3 = *w_ptr.add(j + 3);
1574
1575        acc0 = x0.mul_add(y0, acc0);
1576        acc1 = x1.mul_add(y1, acc1);
1577        acc2 = x2.mul_add(y2, acc2);
1578        acc3 = x3.mul_add(y3, acc3);
1579
1580        j += 4;
1581    }
1582    let mut S = (acc0 + acc1) + (acc2 + acc3);
1583    while j < p {
1584        let x = *xptr.add(j);
1585        let y = *w_ptr.add(j);
1586        S = x.mul_add(y, S);
1587        j += 1;
1588    }
1589
1590    let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
1591    let mut fill = 0usize;
1592
1593    let mut y = S * inv_wsum;
1594    ring_mu.get_unchecked_mut(fill).write(y);
1595    fill += 1;
1596
1597    let warm_total = warm0 + m - 1;
1598    let mut i = warm0;
1599    while i + 1 <= warm_total && i + 1 < n {
1600        let x_old = *data.get_unchecked(i + 1 - p);
1601        let x_new = *data.get_unchecked(i + 1);
1602        S = (S * inv_R) - a_old * x_old + w_last * x_new;
1603        y = S * inv_wsum;
1604        ring_mu.get_unchecked_mut(fill).write(y);
1605        fill += 1;
1606        i += 1;
1607    }
1608    if warm_total >= n {
1609        return;
1610    }
1611
1612    let mut Ssum = 0.0;
1613    let mut Wsum = 0.0;
1614    for k in 0..m {
1615        let v = *ring_mu.get_unchecked(k).assume_init_ref();
1616        Ssum += v;
1617        Wsum += v * ((k + 1) as f64);
1618    }
1619    let mut head = 0usize;
1620    *out.get_unchecked_mut(warm_total) = Wsum / wma_sum;
1621
1622    while i + 1 < n {
1623        let x_old = *data.get_unchecked(i + 1 - p);
1624        let x_new = *data.get_unchecked(i + 1);
1625        S = (S * inv_R) - a_old * x_old + w_last * x_new;
1626        let y_new = S * inv_wsum;
1627
1628        Wsum = Wsum - Ssum + (m as f64) * y_new;
1629
1630        let y_old = *ring_mu.get_unchecked(head).assume_init_ref();
1631        ring_mu.get_unchecked_mut(head).write(y_new);
1632        Ssum = Ssum + y_new - y_old;
1633        head = (head + 1) % m;
1634
1635        *out.get_unchecked_mut(i + 1) = Wsum / wma_sum;
1636        i += 1;
1637    }
1638}
1639
1640#[derive(Clone, Debug, Default)]
1641pub struct CoraWaveBatchBuilder {
1642    range: CoraWaveBatchRange,
1643    kernel: Kernel,
1644}
1645
1646impl CoraWaveBatchBuilder {
1647    pub fn new() -> Self {
1648        Self::default()
1649    }
1650
1651    pub fn kernel(mut self, k: Kernel) -> Self {
1652        self.kernel = k;
1653        self
1654    }
1655
1656    #[inline]
1657    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1658        self.range.period = (start, end, step);
1659        self
1660    }
1661
1662    #[inline]
1663    pub fn period_static(mut self, val: usize) -> Self {
1664        self.range.period = (val, val, 0);
1665        self
1666    }
1667
1668    #[inline]
1669    pub fn r_multi_range(mut self, start: f64, end: f64, step: f64) -> Self {
1670        self.range.r_multi = (start, end, step);
1671        self
1672    }
1673
1674    #[inline]
1675    pub fn r_multi_static(mut self, val: f64) -> Self {
1676        self.range.r_multi = (val, val, 0.0);
1677        self
1678    }
1679
1680    #[inline]
1681    pub fn smooth(mut self, val: bool) -> Self {
1682        self.range.smooth = val;
1683        self
1684    }
1685
1686    pub fn apply_slice(self, data: &[f64]) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1687        cora_wave_batch_with_kernel(data, &self.range, self.kernel)
1688    }
1689
1690    pub fn apply_candles(
1691        self,
1692        c: &Candles,
1693        src: &str,
1694    ) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1695        let slice = source_type(c, src);
1696        self.apply_slice(slice)
1697    }
1698
1699    pub fn with_default_candles(c: &Candles) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1700        CoraWaveBatchBuilder::new()
1701            .kernel(Kernel::Auto)
1702            .apply_candles(c, "close")
1703    }
1704
1705    pub fn with_default_slice(
1706        data: &[f64],
1707        k: Kernel,
1708    ) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1709        CoraWaveBatchBuilder::new().kernel(k).apply_slice(data)
1710    }
1711}
1712
1713#[cfg(feature = "python")]
1714#[pyfunction(name = "cora_wave")]
1715#[pyo3(signature = (data, period, r_multi, smooth, kernel=None))]
1716pub fn cora_wave_py<'py>(
1717    py: Python<'py>,
1718    data: PyReadonlyArray1<'py, f64>,
1719    period: usize,
1720    r_multi: f64,
1721    smooth: bool,
1722    kernel: Option<&str>,
1723) -> PyResult<Bound<'py, PyArray1<f64>>> {
1724    let slice_in = data.as_slice()?;
1725    let kern = validate_kernel(kernel, false)?;
1726    let params = CoraWaveParams {
1727        period: Some(period),
1728        r_multi: Some(r_multi),
1729        smooth: Some(smooth),
1730    };
1731    let input = CoraWaveInput::from_slice(slice_in, params);
1732
1733    let result_vec: Vec<f64> = py
1734        .allow_threads(|| cora_wave_with_kernel(&input, kern).map(|o| o.values))
1735        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1736
1737    Ok(result_vec.into_pyarray(py))
1738}
1739
1740#[cfg(feature = "python")]
1741#[pyclass(name = "CoraWaveStream")]
1742pub struct CoraWaveStreamPy {
1743    stream: CoraWaveStream,
1744}
1745
1746#[cfg(feature = "python")]
1747#[pymethods]
1748impl CoraWaveStreamPy {
1749    #[new]
1750    fn new(period: usize, r_multi: f64, smooth: bool) -> PyResult<Self> {
1751        let params = CoraWaveParams {
1752            period: Some(period),
1753            r_multi: Some(r_multi),
1754            smooth: Some(smooth),
1755        };
1756        let stream =
1757            CoraWaveStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1758        Ok(CoraWaveStreamPy { stream })
1759    }
1760
1761    fn update(&mut self, value: f64) -> Option<f64> {
1762        self.stream.update(value)
1763    }
1764}
1765
1766#[cfg(feature = "python")]
1767#[pyfunction(name = "cora_wave_batch")]
1768#[pyo3(signature = (data, period_range, r_multi_range, smooth=true, kernel=None))]
1769pub fn cora_wave_batch_py<'py>(
1770    py: Python<'py>,
1771    data: PyReadonlyArray1<'py, f64>,
1772    period_range: (usize, usize, usize),
1773    r_multi_range: (f64, f64, f64),
1774    smooth: bool,
1775    kernel: Option<&str>,
1776) -> PyResult<Bound<'py, PyDict>> {
1777    use numpy::{IntoPyArray, PyArray1};
1778    let slice_in = data.as_slice()?;
1779    let sweep = CoraWaveBatchRange {
1780        period: period_range,
1781        r_multi: r_multi_range,
1782        smooth,
1783    };
1784
1785    let combos = expand_grid_cw(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1786    let rows = combos.len();
1787    let cols = slice_in.len();
1788
1789    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1790    let out_slice = unsafe { out_arr.as_slice_mut()? };
1791
1792    let kern = validate_kernel(kernel, true)?;
1793    let combos = py
1794        .allow_threads(|| {
1795            let kernel = match kern {
1796                Kernel::Auto => detect_best_batch_kernel(),
1797                k => k,
1798            };
1799            let simd = match kernel {
1800                Kernel::Avx512Batch => Kernel::Avx512,
1801                Kernel::Avx2Batch => Kernel::Avx2,
1802                Kernel::ScalarBatch => Kernel::Scalar,
1803                _ => unreachable!(),
1804            };
1805            cora_wave_batch_inner_into(slice_in, &sweep, simd, true, out_slice)
1806        })
1807        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1808
1809    let dict = PyDict::new(py);
1810    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1811    dict.set_item(
1812        "periods",
1813        combos
1814            .iter()
1815            .map(|p| p.period.unwrap() as u64)
1816            .collect::<Vec<_>>()
1817            .into_pyarray(py),
1818    )?;
1819    dict.set_item(
1820        "r_multis",
1821        combos
1822            .iter()
1823            .map(|p| p.r_multi.unwrap())
1824            .collect::<Vec<_>>()
1825            .into_pyarray(py),
1826    )?;
1827    dict.set_item("smooth", smooth)?;
1828    Ok(dict)
1829}
1830
1831#[cfg(all(feature = "python", feature = "cuda"))]
1832#[pyfunction(name = "cora_wave_cuda_batch_dev")]
1833#[pyo3(signature = (data_f32, period_range, r_multi_range=(2.0,2.0,0.0), smooth=true, device_id=0))]
1834pub fn cora_wave_cuda_batch_dev_py<'py>(
1835    py: Python<'py>,
1836    data_f32: PyReadonlyArray1<'py, f32>,
1837    period_range: (usize, usize, usize),
1838    r_multi_range: (f64, f64, f64),
1839    smooth: bool,
1840    device_id: usize,
1841) -> PyResult<(DeviceArrayF32Py, Bound<'py, PyDict>)> {
1842    if !cuda_available() {
1843        return Err(PyValueError::new_err("CUDA not available"));
1844    }
1845    let slice_in = data_f32.as_slice()?;
1846    let sweep = CoraWaveBatchRange {
1847        period: period_range,
1848        r_multi: r_multi_range,
1849        smooth,
1850    };
1851
1852    fn combos_for_py(sweep: &CoraWaveBatchRange) -> Vec<CoraWaveParams> {
1853        let (ps, pe, pt) = sweep.period;
1854        let periods: Vec<usize> = if pt == 0 || ps == pe {
1855            vec![ps]
1856        } else if ps <= pe {
1857            (ps..=pe).step_by(pt).collect()
1858        } else {
1859            let mut v = Vec::new();
1860            let mut x = ps;
1861            loop {
1862                v.push(x);
1863                if x <= pe {
1864                    break;
1865                }
1866                if x < pt {
1867                    break;
1868                }
1869                let next = x - pt;
1870                if next < pe {
1871                    break;
1872                }
1873                x = next;
1874            }
1875            v
1876        };
1877        let (ms, me, mt) = sweep.r_multi;
1878        let mut mults: Vec<f64> = vec![];
1879        if mt.abs() < 1e-12 || (ms - me).abs() < 1e-12 {
1880            mults.push(ms);
1881        } else if ms <= me {
1882            let mut x = ms;
1883            let step = mt.abs();
1884            while x <= me + 1e-12 {
1885                mults.push(x);
1886                x += step;
1887            }
1888        } else {
1889            let mut x = ms;
1890            let step = mt.abs();
1891            while x >= me - 1e-12 {
1892                mults.push(x);
1893                x -= step;
1894            }
1895        }
1896        let mut out = Vec::with_capacity(periods.len().saturating_mul(mults.len()));
1897        for &p in &periods {
1898            for &m in &mults {
1899                out.push(CoraWaveParams {
1900                    period: Some(p),
1901                    r_multi: Some(m),
1902                    smooth: Some(sweep.smooth),
1903                });
1904            }
1905        }
1906        out
1907    }
1908
1909    let (inner, ctx_arc, dev_id, combos) = py.allow_threads(|| {
1910        let cuda =
1911            CudaCoraWave::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1912        let ctx = cuda.context_arc();
1913        let dev = cuda.device_id();
1914        let out = cuda
1915            .cora_wave_batch_dev(slice_in, &sweep)
1916            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1917        Ok::<
1918            (
1919                DeviceArrayF32,
1920                std::sync::Arc<cust::context::Context>,
1921                u32,
1922                Vec<CoraWaveParams>,
1923            ),
1924            PyErr,
1925        >((out, ctx, dev, combos_for_py(&sweep)))
1926    })?;
1927
1928    let dict = PyDict::new(py);
1929    use numpy::PyArrayMethods;
1930    let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1931    let r_multis: Vec<f64> = combos.iter().map(|c| c.r_multi.unwrap()).collect();
1932    dict.set_item("periods", periods.into_pyarray(py))?;
1933    dict.set_item("r_multis", r_multis.into_pyarray(py))?;
1934    dict.set_item("smooth", smooth)?;
1935    Ok((
1936        DeviceArrayF32Py {
1937            inner,
1938            _ctx: Some(ctx_arc),
1939            device_id: Some(dev_id),
1940        },
1941        dict,
1942    ))
1943}
1944
1945#[cfg(all(feature = "python", feature = "cuda"))]
1946#[pyfunction(name = "cora_wave_cuda_many_series_one_param_dev")]
1947#[pyo3(signature = (data_tm_f32, cols, rows, period, r_multi=2.0, smooth=true, device_id=0))]
1948pub fn cora_wave_cuda_many_series_one_param_dev_py<'py>(
1949    py: Python<'py>,
1950    data_tm_f32: PyReadonlyArray1<'py, f32>,
1951    cols: usize,
1952    rows: usize,
1953    period: usize,
1954    r_multi: f64,
1955    smooth: bool,
1956    device_id: usize,
1957) -> PyResult<DeviceArrayF32Py> {
1958    if !cuda_available() {
1959        return Err(PyValueError::new_err("CUDA not available"));
1960    }
1961    let slice = data_tm_f32.as_slice()?;
1962    let params = CoraWaveParams {
1963        period: Some(period),
1964        r_multi: Some(r_multi),
1965        smooth: Some(smooth),
1966    };
1967    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1968        let cuda =
1969            CudaCoraWave::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1970        let ctx = cuda.context_arc();
1971        let dev = cuda.device_id();
1972        let out = cuda
1973            .cora_wave_multi_series_one_param_time_major_dev(slice, cols, rows, &params)
1974            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1975        Ok::<(DeviceArrayF32, std::sync::Arc<cust::context::Context>, u32), PyErr>((out, ctx, dev))
1976    })?;
1977    Ok(DeviceArrayF32Py {
1978        inner,
1979        _ctx: Some(ctx_arc),
1980        device_id: Some(dev_id),
1981    })
1982}
1983
1984#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1985#[wasm_bindgen]
1986pub fn cora_wave_js(
1987    data: &[f64],
1988    period: usize,
1989    r_multi: f64,
1990    smooth: bool,
1991) -> Result<Vec<f64>, JsValue> {
1992    let params = CoraWaveParams {
1993        period: Some(period),
1994        r_multi: Some(r_multi),
1995        smooth: Some(smooth),
1996    };
1997    let input = CoraWaveInput::from_slice(data, params);
1998
1999    let mut output = vec![0.0; data.len()];
2000    cora_wave_into_slice(&mut output, &input, detect_best_kernel())
2001        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2002
2003    Ok(output)
2004}
2005
2006#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2007#[wasm_bindgen]
2008pub fn cora_wave_alloc(len: usize) -> *mut f64 {
2009    let mut vec = Vec::<f64>::with_capacity(len);
2010    let ptr = vec.as_mut_ptr();
2011    std::mem::forget(vec);
2012    ptr
2013}
2014
2015#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2016#[wasm_bindgen]
2017pub fn cora_wave_free(ptr: *mut f64, len: usize) {
2018    unsafe {
2019        let _ = Vec::from_raw_parts(ptr, len, len);
2020    }
2021}
2022
2023#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2024#[wasm_bindgen]
2025pub fn cora_wave_into(
2026    in_ptr: *const f64,
2027    out_ptr: *mut f64,
2028    len: usize,
2029    period: usize,
2030    r_multi: f64,
2031    smooth: bool,
2032) -> Result<(), JsValue> {
2033    if in_ptr.is_null() || out_ptr.is_null() {
2034        return Err(JsValue::from_str("null pointer passed to cora_wave_into"));
2035    }
2036    if period == 0 || period > len {
2037        return Err(JsValue::from_str("Invalid period"));
2038    }
2039
2040    unsafe {
2041        let data = std::slice::from_raw_parts(in_ptr, len);
2042
2043        let params = CoraWaveParams {
2044            period: Some(period),
2045            r_multi: Some(r_multi),
2046            smooth: Some(smooth),
2047        };
2048        let input = CoraWaveInput::from_slice(data, params);
2049
2050        if in_ptr == out_ptr {
2051            let mut temp = vec![0.0; len];
2052            cora_wave_into_slice(&mut temp, &input, detect_best_kernel())
2053                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2054            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2055            out.copy_from_slice(&temp);
2056        } else {
2057            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2058            cora_wave_into_slice(out, &input, detect_best_kernel())
2059                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2060        }
2061
2062        Ok(())
2063    }
2064}
2065
2066#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2067#[derive(Serialize, Deserialize)]
2068pub struct CoraWaveBatchConfig {
2069    pub period_range: (usize, usize, usize),
2070    pub r_multi_range: (f64, f64, f64),
2071    pub smooth: bool,
2072}
2073
2074#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2075#[derive(Serialize, Deserialize)]
2076pub struct CoraWaveBatchJsOutput {
2077    pub values: Vec<f64>,
2078    pub combos: Vec<CoraWaveParams>,
2079    pub rows: usize,
2080    pub cols: usize,
2081}
2082
2083#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2084#[wasm_bindgen(js_name = cora_wave_batch)]
2085pub fn cora_wave_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2086    let cfg: CoraWaveBatchConfig = serde_wasm_bindgen::from_value(config)
2087        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2088    let sweep = CoraWaveBatchRange {
2089        period: cfg.period_range,
2090        r_multi: cfg.r_multi_range,
2091        smooth: cfg.smooth,
2092    };
2093    let out = cora_wave_batch_with_kernel(data, &sweep, detect_best_batch_kernel())
2094        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2095    let js = CoraWaveBatchJsOutput {
2096        values: out.values,
2097        combos: out.combos,
2098        rows: out.rows,
2099        cols: out.cols,
2100    };
2101    serde_wasm_bindgen::to_value(&js)
2102        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2103}
2104
2105#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2106#[wasm_bindgen]
2107pub fn cora_wave_batch_into(
2108    in_ptr: *const f64,
2109    out_ptr: *mut f64,
2110    len: usize,
2111    period_start: usize,
2112    period_end: usize,
2113    period_step: usize,
2114    rmulti_start: f64,
2115    rmulti_end: f64,
2116    rmulti_step: f64,
2117    smooth: bool,
2118) -> Result<usize, JsValue> {
2119    if in_ptr.is_null() || out_ptr.is_null() {
2120        return Err(JsValue::from_str(
2121            "null pointer passed to cora_wave_batch_into",
2122        ));
2123    }
2124    unsafe {
2125        let data = std::slice::from_raw_parts(in_ptr, len);
2126        let sweep = CoraWaveBatchRange {
2127            period: (period_start, period_end, period_step),
2128            r_multi: (rmulti_start, rmulti_end, rmulti_step),
2129            smooth,
2130        };
2131        let combos = expand_grid_cw(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2132        let rows = combos.len();
2133        let cols = len;
2134        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2135
2136        cora_wave_batch_inner_into(data, &sweep, detect_best_batch_kernel(), false, out)
2137            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2138        Ok(rows)
2139    }
2140}
2141
2142#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2143#[wasm_bindgen]
2144#[deprecated(
2145    since = "1.0.0",
2146    note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2147)]
2148pub struct CoraWaveContext {
2149    weights: Vec<f64>,
2150    inv_norm: f64,
2151    period: usize,
2152    smooth_period: usize,
2153    kernel: Kernel,
2154}
2155
2156#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2157#[wasm_bindgen]
2158#[allow(deprecated)]
2159impl CoraWaveContext {
2160    #[wasm_bindgen(constructor)]
2161    #[deprecated(
2162        since = "1.0.0",
2163        note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2164    )]
2165    pub fn new(period: usize, r_multi: f64, smooth: bool) -> Result<CoraWaveContext, JsValue> {
2166        if period == 0 {
2167            return Err(JsValue::from_str("Invalid period: 0"));
2168        }
2169        if !r_multi.is_finite() || r_multi < 0.0 {
2170            return Err(JsValue::from_str(&format!("Invalid r_multi: {}", r_multi)));
2171        }
2172        let smooth_period = if smooth {
2173            ((period as f64).sqrt().round() as usize).max(1)
2174        } else {
2175            1
2176        };
2177
2178        if period == 1 {
2179            return Ok(CoraWaveContext {
2180                weights: vec![1.0],
2181                inv_norm: 1.0,
2182                period,
2183                smooth_period,
2184                kernel: detect_best_kernel(),
2185            });
2186        }
2187
2188        let start_wt = 0.01;
2189        let end_wt = period as f64;
2190        let r = (end_wt / start_wt).powf(1.0 / (period as f64 - 1.0)) - 1.0;
2191        let base = 1.0 + r * r_multi;
2192
2193        let mut weights = Vec::with_capacity(period);
2194        let mut norm = 0.0;
2195        for j in 0..period {
2196            let w = start_wt * base.powi((j + 1) as i32);
2197            weights.push(w);
2198            norm += w;
2199        }
2200
2201        Ok(CoraWaveContext {
2202            weights,
2203            inv_norm: 1.0 / norm,
2204            period,
2205            smooth_period,
2206            kernel: detect_best_kernel(),
2207        })
2208    }
2209
2210    pub fn update_into(
2211        &self,
2212        in_ptr: *const f64,
2213        out_ptr: *mut f64,
2214        len: usize,
2215    ) -> Result<(), JsValue> {
2216        if len < self.period {
2217            return Err(JsValue::from_str("Data length less than period"));
2218        }
2219        if in_ptr.is_null() || out_ptr.is_null() {
2220            return Err(JsValue::from_str("null pointer passed to update_into"));
2221        }
2222        unsafe {
2223            let data = std::slice::from_raw_parts(in_ptr, len);
2224            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2225            let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2226
2227            cora_wave_scalar_with_weights(
2228                data,
2229                &self.weights,
2230                self.inv_norm,
2231                self.smooth_period,
2232                first,
2233                out,
2234            );
2235
2236            let warm = first + self.period - 1 + self.smooth_period.saturating_sub(1);
2237            for i in 0..warm.min(len) {
2238                out[i] = f64::NAN;
2239            }
2240        }
2241        Ok(())
2242    }
2243
2244    pub fn get_warmup_period(&self) -> usize {
2245        self.period - 1 + self.smooth_period.saturating_sub(1)
2246    }
2247}
2248
2249#[cfg(test)]
2250mod tests {
2251    use super::*;
2252    use crate::skip_if_unsupported;
2253    use crate::utilities::data_loader::read_candles_from_csv;
2254    #[cfg(feature = "proptest")]
2255    use proptest::prelude::*;
2256    use std::error::Error;
2257
2258    #[test]
2259    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2260    fn test_cora_wave_into_matches_api() {
2261        let mut data = vec![f64::NAN; 5];
2262        for i in 0..256 {
2263            let x = (i as f64 * 0.03141592653589793).sin() * 10.0 + 50.0;
2264            data.push(x);
2265        }
2266
2267        let input = CoraWaveInput::from_slice(&data, CoraWaveParams::default());
2268
2269        let baseline = cora_wave(&input).expect("baseline cora_wave() failed");
2270        let mut out = vec![0.0; data.len()];
2271        cora_wave_into(&input, &mut out).expect("cora_wave_into() failed");
2272
2273        assert_eq!(baseline.values.len(), out.len());
2274        for (i, (&a, &b)) in baseline.values.iter().zip(&out).enumerate() {
2275            let ok = if a.is_nan() && b.is_nan() {
2276                true
2277            } else if a == b {
2278                true
2279            } else {
2280                (a - b).abs() <= 1e-12
2281            };
2282            assert!(ok, "mismatch at index {}: {} vs {}", i, a, b);
2283        }
2284    }
2285
2286    fn check_cora_wave_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2287        skip_if_unsupported!(kernel, test_name);
2288        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2289        let candles = read_candles_from_csv(file_path)?;
2290
2291        let input = CoraWaveInput::from_candles(&candles, "close", CoraWaveParams::default());
2292        let result = cora_wave_with_kernel(&input, kernel)?;
2293
2294        let expected_last_five = [
2295            59248.63632114,
2296            59251.74238978,
2297            59203.36944998,
2298            59171.14999178,
2299            59053.74201623,
2300        ];
2301
2302        let start = result.values.len().saturating_sub(5);
2303        for (i, &val) in result.values[start..].iter().enumerate() {
2304            let diff = (val - expected_last_five[i]).abs();
2305            assert!(
2306                diff < 0.01,
2307                "[{}] CoRa Wave {:?} mismatch at idx {}: got {}, expected {}",
2308                test_name,
2309                kernel,
2310                i,
2311                val,
2312                expected_last_five[i]
2313            );
2314        }
2315        Ok(())
2316    }
2317
2318    fn check_cora_wave_partial_params(
2319        test_name: &str,
2320        kernel: Kernel,
2321    ) -> Result<(), Box<dyn Error>> {
2322        skip_if_unsupported!(kernel, test_name);
2323        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2324        let candles = read_candles_from_csv(file_path)?;
2325
2326        let default_params = CoraWaveParams {
2327            period: None,
2328            r_multi: None,
2329            smooth: None,
2330        };
2331        let input = CoraWaveInput::from_candles(&candles, "close", default_params);
2332        let output = cora_wave_with_kernel(&input, kernel)?;
2333        assert_eq!(output.values.len(), candles.close.len());
2334
2335        Ok(())
2336    }
2337
2338    fn check_cora_wave_default_candles(
2339        test_name: &str,
2340        kernel: Kernel,
2341    ) -> Result<(), Box<dyn Error>> {
2342        skip_if_unsupported!(kernel, test_name);
2343        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2344        let candles = read_candles_from_csv(file_path)?;
2345
2346        let input = CoraWaveInput::with_default_candles(&candles);
2347        match input.data {
2348            CoraWaveData::Candles { source, .. } => assert_eq!(source, "close"),
2349            _ => panic!("Expected CoraWaveData::Candles"),
2350        }
2351        let output = cora_wave_with_kernel(&input, kernel)?;
2352        assert_eq!(output.values.len(), candles.close.len());
2353
2354        Ok(())
2355    }
2356
2357    fn check_cora_wave_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2358        skip_if_unsupported!(kernel, test_name);
2359        let input_data = [10.0, 20.0, 30.0];
2360        let params = CoraWaveParams {
2361            period: Some(0),
2362            r_multi: None,
2363            smooth: None,
2364        };
2365        let input = CoraWaveInput::from_slice(&input_data, params);
2366        let res = cora_wave_with_kernel(&input, kernel);
2367        assert!(
2368            res.is_err(),
2369            "[{}] CoRa Wave should fail with zero period",
2370            test_name
2371        );
2372        Ok(())
2373    }
2374
2375    fn check_cora_wave_period_exceeds_length(
2376        test_name: &str,
2377        kernel: Kernel,
2378    ) -> Result<(), Box<dyn Error>> {
2379        skip_if_unsupported!(kernel, test_name);
2380        let data_small = [10.0, 20.0, 30.0];
2381        let params = CoraWaveParams {
2382            period: Some(10),
2383            r_multi: None,
2384            smooth: None,
2385        };
2386        let input = CoraWaveInput::from_slice(&data_small, params);
2387        let res = cora_wave_with_kernel(&input, kernel);
2388        assert!(
2389            res.is_err(),
2390            "[{}] CoRa Wave should fail with period exceeding length",
2391            test_name
2392        );
2393        Ok(())
2394    }
2395
2396    fn check_cora_wave_very_small_dataset(
2397        test_name: &str,
2398        kernel: Kernel,
2399    ) -> Result<(), Box<dyn Error>> {
2400        skip_if_unsupported!(kernel, test_name);
2401        let single_point = [42.0];
2402        let params = CoraWaveParams::default();
2403        let input = CoraWaveInput::from_slice(&single_point, params);
2404        let res = cora_wave_with_kernel(&input, kernel);
2405        assert!(
2406            res.is_err(),
2407            "[{}] CoRa Wave should fail with insufficient data",
2408            test_name
2409        );
2410        Ok(())
2411    }
2412
2413    fn check_cora_wave_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2414        skip_if_unsupported!(kernel, test_name);
2415        let empty: [f64; 0] = [];
2416        let params = CoraWaveParams::default();
2417        let input = CoraWaveInput::from_slice(&empty, params);
2418        let res = cora_wave_with_kernel(&input, kernel);
2419        assert!(
2420            res.is_err(),
2421            "[{}] CoRa Wave should fail with empty input",
2422            test_name
2423        );
2424        Ok(())
2425    }
2426
2427    fn check_cora_wave_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2428        skip_if_unsupported!(kernel, test_name);
2429        let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2430        let params = CoraWaveParams::default();
2431        let input = CoraWaveInput::from_slice(&nan_data, params);
2432        let res = cora_wave_with_kernel(&input, kernel);
2433        assert!(
2434            res.is_err(),
2435            "[{}] CoRa Wave should fail with all NaN values",
2436            test_name
2437        );
2438        Ok(())
2439    }
2440
2441    fn check_cora_wave_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2442        skip_if_unsupported!(kernel, test_name);
2443
2444        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2445        let c = read_candles_from_csv(file)?;
2446
2447        let params = CoraWaveParams::default();
2448        let input = CoraWaveInput::from_candles(&c, "close", params.clone());
2449        let batch = cora_wave_with_kernel(&input, kernel)?.values;
2450
2451        let mut stream = CoraWaveStream::try_new(params)?;
2452        let mut streamed = Vec::with_capacity(c.close.len());
2453        for &px in &c.close {
2454            match stream.update(px) {
2455                Some(v) => streamed.push(v),
2456                None => streamed.push(f64::NAN),
2457            }
2458        }
2459
2460        assert_eq!(batch.len(), streamed.len());
2461        for (i, (&b, &s)) in batch.iter().zip(&streamed).enumerate() {
2462            if b.is_nan() && s.is_nan() {
2463                continue;
2464            }
2465            let d = (b - s).abs();
2466            assert!(
2467                d < 1e-9,
2468                "[{}] streaming mismatch at {}: {} vs {}",
2469                test_name,
2470                i,
2471                b,
2472                s
2473            );
2474        }
2475        Ok(())
2476    }
2477
2478    fn check_cora_wave_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2479        skip_if_unsupported!(kernel, test_name);
2480
2481        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2482        let c = read_candles_from_csv(file)?;
2483
2484        let out = cora_wave_with_kernel(&CoraWaveInput::with_default_candles(&c), kernel)?.values;
2485        if out.len() > 240 {
2486            for (i, &v) in out[240..].iter().enumerate() {
2487                assert!(!v.is_nan(), "[{}] unexpected NaN at {}", test_name, 240 + i);
2488            }
2489        }
2490        Ok(())
2491    }
2492
2493    fn check_cora_wave_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2494        skip_if_unsupported!(kernel, test_name);
2495
2496        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2497        let c = read_candles_from_csv(file)?;
2498
2499        let first = cora_wave_with_kernel(
2500            &CoraWaveInput::from_candles(&c, "close", CoraWaveParams::default()),
2501            kernel,
2502        )?;
2503
2504        let second_in = CoraWaveInput::from_slice(&first.values, CoraWaveParams::default());
2505        let second = cora_wave_with_kernel(&second_in, kernel)?;
2506        let second_ref = cora_wave_with_kernel(&second_in, Kernel::Scalar)?;
2507
2508        assert_eq!(second.values.len(), first.values.len());
2509        for (i, (&a, &b)) in second.values.iter().zip(&second_ref.values).enumerate() {
2510            if a.is_nan() && b.is_nan() {
2511                continue;
2512            }
2513            assert!(
2514                (a - b).abs() < 1e-9,
2515                "[{}] reinput mismatch at {}: {} vs {}",
2516                test_name,
2517                i,
2518                a,
2519                b
2520            );
2521        }
2522        Ok(())
2523    }
2524
2525    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2526        skip_if_unsupported!(kernel, test);
2527
2528        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2529        let c = read_candles_from_csv(file)?;
2530
2531        let out = CoraWaveBatchBuilder::new()
2532            .kernel(kernel)
2533            .apply_candles(&c, "close")?;
2534        let def = CoraWaveParams::default();
2535        let row = out.values_for(&def).expect("default row missing");
2536        assert_eq!(row.len(), c.close.len());
2537
2538        let expected = [
2539            59248.63632114,
2540            59251.74238978,
2541            59203.36944998,
2542            59171.14999178,
2543            59053.74201623,
2544        ];
2545        let start = row.len() - 5;
2546        for (i, &v) in row[start..].iter().enumerate() {
2547            assert!(
2548                (v - expected[i]).abs() < 0.01,
2549                "[{test}] default-row mismatch at idx {i}: {v}"
2550            );
2551        }
2552        Ok(())
2553    }
2554
2555    fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2556        skip_if_unsupported!(kernel, test);
2557
2558        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2559        let c = read_candles_from_csv(file)?;
2560
2561        let out = CoraWaveBatchBuilder::new()
2562            .kernel(kernel)
2563            .period_range(20, 60, 1)
2564            .r_multi_range(1.0, 3.0, 0.25)
2565            .apply_candles(&c, "close")?;
2566
2567        let expected = 41 * 9;
2568        assert_eq!(out.combos.len(), expected);
2569        assert_eq!(out.rows, expected);
2570        assert_eq!(out.cols, c.close.len());
2571        Ok(())
2572    }
2573
2574    #[cfg(debug_assertions)]
2575    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2576        skip_if_unsupported!(kernel, test);
2577
2578        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2579        let c = read_candles_from_csv(file)?;
2580
2581        let out = CoraWaveBatchBuilder::new()
2582            .kernel(kernel)
2583            .period_range(20, 24, 2)
2584            .r_multi_range(1.5, 2.5, 0.5)
2585            .apply_candles(&c, "close")?;
2586
2587        for (idx, &v) in out.values.iter().enumerate() {
2588            if v.is_nan() {
2589                continue;
2590            }
2591            let bits = v.to_bits();
2592            if bits == 0x11111111_11111111
2593                || bits == 0x22222222_22222222
2594                || bits == 0x33333333_33333333
2595            {
2596                let row = idx / out.cols;
2597                let col = idx % out.cols;
2598                let combo = &out.combos[row];
2599                panic!(
2600                    "[{}] poison value at row {} col {} (idx {}) params: period={}, r_multi={}, smooth={}",
2601                    test, row, col, idx,
2602                    combo.period.unwrap_or(20),
2603                    combo.r_multi.unwrap_or(2.0),
2604                    combo.smooth.unwrap_or(true),
2605                );
2606            }
2607        }
2608        Ok(())
2609    }
2610
2611    #[cfg(not(debug_assertions))]
2612    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2613        Ok(())
2614    }
2615
2616    #[cfg(debug_assertions)]
2617    fn check_cora_wave_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2618        skip_if_unsupported!(kernel, test);
2619        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2620        let c = read_candles_from_csv(file)?;
2621        let out = cora_wave_with_kernel(&CoraWaveInput::with_default_candles(&c), kernel)?.values;
2622        for (i, &v) in out.iter().enumerate() {
2623            if v.is_nan() {
2624                continue;
2625            }
2626            let b = v.to_bits();
2627            assert!(
2628                !(b == 0x11111111_11111111 || b == 0x22222222_22222222 || b == 0x33333333_33333333),
2629                "[{test}] poison at idx {i}"
2630            );
2631        }
2632        Ok(())
2633    }
2634
2635    #[cfg(not(debug_assertions))]
2636    fn check_cora_wave_no_poison(_: &str, _: Kernel) -> Result<(), Box<dyn Error>> {
2637        Ok(())
2638    }
2639
2640    macro_rules! generate_all_cora_wave_tests {
2641        ($($test_fn:ident),*) => {
2642            paste::paste! {
2643                $(
2644                    #[test]
2645                    fn [<$test_fn _scalar_f64>]() {
2646                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2647                    }
2648                )*
2649                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2650                $(
2651                    #[test]
2652                    fn [<$test_fn _avx2_f64>]() {
2653                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2654                    }
2655                )*
2656                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2657                $(
2658                    #[test]
2659                    fn [<$test_fn _avx512_f64>]() {
2660                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2661                    }
2662                )*
2663            }
2664        };
2665    }
2666
2667    generate_all_cora_wave_tests!(
2668        check_cora_wave_accuracy,
2669        check_cora_wave_partial_params,
2670        check_cora_wave_default_candles,
2671        check_cora_wave_zero_period,
2672        check_cora_wave_period_exceeds_length,
2673        check_cora_wave_very_small_dataset,
2674        check_cora_wave_empty_input,
2675        check_cora_wave_all_nan,
2676        check_cora_wave_nan_handling,
2677        check_cora_wave_streaming,
2678        check_cora_wave_reinput,
2679        check_cora_wave_no_poison
2680    );
2681
2682    macro_rules! gen_batch_tests {
2683        ($fn_name:ident) => {
2684            paste::paste! {
2685                #[test] fn [<$fn_name _scalar>]()      { let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch); }
2686                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2687                #[test] fn [<$fn_name _avx2>]()        { let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch); }
2688                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2689                #[test] fn [<$fn_name _avx512>]()      { let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch); }
2690                #[test] fn [<$fn_name _auto_detect>]() { let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto); }
2691            }
2692        };
2693    }
2694
2695    gen_batch_tests!(check_batch_default_row);
2696    gen_batch_tests!(check_batch_sweep);
2697    gen_batch_tests!(check_batch_no_poison);
2698
2699    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2700    #[test]
2701    fn test_cora_wave_simd128_correctness() {
2702        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2703        let params = CoraWaveParams::default();
2704        let input = CoraWaveInput::from_slice(&data, params);
2705        let scalar = cora_wave_with_kernel(&input, Kernel::Scalar).unwrap();
2706        let simd = cora_wave_with_kernel(&input, Kernel::Scalar).unwrap();
2707        assert_eq!(scalar.values.len(), simd.values.len());
2708        for (i, (a, b)) in scalar.values.iter().zip(simd.values.iter()).enumerate() {
2709            assert!((a - b).abs() < 1e-10, "SIMD128 mismatch at {i}: {a} vs {b}");
2710        }
2711    }
2712
2713    #[cfg(feature = "proptest")]
2714    proptest! {
2715        #[test]
2716        fn test_cora_wave_no_panic(data: Vec<f64>, period in 1usize..100) {
2717            let params = CoraWaveParams {
2718                period: Some(period),
2719                r_multi: Some(2.0),
2720                smooth: Some(true),
2721            };
2722            let input = CoraWaveInput::from_slice(&data, params);
2723            let _ = cora_wave(&input);
2724        }
2725
2726        #[test]
2727        fn test_cora_wave_length_preservation(size in 10usize..100) {
2728            let data: Vec<f64> = (0..size).map(|i| i as f64).collect();
2729            let params = CoraWaveParams::default();
2730            let input = CoraWaveInput::from_slice(&data, params);
2731
2732            if let Ok(output) = cora_wave(&input) {
2733                prop_assert_eq!(output.values.len(), size);
2734            }
2735        }
2736    }
2737}