Skip to main content

vector_ta/indicators/
cmo.rs

1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5    make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use serde::{Deserialize, Serialize};
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use wasm_bindgen::prelude::*;
20
21impl<'a> AsRef<[f64]> for CmoInput<'a> {
22    #[inline(always)]
23    fn as_ref(&self) -> &[f64] {
24        match &self.data {
25            CmoData::Slice(slice) => slice,
26            CmoData::Candles { candles, source } => source_type(candles, source),
27        }
28    }
29}
30
31#[derive(Debug, Clone)]
32pub enum CmoData<'a> {
33    Candles {
34        candles: &'a Candles,
35        source: &'a str,
36    },
37    Slice(&'a [f64]),
38}
39
40#[derive(Debug, Clone)]
41pub struct CmoOutput {
42    pub values: Vec<f64>,
43}
44
45#[derive(Debug, Clone)]
46#[cfg_attr(
47    all(target_arch = "wasm32", feature = "wasm"),
48    derive(Serialize, Deserialize)
49)]
50pub struct CmoParams {
51    pub period: Option<usize>,
52}
53
54impl Default for CmoParams {
55    fn default() -> Self {
56        Self { period: Some(14) }
57    }
58}
59
60#[derive(Debug, Clone)]
61pub struct CmoInput<'a> {
62    pub data: CmoData<'a>,
63    pub params: CmoParams,
64}
65
66impl<'a> CmoInput<'a> {
67    #[inline]
68    pub fn from_candles(c: &'a Candles, s: &'a str, p: CmoParams) -> Self {
69        Self {
70            data: CmoData::Candles {
71                candles: c,
72                source: s,
73            },
74            params: p,
75        }
76    }
77    #[inline]
78    pub fn from_slice(sl: &'a [f64], p: CmoParams) -> Self {
79        Self {
80            data: CmoData::Slice(sl),
81            params: p,
82        }
83    }
84    #[inline]
85    pub fn with_default_candles(c: &'a Candles) -> Self {
86        Self::from_candles(c, "close", CmoParams::default())
87    }
88    #[inline]
89    pub fn get_period(&self) -> usize {
90        self.params.period.unwrap_or(14)
91    }
92    #[inline]
93    pub fn data_len(&self) -> usize {
94        match &self.data {
95            CmoData::Slice(slice) => slice.len(),
96            CmoData::Candles { candles, .. } => candles.close.len(),
97        }
98    }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct CmoBuilder {
103    period: Option<usize>,
104    kernel: Kernel,
105}
106
107impl Default for CmoBuilder {
108    fn default() -> Self {
109        Self {
110            period: None,
111            kernel: Kernel::Auto,
112        }
113    }
114}
115
116impl CmoBuilder {
117    #[inline(always)]
118    pub fn new() -> Self {
119        Self::default()
120    }
121    #[inline(always)]
122    pub fn period(mut self, n: usize) -> Self {
123        self.period = Some(n);
124        self
125    }
126    #[inline(always)]
127    pub fn kernel(mut self, k: Kernel) -> Self {
128        self.kernel = k;
129        self
130    }
131
132    #[inline(always)]
133    pub fn apply(self, c: &Candles) -> Result<CmoOutput, CmoError> {
134        let p = CmoParams {
135            period: self.period,
136        };
137        let i = CmoInput::from_candles(c, "close", p);
138        cmo_with_kernel(&i, self.kernel)
139    }
140
141    #[inline(always)]
142    pub fn apply_slice(self, d: &[f64]) -> Result<CmoOutput, CmoError> {
143        let p = CmoParams {
144            period: self.period,
145        };
146        let i = CmoInput::from_slice(d, p);
147        cmo_with_kernel(&i, self.kernel)
148    }
149
150    #[inline(always)]
151    pub fn into_stream(self) -> Result<CmoStream, CmoError> {
152        let p = CmoParams {
153            period: self.period,
154        };
155        CmoStream::try_new(p)
156    }
157}
158
159#[derive(Debug, Error)]
160pub enum CmoError {
161    #[error("cmo: Empty data provided.")]
162    EmptyData,
163
164    #[error("cmo: Invalid period: period={period}, data_len={data_len}")]
165    InvalidPeriod { period: usize, data_len: usize },
166
167    #[error("cmo: All values are NaN.")]
168    AllValuesNaN,
169
170    #[error("cmo: Not enough valid data: needed={needed}, valid={valid}")]
171    NotEnoughValidData { needed: usize, valid: usize },
172
173    #[error("cmo: Invalid range: start={start}, end={end}, step={step}")]
174    InvalidRange {
175        start: usize,
176        end: usize,
177        step: usize,
178    },
179
180    #[error("cmo: Invalid kernel for batch: {0:?}")]
181    InvalidKernelForBatch(Kernel),
182
183    #[error("cmo: Output length mismatch: expected={expected}, got={got}")]
184    OutputLengthMismatch { expected: usize, got: usize },
185}
186
187#[inline]
188pub fn cmo(input: &CmoInput) -> Result<CmoOutput, CmoError> {
189    cmo_with_kernel(input, Kernel::Auto)
190}
191
192#[inline(always)]
193fn cmo_prepare<'a>(
194    input: &'a CmoInput,
195    k: Kernel,
196) -> Result<(&'a [f64], usize, usize, Kernel), CmoError> {
197    let data: &[f64] = input.as_ref();
198    let len = data.len();
199    if len == 0 {
200        return Err(CmoError::EmptyData);
201    }
202    let period = input.get_period();
203    if period == 0 || period > len {
204        return Err(CmoError::InvalidPeriod {
205            period,
206            data_len: len,
207        });
208    }
209    let first = data
210        .iter()
211        .position(|x| !x.is_nan())
212        .ok_or(CmoError::AllValuesNaN)?;
213    if len - first <= period {
214        return Err(CmoError::NotEnoughValidData {
215            needed: period + 1,
216            valid: len - first,
217        });
218    }
219    let mut chosen = match k {
220        Kernel::Auto => Kernel::Scalar,
221        other => other,
222    };
223
224    if chosen.is_batch() {
225        chosen = match chosen {
226            Kernel::Avx512Batch => Kernel::Avx512,
227            Kernel::Avx2Batch => Kernel::Avx2,
228            Kernel::ScalarBatch => Kernel::Scalar,
229            _ => chosen,
230        };
231    }
232    Ok((data, period, first, chosen))
233}
234
235#[inline(always)]
236fn cmo_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
237    unsafe {
238        match kernel {
239            Kernel::Scalar | Kernel::ScalarBatch => cmo_scalar(data, period, first, out),
240            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
241            Kernel::Avx2 | Kernel::Avx2Batch => cmo_avx2(data, period, first, out),
242            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
243            Kernel::Avx512 | Kernel::Avx512Batch => cmo_avx512(data, period, first, out),
244            _ => unreachable!(),
245        }
246    }
247}
248
249pub fn cmo_with_kernel(input: &CmoInput, kernel: Kernel) -> Result<CmoOutput, CmoError> {
250    let (data, period, first, chosen) = cmo_prepare(input, kernel)?;
251    let mut out = alloc_with_nan_prefix(data.len(), first + period);
252    cmo_compute_into(data, period, first, chosen, &mut out);
253    Ok(CmoOutput { values: out })
254}
255
256#[inline]
257pub fn cmo_into_slice(dst: &mut [f64], input: &CmoInput, kern: Kernel) -> Result<(), CmoError> {
258    let (data, period, first, chosen) = cmo_prepare(input, kern)?;
259    if dst.len() != data.len() {
260        return Err(CmoError::OutputLengthMismatch {
261            expected: data.len(),
262            got: dst.len(),
263        });
264    }
265    cmo_compute_into(data, period, first, chosen, dst);
266    let warmup_end = first + period;
267    for v in &mut dst[..warmup_end] {
268        *v = f64::NAN;
269    }
270    Ok(())
271}
272
273#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
274#[inline]
275pub fn cmo_into(input: &CmoInput, out: &mut [f64]) -> Result<(), CmoError> {
276    let (data, period, first, chosen) = cmo_prepare(input, Kernel::Auto)?;
277
278    if out.len() != data.len() {
279        return Err(CmoError::OutputLengthMismatch {
280            expected: data.len(),
281            got: out.len(),
282        });
283    }
284
285    let warmup_end = first + period;
286    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
287    let warm = warmup_end.min(out.len());
288    for v in &mut out[..warm] {
289        *v = qnan;
290    }
291
292    cmo_compute_into(data, period, first, chosen, out);
293
294    Ok(())
295}
296
297#[inline]
298pub fn cmo_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
299    let mut avg_gain = 0.0;
300    let mut avg_loss = 0.0;
301    let mut prev_price = data[first_valid];
302
303    let start_loop = first_valid + 1;
304    let init_end = first_valid + period;
305
306    let period_f = period as f64;
307    let period_m1 = (period - 1) as f64;
308    let inv_period = 1.0 / period_f;
309
310    for i in start_loop..data.len() {
311        let curr = data[i];
312        let diff = curr - prev_price;
313        prev_price = curr;
314
315        let abs_diff = diff.abs();
316        let gain = 0.5 * (diff + abs_diff);
317        let loss = 0.5 * (abs_diff - diff);
318
319        if i <= init_end {
320            avg_gain += gain;
321            avg_loss += loss;
322            if i == init_end {
323                avg_gain *= inv_period;
324                avg_loss *= inv_period;
325                let sum_gl = avg_gain + avg_loss;
326                out[i] = if sum_gl != 0.0 {
327                    100.0 * ((avg_gain - avg_loss) / sum_gl)
328                } else {
329                    0.0
330                };
331            }
332        } else {
333            avg_gain *= period_m1;
334            avg_loss *= period_m1;
335            avg_gain += gain;
336            avg_loss += loss;
337            avg_gain *= inv_period;
338            avg_loss *= inv_period;
339            let sum_gl = avg_gain + avg_loss;
340            out[i] = if sum_gl != 0.0 {
341                100.0 * ((avg_gain - avg_loss) / sum_gl)
342            } else {
343                0.0
344            };
345        }
346    }
347}
348
349#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
350#[inline]
351pub fn cmo_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
352    unsafe {
353        if period <= 32 {
354            cmo_avx512_short(data, period, first_valid, out)
355        } else {
356            cmo_avx512_long(data, period, first_valid, out)
357        }
358    }
359}
360
361#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362#[inline]
363pub fn cmo_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
364    unsafe { cmo_avx2_impl(data, period, first_valid, out) }
365}
366
367#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
368#[inline]
369pub unsafe fn cmo_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
370    cmo_avx512_impl(data, period, first_valid, out)
371}
372
373#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
374#[inline]
375pub unsafe fn cmo_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
376    cmo_avx512_impl(data, period, first_valid, out)
377}
378
379#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
380#[target_feature(enable = "avx2")]
381unsafe fn cmo_avx2_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
382    use core::arch::x86_64::*;
383
384    debug_assert!(out.len() == data.len());
385
386    #[inline(always)]
387    unsafe fn hsum256_pd(v: __m256d) -> f64 {
388        let hi = _mm256_extractf128_pd(v, 1);
389        let lo = _mm256_castpd256_pd128(v);
390        let sum2 = _mm_add_pd(lo, hi);
391        let hi64 = _mm_unpackhi_pd(sum2, sum2);
392        _mm_cvtsd_f64(_mm_add_sd(sum2, hi64))
393    }
394
395    let len = data.len();
396    let start = first_valid + 1;
397    let init_end = first_valid + period;
398
399    let inv_period = 1.0 / (period as f64);
400    let period_m1 = (period - 1) as f64;
401
402    let mut acc_gain_v = _mm256_setzero_pd();
403    let mut acc_loss_v = _mm256_setzero_pd();
404    let half_v = _mm256_set1_pd(0.5);
405    let abs_mask = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FFF_FFFF_FFFF_FFFFu64 as i64));
406
407    let mut sum_gain = 0.0f64;
408    let mut sum_loss = 0.0f64;
409
410    let mut i = start;
411    while i + 3 <= init_end {
412        let curr_v = _mm256_loadu_pd(data.as_ptr().add(i));
413        let prev_v = _mm256_loadu_pd(data.as_ptr().add(i - 1));
414        let diff_v = _mm256_sub_pd(curr_v, prev_v);
415
416        let ad_v = _mm256_and_pd(diff_v, abs_mask);
417        let gain_v = _mm256_mul_pd(_mm256_add_pd(ad_v, diff_v), half_v);
418        let loss_v = _mm256_mul_pd(_mm256_sub_pd(ad_v, diff_v), half_v);
419
420        acc_gain_v = _mm256_add_pd(acc_gain_v, gain_v);
421        acc_loss_v = _mm256_add_pd(acc_loss_v, loss_v);
422
423        i += 4;
424    }
425
426    sum_gain += hsum256_pd(acc_gain_v);
427    sum_loss += hsum256_pd(acc_loss_v);
428
429    let mut prev = if i == start {
430        *data.get_unchecked(first_valid)
431    } else {
432        *data.get_unchecked(i - 1)
433    };
434
435    while i <= init_end {
436        let curr = *data.get_unchecked(i);
437        let diff = curr - prev;
438        prev = curr;
439
440        let ad = diff.abs();
441        sum_gain += 0.5 * (ad + diff);
442        sum_loss += 0.5 * (ad - diff);
443        i += 1;
444    }
445
446    let mut avg_gain = sum_gain * inv_period;
447    let mut avg_loss = sum_loss * inv_period;
448    {
449        let sum_gl = avg_gain + avg_loss;
450        *out.get_unchecked_mut(init_end) = if sum_gl != 0.0 {
451            100.0 * ((avg_gain - avg_loss) / sum_gl)
452        } else {
453            0.0
454        };
455    }
456
457    while i < len {
458        let curr = *data.get_unchecked(i);
459        let diff = curr - prev;
460        prev = curr;
461
462        let ad = diff.abs();
463        let gain = 0.5 * (ad + diff);
464        let loss = 0.5 * (ad - diff);
465
466        avg_gain *= period_m1;
467        avg_loss *= period_m1;
468        avg_gain += gain;
469        avg_loss += loss;
470        avg_gain *= inv_period;
471        avg_loss *= inv_period;
472
473        let sum_gl = avg_gain + avg_loss;
474        *out.get_unchecked_mut(i) = if sum_gl != 0.0 {
475            100.0 * ((avg_gain - avg_loss) / sum_gl)
476        } else {
477            0.0
478        };
479
480        i += 1;
481    }
482}
483
484#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
485#[target_feature(enable = "avx512f")]
486unsafe fn cmo_avx512_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
487    use core::arch::x86_64::*;
488
489    debug_assert!(out.len() == data.len());
490
491    #[inline(always)]
492    unsafe fn hsum256_pd(v: __m256d) -> f64 {
493        let hi = _mm256_extractf128_pd(v, 1);
494        let lo = _mm256_castpd256_pd128(v);
495        let sum2 = _mm_add_pd(lo, hi);
496        let hi64 = _mm_unpackhi_pd(sum2, sum2);
497        _mm_cvtsd_f64(_mm_add_sd(sum2, hi64))
498    }
499
500    #[inline(always)]
501    unsafe fn hsum512_pd(v: __m512d) -> f64 {
502        let lo256 = _mm512_castpd512_pd256(v);
503        let hi256 = _mm512_extractf64x4_pd(v, 1);
504        hsum256_pd(_mm256_add_pd(lo256, hi256))
505    }
506
507    let len = data.len();
508    let start = first_valid + 1;
509    let init_end = first_valid + period;
510
511    let inv_period = 1.0 / (period as f64);
512    let period_m1 = (period - 1) as f64;
513
514    let mut acc_gain_v = _mm512_setzero_pd();
515    let mut acc_loss_v = _mm512_setzero_pd();
516    let half_v = _mm512_set1_pd(0.5);
517    let abs_mask_i = _mm512_set1_epi64(0x7FFF_FFFF_FFFF_FFFFu64 as i64);
518
519    let mut sum_gain = 0.0f64;
520    let mut sum_loss = 0.0f64;
521
522    let mut i = start;
523    while i + 7 <= init_end {
524        let curr_v = _mm512_loadu_pd(data.as_ptr().add(i));
525        let prev_v = _mm512_loadu_pd(data.as_ptr().add(i - 1));
526        let diff_v = _mm512_sub_pd(curr_v, prev_v);
527
528        let diff_i = _mm512_castpd_si512(diff_v);
529        let abs_i = _mm512_and_si512(diff_i, abs_mask_i);
530        let ad_v = _mm512_castsi512_pd(abs_i);
531
532        let gain_v = _mm512_mul_pd(_mm512_add_pd(ad_v, diff_v), half_v);
533        let loss_v = _mm512_mul_pd(_mm512_sub_pd(ad_v, diff_v), half_v);
534
535        acc_gain_v = _mm512_add_pd(acc_gain_v, gain_v);
536        acc_loss_v = _mm512_add_pd(acc_loss_v, loss_v);
537
538        i += 8;
539    }
540
541    sum_gain += hsum512_pd(acc_gain_v);
542    sum_loss += hsum512_pd(acc_loss_v);
543
544    let mut prev = if i == start {
545        *data.get_unchecked(first_valid)
546    } else {
547        *data.get_unchecked(i - 1)
548    };
549
550    while i <= init_end {
551        let curr = *data.get_unchecked(i);
552        let diff = curr - prev;
553        prev = curr;
554
555        let ad = diff.abs();
556        sum_gain += 0.5 * (ad + diff);
557        sum_loss += 0.5 * (ad - diff);
558        i += 1;
559    }
560
561    let mut avg_gain = sum_gain * inv_period;
562    let mut avg_loss = sum_loss * inv_period;
563    {
564        let sum_gl = avg_gain + avg_loss;
565        *out.get_unchecked_mut(init_end) = if sum_gl != 0.0 {
566            100.0 * ((avg_gain - avg_loss) / sum_gl)
567        } else {
568            0.0
569        };
570    }
571
572    while i < len {
573        let curr = *data.get_unchecked(i);
574        let diff = curr - prev;
575        prev = curr;
576
577        let ad = diff.abs();
578        let gain = 0.5 * (ad + diff);
579        let loss = 0.5 * (ad - diff);
580
581        avg_gain *= period_m1;
582        avg_loss *= period_m1;
583        avg_gain += gain;
584        avg_loss += loss;
585        avg_gain *= inv_period;
586        avg_loss *= inv_period;
587
588        let sum_gl = avg_gain + avg_loss;
589        *out.get_unchecked_mut(i) = if sum_gl != 0.0 {
590            100.0 * ((avg_gain - avg_loss) / sum_gl)
591        } else {
592            0.0
593        };
594
595        i += 1;
596    }
597}
598
599#[inline(always)]
600pub fn cmo_batch_with_kernel(
601    data: &[f64],
602    sweep: &CmoBatchRange,
603    k: Kernel,
604) -> Result<CmoBatchOutput, CmoError> {
605    let kernel = match k {
606        Kernel::Auto => Kernel::ScalarBatch,
607        other if other.is_batch() => other,
608        _ => return Err(CmoError::InvalidKernelForBatch(k)),
609    };
610    let simd = match kernel {
611        Kernel::Avx512Batch => Kernel::Avx512,
612        Kernel::Avx2Batch => Kernel::Avx2,
613        Kernel::ScalarBatch => Kernel::Scalar,
614        _ => unreachable!(),
615    };
616    cmo_batch_par_slice(data, sweep, simd)
617}
618
619#[derive(Clone, Debug)]
620pub struct CmoBatchRange {
621    pub period: (usize, usize, usize),
622}
623
624impl Default for CmoBatchRange {
625    fn default() -> Self {
626        Self {
627            period: (14, 263, 1),
628        }
629    }
630}
631
632#[derive(Clone, Debug, Default)]
633pub struct CmoBatchBuilder {
634    range: CmoBatchRange,
635    kernel: Kernel,
636}
637
638impl CmoBatchBuilder {
639    pub fn new() -> Self {
640        Self::default()
641    }
642    pub fn kernel(mut self, k: Kernel) -> Self {
643        self.kernel = k;
644        self
645    }
646    #[inline]
647    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
648        self.range.period = (start, end, step);
649        self
650    }
651    #[inline]
652    pub fn period_static(mut self, p: usize) -> Self {
653        self.range.period = (p, p, 0);
654        self
655    }
656    pub fn apply_slice(self, data: &[f64]) -> Result<CmoBatchOutput, CmoError> {
657        cmo_batch_with_kernel(data, &self.range, self.kernel)
658    }
659    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CmoBatchOutput, CmoError> {
660        CmoBatchBuilder::new().kernel(k).apply_slice(data)
661    }
662    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CmoBatchOutput, CmoError> {
663        let slice = source_type(c, src);
664        self.apply_slice(slice)
665    }
666    pub fn with_default_candles(c: &Candles) -> Result<CmoBatchOutput, CmoError> {
667        CmoBatchBuilder::new()
668            .kernel(Kernel::Auto)
669            .apply_candles(c, "close")
670    }
671}
672
673#[derive(Clone, Debug)]
674pub struct CmoBatchOutput {
675    pub values: Vec<f64>,
676    pub combos: Vec<CmoParams>,
677    pub rows: usize,
678    pub cols: usize,
679}
680
681impl CmoBatchOutput {
682    pub fn row_for_params(&self, p: &CmoParams) -> Option<usize> {
683        self.combos
684            .iter()
685            .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
686    }
687    pub fn values_for(&self, p: &CmoParams) -> Option<&[f64]> {
688        self.row_for_params(p).map(|row| {
689            let start = row * self.cols;
690            &self.values[start..start + self.cols]
691        })
692    }
693}
694
695#[inline(always)]
696fn expand_grid(r: &CmoBatchRange) -> Vec<CmoParams> {
697    fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
698        if step == 0 || start == end {
699            return vec![start];
700        }
701        let mut vals = Vec::new();
702        if start < end {
703            let mut x = start;
704            while x <= end {
705                vals.push(x);
706                let next = x.saturating_add(step);
707                if next == x {
708                    break;
709                }
710                x = next;
711            }
712        } else {
713            let mut x = start;
714            loop {
715                vals.push(x);
716                if x <= end {
717                    break;
718                }
719                let next = x.saturating_sub(step);
720                if next >= x {
721                    break;
722                }
723                x = next;
724            }
725        }
726        vals
727    }
728    let periods = axis_usize(r.period);
729    let mut out = Vec::with_capacity(periods.len());
730    for &p in &periods {
731        out.push(CmoParams { period: Some(p) });
732    }
733    out
734}
735
736#[inline(always)]
737pub fn cmo_batch_slice(
738    data: &[f64],
739    sweep: &CmoBatchRange,
740    kern: Kernel,
741) -> Result<CmoBatchOutput, CmoError> {
742    cmo_batch_inner(data, sweep, kern, false)
743}
744
745#[inline(always)]
746pub fn cmo_batch_par_slice(
747    data: &[f64],
748    sweep: &CmoBatchRange,
749    kern: Kernel,
750) -> Result<CmoBatchOutput, CmoError> {
751    cmo_batch_inner(data, sweep, kern, true)
752}
753
754#[inline(always)]
755fn cmo_batch_inner(
756    data: &[f64],
757    sweep: &CmoBatchRange,
758    kern: Kernel,
759    parallel: bool,
760) -> Result<CmoBatchOutput, CmoError> {
761    let combos = expand_grid(sweep);
762    if combos.is_empty() {
763        return Err(CmoError::InvalidRange {
764            start: sweep.period.0,
765            end: sweep.period.1,
766            step: sweep.period.2,
767        });
768    }
769    let first = data
770        .iter()
771        .position(|x| !x.is_nan())
772        .ok_or(CmoError::AllValuesNaN)?;
773    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
774    let _ = combos
775        .len()
776        .checked_mul(max_p)
777        .ok_or(CmoError::InvalidRange {
778            start: sweep.period.0,
779            end: sweep.period.1,
780            step: sweep.period.2,
781        })?;
782    if data.len() - first <= max_p {
783        return Err(CmoError::NotEnoughValidData {
784            needed: max_p + 1,
785            valid: data.len() - first,
786        });
787    }
788    let rows = combos.len();
789    let cols = data.len();
790    let _expected = rows.checked_mul(cols).ok_or(CmoError::InvalidRange {
791        start: sweep.period.0,
792        end: sweep.period.1,
793        step: sweep.period.2,
794    })?;
795
796    let mut buf_mu = make_uninit_matrix(rows, cols);
797
798    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
799
800    init_matrix_prefixes(&mut buf_mu, cols, &warm);
801
802    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
803    let out: &mut [f64] = unsafe {
804        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
805    };
806
807    let len = data.len();
808    let start = first + 1;
809    let mut gains = vec![0.0f64; len];
810    let mut losses = vec![0.0f64; len];
811    for i in start..len {
812        let diff = data[i] - data[i - 1];
813        let ad = diff.abs();
814        gains[i] = 0.5 * (ad + diff);
815        losses[i] = 0.5 * (ad - diff);
816    }
817    let mut pg = vec![0.0f64; len + 1];
818    let mut pl = vec![0.0f64; len + 1];
819    for i in 0..len {
820        pg[i + 1] = pg[i] + gains[i];
821        pl[i + 1] = pl[i] + losses[i];
822    }
823
824    let do_row = |row: usize, out_row: &mut [f64]| unsafe {
825        let period = combos[row].period.unwrap();
826        cmo_row_from_gl(&gains, &losses, &pg, &pl, first, period, out_row);
827    };
828
829    if parallel {
830        #[cfg(not(target_arch = "wasm32"))]
831        {
832            out.par_chunks_mut(cols)
833                .enumerate()
834                .for_each(|(row, slice)| do_row(row, slice));
835        }
836
837        #[cfg(target_arch = "wasm32")]
838        {
839            for (row, slice) in out.chunks_mut(cols).enumerate() {
840                do_row(row, slice);
841            }
842        }
843    } else {
844        for (row, slice) in out.chunks_mut(cols).enumerate() {
845            do_row(row, slice);
846        }
847    }
848
849    let values = unsafe {
850        Vec::from_raw_parts(
851            buf_guard.as_mut_ptr() as *mut f64,
852            buf_guard.len(),
853            buf_guard.capacity(),
854        )
855    };
856
857    Ok(CmoBatchOutput {
858        values,
859        combos,
860        rows,
861        cols,
862    })
863}
864
865#[inline(always)]
866fn cmo_batch_inner_into(
867    data: &[f64],
868    sweep: &CmoBatchRange,
869    kern: Kernel,
870    parallel: bool,
871    out: &mut [f64],
872) -> Result<Vec<CmoParams>, CmoError> {
873    let combos = expand_grid(sweep);
874    if combos.is_empty() {
875        return Err(CmoError::InvalidRange {
876            start: sweep.period.0,
877            end: sweep.period.1,
878            step: sweep.period.2,
879        });
880    }
881    let first = data
882        .iter()
883        .position(|x| !x.is_nan())
884        .ok_or(CmoError::AllValuesNaN)?;
885    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
886    let _ = combos
887        .len()
888        .checked_mul(max_p)
889        .ok_or(CmoError::InvalidRange {
890            start: sweep.period.0,
891            end: sweep.period.1,
892            step: sweep.period.2,
893        })?;
894    if data.len() - first <= max_p {
895        return Err(CmoError::NotEnoughValidData {
896            needed: max_p + 1,
897            valid: data.len() - first,
898        });
899    }
900    let cols = data.len();
901    let rows = combos.len();
902    let expected = rows.checked_mul(cols).ok_or(CmoError::InvalidRange {
903        start: sweep.period.0,
904        end: sweep.period.1,
905        step: sweep.period.2,
906    })?;
907    if out.len() != expected {
908        return Err(CmoError::OutputLengthMismatch {
909            expected,
910            got: out.len(),
911        });
912    }
913
914    let out_mu: &mut [MaybeUninit<f64>] = unsafe {
915        std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
916    };
917    let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
918    init_matrix_prefixes(out_mu, cols, &warm);
919
920    let len = data.len();
921    let start = first + 1;
922    let mut gains = vec![0.0f64; len];
923    let mut losses = vec![0.0f64; len];
924    for i in start..len {
925        let diff = data[i] - data[i - 1];
926        let ad = diff.abs();
927        gains[i] = 0.5 * (ad + diff);
928        losses[i] = 0.5 * (ad - diff);
929    }
930    let mut pg = vec![0.0f64; len + 1];
931    let mut pl = vec![0.0f64; len + 1];
932    for i in 0..len {
933        pg[i + 1] = pg[i] + gains[i];
934        pl[i + 1] = pl[i] + losses[i];
935    }
936
937    let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| unsafe {
938        let period = combos[row].period.unwrap();
939        let row_dst: &mut [f64] =
940            std::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
941        cmo_row_from_gl(&gains, &losses, &pg, &pl, first, period, row_dst);
942    };
943
944    if parallel {
945        #[cfg(not(target_arch = "wasm32"))]
946        {
947            out_mu
948                .par_chunks_mut(cols)
949                .enumerate()
950                .for_each(|(r, row_mu)| do_row(r, row_mu));
951        }
952        #[cfg(target_arch = "wasm32")]
953        {
954            for (r, row_mu) in out_mu.chunks_mut(cols).enumerate() {
955                do_row(r, row_mu);
956            }
957        }
958    } else {
959        for (r, row_mu) in out_mu.chunks_mut(cols).enumerate() {
960            do_row(r, row_mu);
961        }
962    }
963
964    Ok(combos)
965}
966
967#[inline]
968pub fn cmo_batch_into_slice(
969    out: &mut [f64],
970    data: &[f64],
971    sweep: &CmoBatchRange,
972    k: Kernel,
973) -> Result<Vec<CmoParams>, CmoError> {
974    let kernel = match k {
975        Kernel::Auto => detect_best_batch_kernel(),
976        other if other.is_batch() => other,
977        _ => return Err(CmoError::InvalidKernelForBatch(k)),
978    };
979    let simd = match kernel {
980        Kernel::Avx512Batch => Kernel::Avx512,
981        Kernel::Avx2Batch => Kernel::Avx2,
982        Kernel::ScalarBatch => Kernel::Scalar,
983        _ => unreachable!(),
984    };
985    cmo_batch_inner_into(data, sweep, simd, true, out)
986}
987
988#[inline(always)]
989unsafe fn cmo_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
990    cmo_scalar(data, period, first, out);
991}
992
993#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
994#[inline(always)]
995unsafe fn cmo_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
996    cmo_avx2(data, period, first, out);
997}
998
999#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1000#[inline(always)]
1001unsafe fn cmo_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1002    if period <= 32 {
1003        cmo_row_avx512_short(data, first, period, out);
1004    } else {
1005        cmo_row_avx512_long(data, first, period, out);
1006    }
1007}
1008
1009#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1010#[inline(always)]
1011unsafe fn cmo_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1012    cmo_avx512_short(data, period, first, out)
1013}
1014
1015#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1016#[inline(always)]
1017unsafe fn cmo_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1018    cmo_avx512_long(data, period, first, out)
1019}
1020
1021#[inline(always)]
1022unsafe fn cmo_row_from_gl(
1023    gains: &[f64],
1024    losses: &[f64],
1025    pg: &[f64],
1026    pl: &[f64],
1027    first: usize,
1028    period: usize,
1029    out: &mut [f64],
1030) {
1031    let len = out.len();
1032    let start = first + 1;
1033    let init_end = first + period;
1034    let inv_period = 1.0 / (period as f64);
1035    let period_m1 = (period - 1) as f64;
1036
1037    let sum_gain = pg[init_end + 1] - pg[start];
1038    let sum_loss = pl[init_end + 1] - pl[start];
1039    let mut avg_gain = sum_gain * inv_period;
1040    let mut avg_loss = sum_loss * inv_period;
1041
1042    {
1043        let sum_gl = avg_gain + avg_loss;
1044        *out.get_unchecked_mut(init_end) = if sum_gl != 0.0 {
1045            100.0 * ((avg_gain - avg_loss) / sum_gl)
1046        } else {
1047            0.0
1048        };
1049    }
1050
1051    let mut i = init_end + 1;
1052    while i < len {
1053        let g = *gains.get_unchecked(i);
1054        let l = *losses.get_unchecked(i);
1055
1056        avg_gain *= period_m1;
1057        avg_loss *= period_m1;
1058        avg_gain += g;
1059        avg_loss += l;
1060        avg_gain *= inv_period;
1061        avg_loss *= inv_period;
1062
1063        let sum_gl = avg_gain + avg_loss;
1064        *out.get_unchecked_mut(i) = if sum_gl != 0.0 {
1065            100.0 * ((avg_gain - avg_loss) / sum_gl)
1066        } else {
1067            0.0
1068        };
1069        i += 1;
1070    }
1071}
1072
1073#[derive(Debug, Clone)]
1074pub struct CmoStream {
1075    period: usize,
1076    inv_period: f64,
1077    avg_gain: f64,
1078    avg_loss: f64,
1079    prev: f64,
1080    head: usize,
1081    started: bool,
1082    filled: bool,
1083}
1084
1085impl CmoStream {
1086    pub fn try_new(params: CmoParams) -> Result<Self, CmoError> {
1087        let period = params.period.unwrap_or(14);
1088        if period == 0 {
1089            return Err(CmoError::InvalidPeriod {
1090                period,
1091                data_len: 0,
1092            });
1093        }
1094        Ok(Self {
1095            period,
1096            inv_period: 1.0 / (period as f64),
1097            avg_gain: 0.0,
1098            avg_loss: 0.0,
1099            prev: 0.0,
1100            head: 0,
1101            started: false,
1102            filled: false,
1103        })
1104    }
1105    #[inline(always)]
1106    pub fn update(&mut self, value: f64) -> Option<f64> {
1107        if !self.started {
1108            self.prev = value;
1109            self.started = true;
1110            return None;
1111        }
1112
1113        let diff = value - self.prev;
1114        self.prev = value;
1115
1116        let ad = diff.abs();
1117        let gain = 0.5 * (ad + diff);
1118        let loss = 0.5 * (ad - diff);
1119
1120        if !self.filled {
1121            self.avg_gain += gain;
1122            self.avg_loss += loss;
1123            self.head += 1;
1124
1125            if self.head == self.period {
1126                self.avg_gain *= self.inv_period;
1127                self.avg_loss *= self.inv_period;
1128                self.filled = true;
1129
1130                let denom = self.avg_gain + self.avg_loss;
1131                return Some(if denom != 0.0 {
1132                    100.0 * (self.avg_gain - self.avg_loss) / denom
1133                } else {
1134                    0.0
1135                });
1136            }
1137            return None;
1138        }
1139
1140        let ip = self.inv_period;
1141        self.avg_gain = (gain - self.avg_gain).mul_add(ip, self.avg_gain);
1142        self.avg_loss = (loss - self.avg_loss).mul_add(ip, self.avg_loss);
1143
1144        let denom = self.avg_gain + self.avg_loss;
1145        Some(if denom != 0.0 {
1146            100.0 * (self.avg_gain - self.avg_loss) / denom
1147        } else {
1148            0.0
1149        })
1150    }
1151}
1152
1153#[cfg(feature = "python")]
1154use crate::utilities::kernel_validation::validate_kernel;
1155#[cfg(feature = "python")]
1156use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1157#[cfg(feature = "python")]
1158use pyo3::exceptions::PyValueError;
1159#[cfg(feature = "python")]
1160use pyo3::prelude::*;
1161#[cfg(feature = "python")]
1162use pyo3::types::PyDict;
1163
1164#[cfg(all(feature = "python", feature = "cuda"))]
1165use crate::cuda::oscillators::CudaCmo;
1166#[cfg(all(feature = "python", feature = "cuda"))]
1167use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
1168
1169#[cfg(feature = "python")]
1170#[pyfunction(name = "cmo")]
1171#[pyo3(signature = (data, period=None, kernel=None))]
1172pub fn cmo_py<'py>(
1173    py: Python<'py>,
1174    data: PyReadonlyArray1<'py, f64>,
1175    period: Option<usize>,
1176    kernel: Option<&str>,
1177) -> PyResult<Bound<'py, PyArray1<f64>>> {
1178    let slice_in = data.as_slice()?;
1179    let kern = validate_kernel(kernel, false)?;
1180
1181    let params = CmoParams { period };
1182    let input = CmoInput::from_slice(slice_in, params);
1183
1184    let result_vec: Vec<f64> = py
1185        .allow_threads(|| cmo_with_kernel(&input, kern).map(|o| o.values))
1186        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1187
1188    Ok(result_vec.into_pyarray(py))
1189}
1190
1191#[cfg(feature = "python")]
1192#[pyclass(name = "CmoStream")]
1193pub struct CmoStreamPy {
1194    stream: CmoStream,
1195}
1196
1197#[cfg(feature = "python")]
1198#[pymethods]
1199impl CmoStreamPy {
1200    #[new]
1201    fn new(period: Option<usize>) -> PyResult<Self> {
1202        let params = CmoParams { period };
1203        let stream =
1204            CmoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1205        Ok(CmoStreamPy { stream })
1206    }
1207
1208    fn update(&mut self, value: f64) -> Option<f64> {
1209        self.stream.update(value)
1210    }
1211}
1212
1213#[cfg(feature = "python")]
1214#[pyfunction(name = "cmo_batch")]
1215#[pyo3(signature = (data, period_range, kernel=None))]
1216pub fn cmo_batch_py<'py>(
1217    py: Python<'py>,
1218    data: PyReadonlyArray1<'py, f64>,
1219    period_range: (usize, usize, usize),
1220    kernel: Option<&str>,
1221) -> PyResult<Bound<'py, PyDict>> {
1222    let slice_in = data.as_slice()?;
1223
1224    let sweep = CmoBatchRange {
1225        period: period_range,
1226    };
1227
1228    let combos = expand_grid(&sweep);
1229    let rows = combos.len();
1230    let cols = slice_in.len();
1231    let total = rows.checked_mul(cols).ok_or_else(|| {
1232        PyValueError::new_err(format!(
1233            "cmo_batch: size overflow for rows={} cols={}",
1234            rows, cols
1235        ))
1236    })?;
1237
1238    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1239    let slice_out = unsafe { out_arr.as_slice_mut()? };
1240
1241    let kern = validate_kernel(kernel, true)?;
1242
1243    let combos = py
1244        .allow_threads(|| {
1245            let kernel = match kern {
1246                Kernel::Auto => detect_best_batch_kernel(),
1247                k => k,
1248            };
1249            let simd = match kernel {
1250                Kernel::Avx512Batch => Kernel::Avx512,
1251                Kernel::Avx2Batch => Kernel::Avx2,
1252                Kernel::ScalarBatch => Kernel::Scalar,
1253                _ => unreachable!(),
1254            };
1255            cmo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1256        })
1257        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1258
1259    let dict = PyDict::new(py);
1260    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1261    dict.set_item(
1262        "periods",
1263        combos
1264            .iter()
1265            .map(|p| p.period.unwrap() as u64)
1266            .collect::<Vec<_>>()
1267            .into_pyarray(py),
1268    )?;
1269
1270    Ok(dict)
1271}
1272
1273#[cfg(all(feature = "python", feature = "cuda"))]
1274#[pyfunction(name = "cmo_cuda_batch_dev")]
1275#[pyo3(signature = (data_f32, period_range, device_id=0))]
1276pub fn cmo_cuda_batch_dev_py<'py>(
1277    py: Python<'py>,
1278    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1279    period_range: (usize, usize, usize),
1280    device_id: usize,
1281) -> PyResult<(DeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
1282    use crate::cuda::cuda_available;
1283    use numpy::IntoPyArray;
1284    use pyo3::types::PyDict;
1285
1286    if !cuda_available() {
1287        return Err(PyValueError::new_err("CUDA not available"));
1288    }
1289    let prices = data_f32.as_slice()?;
1290    let sweep = CmoBatchRange {
1291        period: period_range,
1292    };
1293    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1294        let cuda = CudaCmo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1295        let ctx_arc = cuda.context_arc();
1296        let dev_id = cuda.device_id();
1297        cuda.cmo_batch_dev(prices, &sweep)
1298            .map_err(|e| PyValueError::new_err(e.to_string()))
1299            .map(|inner| (inner, ctx_arc, dev_id))
1300    })?;
1301
1302    let dict = PyDict::new(py);
1303    let periods: Vec<u64> = expand_grid(&sweep)
1304        .iter()
1305        .map(|p| p.period.unwrap_or(14) as u64)
1306        .collect();
1307    dict.set_item("periods", periods.into_pyarray(py))?;
1308
1309    Ok((
1310        DeviceArrayF32Py {
1311            inner,
1312            _ctx: Some(ctx_arc),
1313            device_id: Some(dev_id),
1314        },
1315        dict,
1316    ))
1317}
1318
1319#[cfg(all(feature = "python", feature = "cuda"))]
1320#[pyfunction(name = "cmo_cuda_many_series_one_param_dev")]
1321#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1322pub fn cmo_cuda_many_series_one_param_dev_py(
1323    py: Python<'_>,
1324    data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1325    period: usize,
1326    device_id: usize,
1327) -> PyResult<DeviceArrayF32Py> {
1328    use crate::cuda::cuda_available;
1329    use numpy::PyUntypedArrayMethods;
1330    if !cuda_available() {
1331        return Err(PyValueError::new_err("CUDA not available"));
1332    }
1333    let flat = data_tm_f32.as_slice()?;
1334    let rows = data_tm_f32.shape()[0];
1335    let cols = data_tm_f32.shape()[1];
1336    let params = CmoParams {
1337        period: Some(period),
1338    };
1339    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1340        let cuda = CudaCmo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1341        let ctx_arc = cuda.context_arc();
1342        let dev_id = cuda.device_id();
1343        cuda.cmo_many_series_one_param_time_major_dev(flat, cols, rows, &params)
1344            .map_err(|e| PyValueError::new_err(e.to_string()))
1345            .map(|inner| (inner, ctx_arc, dev_id))
1346    })?;
1347    Ok(DeviceArrayF32Py {
1348        inner,
1349        _ctx: Some(ctx_arc),
1350        device_id: Some(dev_id),
1351    })
1352}
1353
1354#[cfg(test)]
1355mod tests {
1356    use super::*;
1357    use crate::skip_if_unsupported;
1358    use crate::utilities::data_loader::read_candles_from_csv;
1359
1360    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1361    #[test]
1362    fn test_cmo_into_matches_api() -> Result<(), Box<dyn Error>> {
1363        let mut data = vec![f64::NAN; 3];
1364        data.extend((0..256).map(|i| {
1365            let x = i as f64;
1366            (x * 0.07).sin() * 5.0 + x * 0.1
1367        }));
1368
1369        let input = CmoInput::from_slice(&data, CmoParams::default());
1370
1371        let baseline = cmo_with_kernel(&input, Kernel::Auto)?.values;
1372
1373        let mut out = vec![0.0; data.len()];
1374        cmo_into(&input, &mut out)?;
1375
1376        assert_eq!(baseline.len(), out.len());
1377
1378        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1379            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1380        }
1381
1382        for i in 0..out.len() {
1383            assert!(
1384                eq_or_both_nan(baseline[i], out[i]),
1385                "mismatch at {}: baseline={} out={}",
1386                i,
1387                baseline[i],
1388                out[i]
1389            );
1390        }
1391
1392        Ok(())
1393    }
1394
1395    fn check_cmo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1396        skip_if_unsupported!(kernel, test_name);
1397        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1398        let candles = read_candles_from_csv(file_path)?;
1399        let default_params = CmoParams { period: None };
1400        let input = CmoInput::from_candles(&candles, "close", default_params);
1401        let output = cmo_with_kernel(&input, kernel)?;
1402        assert_eq!(output.values.len(), candles.close.len());
1403        let params_10 = CmoParams { period: Some(10) };
1404        let input_10 = CmoInput::from_candles(&candles, "hl2", params_10);
1405        let output_10 = cmo_with_kernel(&input_10, kernel)?;
1406        assert_eq!(output_10.values.len(), candles.close.len());
1407        Ok(())
1408    }
1409
1410    fn check_cmo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1411        skip_if_unsupported!(kernel, test_name);
1412        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1413        let candles = read_candles_from_csv(file_path)?;
1414        let params = CmoParams { period: Some(14) };
1415        let input = CmoInput::from_candles(&candles, "close", params);
1416        let cmo_result = cmo_with_kernel(&input, kernel)?;
1417        let expected_last_five = [
1418            -13.152504931406101,
1419            -14.649876201213106,
1420            -16.760170709240303,
1421            -14.274505732779227,
1422            -21.984038127126716,
1423        ];
1424        let start_idx = cmo_result.values.len() - 5;
1425        let last_five = &cmo_result.values[start_idx..];
1426        for (i, &actual) in last_five.iter().enumerate() {
1427            let expected = expected_last_five[i];
1428            assert!(
1429                (actual - expected).abs() < 1e-6,
1430                "[{}] CMO mismatch at final 5 index {}: expected {}, got {}",
1431                test_name,
1432                i,
1433                expected,
1434                actual
1435            );
1436        }
1437        Ok(())
1438    }
1439
1440    fn check_cmo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1441        skip_if_unsupported!(kernel, test_name);
1442        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1443        let candles = read_candles_from_csv(file_path)?;
1444        let input = CmoInput::with_default_candles(&candles);
1445        match input.data {
1446            CmoData::Candles { source, .. } => assert_eq!(source, "close"),
1447            _ => panic!("Expected CmoData::Candles variant"),
1448        }
1449        let output = cmo_with_kernel(&input, kernel)?;
1450        assert_eq!(output.values.len(), candles.close.len());
1451        Ok(())
1452    }
1453
1454    fn check_cmo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1455        skip_if_unsupported!(kernel, test_name);
1456        let data = [10.0, 20.0, 30.0];
1457        let params = CmoParams { period: Some(0) };
1458        let input = CmoInput::from_slice(&data, params);
1459        let result = cmo_with_kernel(&input, kernel);
1460        assert!(
1461            result.is_err(),
1462            "[{}] Expected error for period=0",
1463            test_name
1464        );
1465        Ok(())
1466    }
1467
1468    fn check_cmo_period_exceeds_length(
1469        test_name: &str,
1470        kernel: Kernel,
1471    ) -> Result<(), Box<dyn Error>> {
1472        skip_if_unsupported!(kernel, test_name);
1473        let data = [10.0, 20.0, 30.0];
1474        let params = CmoParams { period: Some(10) };
1475        let input = CmoInput::from_slice(&data, params);
1476        let result = cmo_with_kernel(&input, kernel);
1477        assert!(
1478            result.is_err(),
1479            "[{}] Expected error for period>data.len()",
1480            test_name
1481        );
1482        Ok(())
1483    }
1484
1485    fn check_cmo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1486        skip_if_unsupported!(kernel, test_name);
1487        let single = [42.0];
1488        let params = CmoParams { period: Some(14) };
1489        let input = CmoInput::from_slice(&single, params);
1490        let result = cmo_with_kernel(&input, kernel);
1491        assert!(
1492            result.is_err(),
1493            "[{}] Expected error for insufficient data",
1494            test_name
1495        );
1496        Ok(())
1497    }
1498
1499    fn check_cmo_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1500        skip_if_unsupported!(kernel, test_name);
1501        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1502        let candles = read_candles_from_csv(file_path)?;
1503        let first_params = CmoParams { period: Some(14) };
1504        let first_input = CmoInput::from_candles(&candles, "close", first_params);
1505        let first_result = cmo_with_kernel(&first_input, kernel)?;
1506        let second_params = CmoParams { period: Some(14) };
1507        let second_input = CmoInput::from_slice(&first_result.values, second_params);
1508        let second_result = cmo_with_kernel(&second_input, kernel)?;
1509        assert_eq!(second_result.values.len(), first_result.values.len());
1510        for i in 28..second_result.values.len() {
1511            assert!(
1512                !second_result.values[i].is_nan(),
1513                "[{}] Expected no NaN after index 28, found NaN at {}",
1514                test_name,
1515                i
1516            );
1517        }
1518        Ok(())
1519    }
1520
1521    #[cfg(debug_assertions)]
1522    fn check_cmo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1523        skip_if_unsupported!(kernel, test_name);
1524
1525        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1526        let candles = read_candles_from_csv(file_path)?;
1527
1528        let test_periods = vec![7, 14, 21, 28];
1529
1530        for period in test_periods {
1531            let params = CmoParams {
1532                period: Some(period),
1533            };
1534            let input = CmoInput::from_candles(&candles, "close", params);
1535            let output = cmo_with_kernel(&input, kernel)?;
1536
1537            for (i, &val) in output.values.iter().enumerate() {
1538                if val.is_nan() {
1539                    continue;
1540                }
1541
1542                let bits = val.to_bits();
1543
1544                if bits == 0x11111111_11111111 {
1545                    panic!(
1546						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1547						test_name, val, bits, i, period
1548					);
1549                }
1550
1551                if bits == 0x22222222_22222222 {
1552                    panic!(
1553						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1554						test_name, val, bits, i, period
1555					);
1556                }
1557
1558                if bits == 0x33333333_33333333 {
1559                    panic!(
1560						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1561						test_name, val, bits, i, period
1562					);
1563                }
1564            }
1565        }
1566
1567        Ok(())
1568    }
1569
1570    #[cfg(not(debug_assertions))]
1571    fn check_cmo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1572        Ok(())
1573    }
1574
1575    macro_rules! generate_all_cmo_tests {
1576        ($($test_fn:ident),*) => {
1577            paste::paste! {
1578                $(
1579                    #[test]
1580                    fn [<$test_fn _scalar_f64>]() {
1581                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1582                    }
1583                )*
1584                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1585                $(
1586                    #[test]
1587                    fn [<$test_fn _avx2_f64>]() {
1588                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1589                    }
1590                    #[test]
1591                    fn [<$test_fn _avx512_f64>]() {
1592                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1593                    }
1594                )*
1595            }
1596        }
1597    }
1598
1599    #[cfg(feature = "proptest")]
1600    #[allow(clippy::float_cmp)]
1601    fn check_cmo_property(
1602        test_name: &str,
1603        kernel: Kernel,
1604    ) -> Result<(), Box<dyn std::error::Error>> {
1605        use proptest::prelude::*;
1606        skip_if_unsupported!(kernel, test_name);
1607
1608        let strat = (1usize..=50).prop_flat_map(|period| {
1609            (
1610                prop::collection::vec(
1611                    (-1e6f64..1e6f64)
1612                        .prop_filter("finite and non-zero", |x| x.is_finite() && x.abs() > 1e-10),
1613                    (period + 1).max(2)..400,
1614                ),
1615                Just(period),
1616            )
1617        });
1618
1619        proptest::test_runner::TestRunner::default()
1620            .run(&strat, |(data, period)| {
1621                let params = CmoParams {
1622                    period: Some(period),
1623                };
1624                let input = CmoInput::from_slice(&data, params);
1625
1626                let CmoOutput { values: out } = cmo_with_kernel(&input, kernel).unwrap();
1627                let CmoOutput { values: ref_out } =
1628                    cmo_with_kernel(&input, Kernel::Scalar).unwrap();
1629
1630                prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1631
1632                let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1633                let warmup = first_valid + period;
1634
1635                for i in 0..warmup.min(out.len()) {
1636                    prop_assert!(
1637                        out[i].is_nan(),
1638                        "Expected NaN during warmup at index {}, got {}",
1639                        i,
1640                        out[i]
1641                    );
1642                }
1643
1644                if warmup < out.len() {
1645                    prop_assert!(
1646                        !out[warmup].is_nan(),
1647                        "Expected valid value at index {} (first after warmup), got NaN",
1648                        warmup
1649                    );
1650                }
1651
1652                for i in warmup..data.len() {
1653                    let y = out[i];
1654                    let r = ref_out[i];
1655
1656                    prop_assert!(
1657                        y.is_nan() || (y >= -100.0 - 1e-9 && y <= 100.0 + 1e-9),
1658                        "CMO value {} at index {} outside bounds [-100, 100]",
1659                        y,
1660                        i
1661                    );
1662
1663                    if data[..=i].iter().all(|x| x.is_finite()) {
1664                        prop_assert!(
1665                            y.is_finite(),
1666                            "Expected finite output at index {}, got {}",
1667                            i,
1668                            y
1669                        );
1670                    }
1671
1672                    let y_bits = y.to_bits();
1673                    let r_bits = r.to_bits();
1674
1675                    if !y.is_finite() || !r.is_finite() {
1676                        prop_assert_eq!(
1677                            y_bits,
1678                            r_bits,
1679                            "Finite/NaN mismatch at index {}: {} vs {}",
1680                            i,
1681                            y,
1682                            r
1683                        );
1684                        continue;
1685                    }
1686
1687                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1688                    prop_assert!(
1689                        (y - r).abs() <= 1e-9 || ulp_diff <= 8,
1690                        "Kernel mismatch at index {}: {} vs {} (ULP={})",
1691                        i,
1692                        y,
1693                        r,
1694                        ulp_diff
1695                    );
1696                }
1697
1698                if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && warmup < data.len() {
1699                    let cmo_val = out[warmup];
1700                    prop_assert!(
1701                        cmo_val.abs() <= 1e-9,
1702                        "Constant data should produce CMO of 0, got {} at index {}",
1703                        cmo_val,
1704                        warmup
1705                    );
1706                }
1707
1708                let is_increasing = data.windows(2).all(|w| w[1] >= w[0] - 1e-10);
1709                if is_increasing && warmup < data.len() {
1710                    for i in warmup..data.len() {
1711                        prop_assert!(
1712							out[i].is_nan() || out[i] >= -1e-6,
1713							"Monotonically increasing data should produce non-negative CMO, got {} at index {}",
1714							out[i],
1715							i
1716						);
1717                    }
1718                }
1719
1720                let is_decreasing = data.windows(2).all(|w| w[1] <= w[0] + 1e-10);
1721                if is_decreasing && warmup < data.len() {
1722                    for i in warmup..data.len() {
1723                        prop_assert!(
1724							out[i].is_nan() || out[i] <= 1e-6,
1725							"Monotonically decreasing data should produce non-positive CMO, got {} at index {}",
1726							out[i],
1727							i
1728						);
1729                    }
1730                }
1731
1732                if period > 1 && warmup + 5 < data.len() {
1733                    let has_strong_gains = (warmup..data.len().min(warmup + 10))
1734                        .zip(warmup.saturating_sub(1)..data.len().saturating_sub(1).min(warmup + 9))
1735                        .all(|(i, j)| data[i] > data[j] * 1.1);
1736
1737                    if has_strong_gains {
1738                        let last_idx = data.len() - 1;
1739                        prop_assert!(
1740                            out[last_idx].is_nan() || out[last_idx] >= 50.0,
1741                            "Strong gains should produce CMO > 50, got {} at index {}",
1742                            out[last_idx],
1743                            last_idx
1744                        );
1745                    }
1746                }
1747
1748                Ok(())
1749            })
1750            .unwrap();
1751
1752        Ok(())
1753    }
1754
1755    generate_all_cmo_tests!(
1756        check_cmo_partial_params,
1757        check_cmo_accuracy,
1758        check_cmo_default_candles,
1759        check_cmo_zero_period,
1760        check_cmo_period_exceeds_length,
1761        check_cmo_very_small_dataset,
1762        check_cmo_reinput,
1763        check_cmo_no_poison
1764    );
1765
1766    #[cfg(feature = "proptest")]
1767    generate_all_cmo_tests!(check_cmo_property);
1768
1769    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1770        skip_if_unsupported!(kernel, test);
1771        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1772        let c = read_candles_from_csv(file)?;
1773        let output = CmoBatchBuilder::new()
1774            .kernel(kernel)
1775            .apply_candles(&c, "close")?;
1776        let def = CmoParams::default();
1777        let row = output.values_for(&def).expect("default row missing");
1778        assert_eq!(row.len(), c.close.len());
1779        Ok(())
1780    }
1781
1782    #[cfg(debug_assertions)]
1783    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784        skip_if_unsupported!(kernel, test);
1785
1786        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1787        let c = read_candles_from_csv(file)?;
1788
1789        let output = CmoBatchBuilder::new()
1790            .kernel(kernel)
1791            .period_range(7, 28, 7)
1792            .apply_candles(&c, "close")?;
1793
1794        for (idx, &val) in output.values.iter().enumerate() {
1795            if val.is_nan() {
1796                continue;
1797            }
1798
1799            let bits = val.to_bits();
1800            let row = idx / output.cols;
1801            let col = idx % output.cols;
1802
1803            if bits == 0x11111111_11111111 {
1804                panic!(
1805					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1806					test, val, bits, row, col, idx
1807				);
1808            }
1809
1810            if bits == 0x22222222_22222222 {
1811                panic!(
1812					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1813					test, val, bits, row, col, idx
1814				);
1815            }
1816
1817            if bits == 0x33333333_33333333 {
1818                panic!(
1819					"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1820					test, val, bits, row, col, idx
1821				);
1822            }
1823        }
1824
1825        Ok(())
1826    }
1827
1828    #[cfg(not(debug_assertions))]
1829    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1830        Ok(())
1831    }
1832
1833    macro_rules! gen_batch_tests {
1834        ($fn_name:ident) => {
1835            paste::paste! {
1836                #[test] fn [<$fn_name _scalar>]()      {
1837                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1838                }
1839                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1840                #[test] fn [<$fn_name _avx2>]()        {
1841                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1842                }
1843                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1844                #[test] fn [<$fn_name _avx512>]()      {
1845                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1846                }
1847                #[test] fn [<$fn_name _auto_detect>]() {
1848                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1849                }
1850            }
1851        };
1852    }
1853    gen_batch_tests!(check_batch_default_row);
1854    gen_batch_tests!(check_batch_no_poison);
1855}
1856
1857#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1858#[wasm_bindgen]
1859pub fn cmo_js(data: &[f64], period: Option<usize>) -> Result<Vec<f64>, JsValue> {
1860    let params = CmoParams { period };
1861    let input = CmoInput::from_slice(data, params);
1862
1863    let mut output = vec![0.0; data.len()];
1864    cmo_into_slice(&mut output, &input, Kernel::Auto)
1865        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1866
1867    Ok(output)
1868}
1869
1870#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1871#[wasm_bindgen]
1872pub fn cmo_into(
1873    in_ptr: *const f64,
1874    out_ptr: *mut f64,
1875    len: usize,
1876    period: Option<usize>,
1877) -> Result<(), JsValue> {
1878    if in_ptr.is_null() || out_ptr.is_null() {
1879        return Err(JsValue::from_str("Null pointer provided"));
1880    }
1881
1882    unsafe {
1883        let data = std::slice::from_raw_parts(in_ptr, len);
1884        let params = CmoParams { period };
1885        let input = CmoInput::from_slice(data, params);
1886
1887        if in_ptr == out_ptr {
1888            let mut temp = vec![0.0; len];
1889            cmo_into_slice(&mut temp, &input, Kernel::Auto)
1890                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1891            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1892            out.copy_from_slice(&temp);
1893        } else {
1894            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1895            cmo_into_slice(out, &input, Kernel::Auto)
1896                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1897        }
1898        Ok(())
1899    }
1900}
1901
1902#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1903#[wasm_bindgen]
1904pub fn cmo_alloc(len: usize) -> *mut f64 {
1905    let mut vec = Vec::<f64>::with_capacity(len);
1906    let ptr = vec.as_mut_ptr();
1907    std::mem::forget(vec);
1908    ptr
1909}
1910
1911#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1912#[wasm_bindgen]
1913pub fn cmo_free(ptr: *mut f64, len: usize) {
1914    if !ptr.is_null() {
1915        unsafe {
1916            let _ = Vec::from_raw_parts(ptr, len, len);
1917        }
1918    }
1919}
1920
1921#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1922#[derive(Serialize, Deserialize)]
1923pub struct CmoBatchConfig {
1924    pub period_range: (usize, usize, usize),
1925}
1926
1927#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1928#[derive(Serialize, Deserialize)]
1929pub struct CmoBatchJsOutput {
1930    pub values: Vec<f64>,
1931    pub combos: Vec<CmoParams>,
1932    pub rows: usize,
1933    pub cols: usize,
1934}
1935
1936#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1937#[wasm_bindgen(js_name = cmo_batch)]
1938pub fn cmo_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1939    let config: CmoBatchConfig = serde_wasm_bindgen::from_value(config)
1940        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1941
1942    let (p_start, p_end, p_step) = config.period_range;
1943
1944    let batch_range = CmoBatchRange {
1945        period: (p_start, p_end, p_step),
1946    };
1947
1948    let output = cmo_batch_with_kernel(data, &batch_range, Kernel::Auto)
1949        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1950
1951    let js_output = CmoBatchJsOutput {
1952        values: output.values,
1953        combos: output.combos,
1954        rows: output.rows,
1955        cols: output.cols,
1956    };
1957
1958    serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1959}
1960
1961#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1962#[wasm_bindgen]
1963pub fn cmo_batch_into(
1964    in_ptr: *const f64,
1965    out_ptr: *mut f64,
1966    len: usize,
1967    period_start: usize,
1968    period_end: usize,
1969    period_step: usize,
1970) -> Result<usize, JsValue> {
1971    if in_ptr.is_null() || out_ptr.is_null() {
1972        return Err(JsValue::from_str("null pointer passed to cmo_batch_into"));
1973    }
1974
1975    unsafe {
1976        let data = std::slice::from_raw_parts(in_ptr, len);
1977
1978        let sweep = CmoBatchRange {
1979            period: (period_start, period_end, period_step),
1980        };
1981
1982        let combos = expand_grid(&sweep);
1983        let rows = combos.len();
1984        let cols = len;
1985
1986        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1987
1988        cmo_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1989            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1990
1991        Ok(rows)
1992    }
1993}