Skip to main content

vector_ta/indicators/
cfo.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyUntypedArrayMethods};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9#[cfg(feature = "python")]
10use pyo3::PyErr;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::convert::AsRef;
31use std::error::Error;
32use std::mem::MaybeUninit;
33use thiserror::Error;
34
35#[derive(Debug, Clone)]
36pub enum CfoData<'a> {
37    Candles {
38        candles: &'a Candles,
39        source: &'a str,
40    },
41    Slice(&'a [f64]),
42}
43
44impl<'a> AsRef<[f64]> for CfoInput<'a> {
45    #[inline(always)]
46    fn as_ref(&self) -> &[f64] {
47        match &self.data {
48            CfoData::Slice(slice) => slice,
49            CfoData::Candles { candles, source } => source_type(candles, source),
50        }
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct CfoOutput {
56    pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61    all(target_arch = "wasm32", feature = "wasm"),
62    derive(Serialize, Deserialize)
63)]
64pub struct CfoParams {
65    pub period: Option<usize>,
66    pub scalar: Option<f64>,
67}
68
69impl Default for CfoParams {
70    fn default() -> Self {
71        Self {
72            period: Some(14),
73            scalar: Some(100.0),
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct CfoInput<'a> {
80    pub data: CfoData<'a>,
81    pub params: CfoParams,
82}
83
84impl<'a> CfoInput<'a> {
85    #[inline]
86    pub fn from_candles(c: &'a Candles, s: &'a str, p: CfoParams) -> Self {
87        Self {
88            data: CfoData::Candles {
89                candles: c,
90                source: s,
91            },
92            params: p,
93        }
94    }
95    #[inline]
96    pub fn from_slice(sl: &'a [f64], p: CfoParams) -> Self {
97        Self {
98            data: CfoData::Slice(sl),
99            params: p,
100        }
101    }
102    #[inline]
103    pub fn with_default_candles(c: &'a Candles) -> Self {
104        Self::from_candles(c, "close", CfoParams::default())
105    }
106    #[inline]
107    pub fn get_period(&self) -> usize {
108        self.params.period.unwrap_or(14)
109    }
110    #[inline]
111    pub fn get_scalar(&self) -> f64 {
112        self.params.scalar.unwrap_or(100.0)
113    }
114}
115
116#[derive(Copy, Clone, Debug)]
117pub struct CfoBuilder {
118    period: Option<usize>,
119    scalar: Option<f64>,
120    kernel: Kernel,
121}
122
123impl Default for CfoBuilder {
124    fn default() -> Self {
125        Self {
126            period: None,
127            scalar: None,
128            kernel: Kernel::Auto,
129        }
130    }
131}
132
133impl CfoBuilder {
134    #[inline(always)]
135    pub fn new() -> Self {
136        Self::default()
137    }
138    #[inline(always)]
139    pub fn period(mut self, n: usize) -> Self {
140        self.period = Some(n);
141        self
142    }
143    #[inline(always)]
144    pub fn scalar(mut self, x: f64) -> Self {
145        self.scalar = Some(x);
146        self
147    }
148    #[inline(always)]
149    pub fn kernel(mut self, k: Kernel) -> Self {
150        self.kernel = k;
151        self
152    }
153    #[inline(always)]
154    pub fn apply(self, c: &Candles) -> Result<CfoOutput, CfoError> {
155        let p = CfoParams {
156            period: self.period,
157            scalar: self.scalar,
158        };
159        let i = CfoInput::from_candles(c, "close", p);
160        cfo_with_kernel(&i, self.kernel)
161    }
162    #[inline(always)]
163    pub fn apply_slice(self, d: &[f64]) -> Result<CfoOutput, CfoError> {
164        let p = CfoParams {
165            period: self.period,
166            scalar: self.scalar,
167        };
168        let i = CfoInput::from_slice(d, p);
169        cfo_with_kernel(&i, self.kernel)
170    }
171    #[inline(always)]
172    pub fn into_stream(self) -> Result<CfoStream, CfoError> {
173        let p = CfoParams {
174            period: self.period,
175            scalar: self.scalar,
176        };
177        CfoStream::try_new(p)
178    }
179}
180
181#[derive(Debug, Error)]
182pub enum CfoError {
183    #[error("cfo: All values are NaN.")]
184    AllValuesNaN,
185    #[error("cfo: Invalid period: period = {period}, data length = {data_len}")]
186    InvalidPeriod { period: usize, data_len: usize },
187    #[error("cfo: Not enough valid data: needed = {needed}, valid = {valid}")]
188    NotEnoughValidData { needed: usize, valid: usize },
189    #[error("cfo: No data provided.")]
190    EmptyInputData,
191    #[error("cfo: output length mismatch: expected = {expected}, got = {got}")]
192    OutputLengthMismatch { expected: usize, got: usize },
193    #[error("cfo: invalid kernel for batch: {0:?}")]
194    InvalidKernelForBatch(Kernel),
195    #[error("cfo: Invalid range: start={start} end={end} step={step}")]
196    InvalidRange {
197        start: usize,
198        end: usize,
199        step: usize,
200    },
201}
202
203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
204impl From<CfoError> for JsValue {
205    fn from(err: CfoError) -> Self {
206        JsValue::from_str(&err.to_string())
207    }
208}
209
210#[inline]
211pub fn cfo(input: &CfoInput) -> Result<CfoOutput, CfoError> {
212    cfo_with_kernel(input, Kernel::Auto)
213}
214
215#[inline]
216pub fn cfo_into_slice(dst: &mut [f64], input: &CfoInput, kernel: Kernel) -> Result<(), CfoError> {
217    let data: &[f64] = input.as_ref();
218    let len = data.len();
219    let period = input.get_period();
220    let scalar = input.get_scalar();
221
222    if len == 0 {
223        return Err(CfoError::EmptyInputData);
224    }
225
226    if dst.len() != data.len() {
227        return Err(CfoError::OutputLengthMismatch {
228            expected: data.len(),
229            got: dst.len(),
230        });
231    }
232
233    let first = data
234        .iter()
235        .position(|x| !x.is_nan())
236        .ok_or(CfoError::AllValuesNaN)?;
237
238    if period == 0 || period > len {
239        return Err(CfoError::InvalidPeriod {
240            period,
241            data_len: len,
242        });
243    }
244    if (len - first) < period {
245        return Err(CfoError::NotEnoughValidData {
246            needed: period,
247            valid: len - first,
248        });
249    }
250
251    let chosen = match kernel {
252        Kernel::Auto => Kernel::Scalar,
253        other => other,
254    };
255
256    let warmup_end = first + period - 1;
257    let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
258    for v in &mut dst[..warmup_end] {
259        *v = qnan;
260    }
261
262    unsafe {
263        match chosen {
264            Kernel::Scalar | Kernel::ScalarBatch => cfo_scalar(data, period, scalar, first, dst),
265            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266            Kernel::Avx2 | Kernel::Avx2Batch => cfo_avx2(data, period, scalar, first, dst),
267            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
268            Kernel::Avx512 | Kernel::Avx512Batch => cfo_avx512(data, period, scalar, first, dst),
269            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
270            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
271                cfo_scalar(data, period, scalar, first, dst)
272            }
273            _ => unreachable!(),
274        }
275    }
276
277    Ok(())
278}
279
280#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
281#[inline]
282pub fn cfo_into(input: &CfoInput, out: &mut [f64]) -> Result<(), CfoError> {
283    cfo_into_slice(out, input, Kernel::Auto)
284}
285
286pub fn cfo_with_kernel(input: &CfoInput, kernel: Kernel) -> Result<CfoOutput, CfoError> {
287    let data: &[f64] = input.as_ref();
288    let len = data.len();
289    let period = input.get_period();
290    let scalar = input.get_scalar();
291
292    if len == 0 {
293        return Err(CfoError::EmptyInputData);
294    }
295
296    let first = data
297        .iter()
298        .position(|x| !x.is_nan())
299        .ok_or(CfoError::AllValuesNaN)?;
300
301    if period == 0 || period > len {
302        return Err(CfoError::InvalidPeriod {
303            period,
304            data_len: len,
305        });
306    }
307    if (len - first) < period {
308        return Err(CfoError::NotEnoughValidData {
309            needed: period,
310            valid: len - first,
311        });
312    }
313
314    let chosen = match kernel {
315        Kernel::Auto => Kernel::Scalar,
316        other => other,
317    };
318
319    let mut out = alloc_with_nan_prefix(len, first + period - 1);
320
321    unsafe {
322        match chosen {
323            Kernel::Scalar | Kernel::ScalarBatch => {
324                cfo_scalar(data, period, scalar, first, &mut out)
325            }
326            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
327            Kernel::Avx2 | Kernel::Avx2Batch => cfo_avx2(data, period, scalar, first, &mut out),
328            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
329            Kernel::Avx512 | Kernel::Avx512Batch => {
330                cfo_avx512(data, period, scalar, first, &mut out)
331            }
332            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
333            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
334                cfo_scalar(data, period, scalar, first, &mut out)
335            }
336            _ => unreachable!(),
337        }
338    }
339
340    Ok(CfoOutput { values: out })
341}
342
343#[inline]
344pub fn cfo_scalar(data: &[f64], period: usize, scalar: f64, first_valid: usize, out: &mut [f64]) {
345    let size = data.len();
346
347    let n = period as f64;
348    let inv_n = 1.0 / n;
349    let sx = ((period * (period + 1)) / 2) as f64;
350    let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
351    let inv_denom = 1.0 / (n * sx2 - sx * sx);
352    let half_nm1 = 0.5 * (n - 1.0);
353
354    let start = first_valid;
355    let pre = period - 1;
356    let mut sum_y = 0.0;
357    let mut sum_xy = 0.0;
358    for k in 0..pre {
359        let v = data[start + k];
360        let w = (k as f64) + 1.0;
361        sum_y += v;
362        sum_xy = v.mul_add(w, sum_xy);
363    }
364
365    let mut i = start + pre;
366    while i < size {
367        let v = data[i];
368        sum_xy = v.mul_add(n, sum_xy);
369        sum_y += v;
370        let b = (-sx).mul_add(sum_y, n * sum_xy) * inv_denom;
371        let f = b.mul_add(half_nm1, sum_y * inv_n);
372        unsafe {
373            *out.get_unchecked_mut(i) = if v.is_finite() && v != 0.0 {
374                (v - f) * (scalar / v)
375            } else {
376                f64::NAN
377            };
378        }
379        sum_xy -= sum_y;
380        sum_y -= data[i - pre];
381        i += 1;
382    }
383}
384
385#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
386#[inline]
387pub fn cfo_avx512(data: &[f64], period: usize, scalar: f64, first_valid: usize, out: &mut [f64]) {
388    if period <= 32 {
389        unsafe { cfo_avx512_short(data, period, scalar, first_valid, out) }
390    } else {
391        unsafe { cfo_avx512_long(data, period, scalar, first_valid, out) }
392    }
393}
394
395#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
396#[target_feature(enable = "avx2")]
397#[inline]
398pub unsafe fn cfo_avx2(
399    data: &[f64],
400    period: usize,
401    scalar: f64,
402    first_valid: usize,
403    out: &mut [f64],
404) {
405    cfo_scalar(data, period, scalar, first_valid, out);
406}
407
408#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
409#[target_feature(enable = "avx512f")]
410#[inline]
411pub unsafe fn cfo_avx512_short(
412    data: &[f64],
413    period: usize,
414    scalar: f64,
415    first_valid: usize,
416    out: &mut [f64],
417) {
418    cfo_scalar(data, period, scalar, first_valid, out);
419}
420
421#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
422#[target_feature(enable = "avx512f")]
423#[inline]
424pub unsafe fn cfo_avx512_long(
425    data: &[f64],
426    period: usize,
427    scalar: f64,
428    first_valid: usize,
429    out: &mut [f64],
430) {
431    cfo_scalar(data, period, scalar, first_valid, out);
432}
433
434#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
435#[target_feature(enable = "avx512f")]
436#[inline]
437unsafe fn cfo_avx512_impl(
438    data: &[f64],
439    period: usize,
440    scalar: f64,
441    first_valid: usize,
442    out: &mut [f64],
443) {
444    use core::arch::x86_64::*;
445    if period < 2 {
446        return cfo_scalar(data, period, scalar, first_valid, out);
447    }
448
449    let size = data.len();
450    let n = period as f64;
451    let inv_n = 1.0 / n;
452    let sx = ((period * (period + 1)) / 2) as f64;
453    let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
454    let inv_denom = 1.0 / (n * sx2 - sx * sx);
455    let half_nm1 = 0.5 * (n - 1.0);
456
457    let start = first_valid;
458    let pre = period - 1;
459    let end_init = start + pre;
460
461    let dp = data.as_ptr();
462
463    let mut acc_y = _mm512_set1_pd(0.0);
464    let mut acc_xy = _mm512_set1_pd(0.0);
465    let mut w = _mm512_setr_pd(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
466    let stepw = _mm512_set1_pd(8.0);
467
468    let mut j = start;
469    while j + 8 <= end_init {
470        let v = _mm512_loadu_pd(dp.add(j));
471        acc_y = _mm512_add_pd(acc_y, v);
472        let prod = _mm512_mul_pd(v, w);
473        acc_xy = _mm512_add_pd(acc_xy, prod);
474        w = _mm512_add_pd(w, stepw);
475        j += 8;
476    }
477
478    let mut tmp8 = [0.0f64; 8];
479    _mm512_storeu_pd(tmp8.as_mut_ptr(), acc_y);
480    let mut sum_y = tmp8.iter().sum::<f64>();
481    _mm512_storeu_pd(tmp8.as_mut_ptr(), acc_xy);
482    let mut sum_xy = tmp8.iter().sum::<f64>();
483
484    let mut w_scalar = 1.0 + ((j - start) as f64);
485    while j < end_init {
486        let v = *dp.add(j);
487        sum_y += v;
488        sum_xy = v.mul_add(w_scalar, sum_xy);
489        w_scalar += 1.0;
490        j += 1;
491    }
492
493    let mut i = start + pre;
494    while i < size {
495        let v0 = *dp.add(i);
496        sum_xy = v0.mul_add(n, sum_xy);
497        sum_y += v0;
498        let b0 = (-sx).mul_add(sum_y, n * sum_xy) * inv_denom;
499        let f0 = b0.mul_add(half_nm1, sum_y * inv_n);
500        *out.get_unchecked_mut(i) = if v0.is_finite() && v0 != 0.0 {
501            (v0 - f0) * (scalar / v0)
502        } else {
503            f64::NAN
504        };
505        sum_xy -= sum_y;
506        sum_y -= *dp.add(i - pre);
507        i += 1;
508    }
509}
510
511#[inline(always)]
512pub fn cfo_row_scalar(
513    data: &[f64],
514    first: usize,
515    period: usize,
516    _stride: usize,
517    scalar: f64,
518    out: &mut [f64],
519) {
520    cfo_scalar(data, period, scalar, first, out)
521}
522
523#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
524#[inline(always)]
525pub unsafe fn cfo_row_avx2(
526    data: &[f64],
527    first: usize,
528    period: usize,
529    _stride: usize,
530    scalar: f64,
531    out: &mut [f64],
532) {
533    cfo_scalar(data, period, scalar, first, out)
534}
535
536#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
537#[inline(always)]
538pub unsafe fn cfo_row_avx512(
539    data: &[f64],
540    first: usize,
541    period: usize,
542    stride: usize,
543    scalar: f64,
544    out: &mut [f64],
545) {
546    if period <= 32 {
547        cfo_row_avx512_short(data, first, period, stride, scalar, out)
548    } else {
549        cfo_row_avx512_long(data, first, period, stride, scalar, out)
550    }
551}
552
553#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
554#[inline(always)]
555pub unsafe fn cfo_row_avx512_short(
556    data: &[f64],
557    first: usize,
558    period: usize,
559    _stride: usize,
560    scalar: f64,
561    out: &mut [f64],
562) {
563    cfo_scalar(data, period, scalar, first, out)
564}
565
566#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
567#[inline(always)]
568pub unsafe fn cfo_row_avx512_long(
569    data: &[f64],
570    first: usize,
571    period: usize,
572    _stride: usize,
573    scalar: f64,
574    out: &mut [f64],
575) {
576    cfo_scalar(data, period, scalar, first, out)
577}
578
579#[derive(Debug, Clone)]
580pub struct CfoStream {
581    period: usize,
582    scalar: f64,
583
584    buf: Vec<f64>,
585    idx: usize,
586    filled: bool,
587
588    sum_y: f64,
589    sum_xy: f64,
590
591    n: f64,
592    inv_n: f64,
593    sx: f64,
594    inv_denom: f64,
595    half_nm1: f64,
596}
597
598impl CfoStream {
599    pub fn try_new(params: CfoParams) -> Result<Self, CfoError> {
600        let period = params.period.unwrap_or(14);
601        if period == 0 {
602            return Err(CfoError::InvalidPeriod {
603                period,
604                data_len: 0,
605            });
606        }
607        let scalar = params.scalar.unwrap_or(100.0);
608        let n = period as f64;
609        let inv_n = 1.0 / n;
610        let sx = ((period * (period + 1)) / 2) as f64;
611        let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
612        let inv_denom = 1.0 / (n * sx2 - sx * sx);
613        let half_nm1 = 0.5 * (n - 1.0);
614
615        Ok(Self {
616            period,
617            scalar,
618            buf: vec![f64::NAN; period],
619            idx: 0,
620            filled: false,
621            sum_y: 0.0,
622            sum_xy: 0.0,
623            n,
624            inv_n,
625            sx,
626            inv_denom,
627            half_nm1,
628        })
629    }
630
631    #[inline(always)]
632    pub fn update(&mut self, value: f64) -> Option<f64> {
633        if !self.filled {
634            let k = (self.idx as f64) + 1.0;
635            self.sum_y += value;
636            self.sum_xy = value.mul_add(k, self.sum_xy);
637
638            self.buf[self.idx] = value;
639            self.idx += 1;
640
641            if self.idx == self.period {
642                self.idx = 0;
643                self.filled = true;
644
645                return Some(self.calc_current());
646            } else {
647                return None;
648            }
649        }
650
651        let y_old = self.buf[self.idx];
652
653        let new_sum_xy = (self.n * value) + (self.sum_xy - self.sum_y);
654        let new_sum_y = self.sum_y - y_old + value;
655
656        self.buf[self.idx] = value;
657        self.sum_xy = new_sum_xy;
658        self.sum_y = new_sum_y;
659
660        self.idx = (self.idx + 1) % self.period;
661
662        Some(self.calc_current())
663    }
664
665    #[inline(always)]
666    fn calc_current(&self) -> f64 {
667        debug_assert!(self.filled, "calc_current() called before buffer filled");
668
669        let cur = self.buf[(self.idx + self.period - 1) % self.period];
670
671        if cur.is_finite() && cur != 0.0 {
672            let b = (-self.sx).mul_add(self.sum_y, self.n * self.sum_xy) * self.inv_denom;
673            let f = b.mul_add(self.half_nm1, self.sum_y * self.inv_n);
674
675            self.scalar.mul_add(-f / cur, self.scalar)
676        } else {
677            f64::NAN
678        }
679    }
680}
681
682#[derive(Clone, Debug)]
683pub struct CfoBatchRange {
684    pub period: (usize, usize, usize),
685    pub scalar: (f64, f64, f64),
686}
687
688impl Default for CfoBatchRange {
689    fn default() -> Self {
690        Self {
691            period: (14, 14, 0),
692            scalar: (100.0, 124.9, 0.1),
693        }
694    }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct CfoBatchBuilder {
699    range: CfoBatchRange,
700    kernel: Kernel,
701}
702
703impl CfoBatchBuilder {
704    pub fn new() -> Self {
705        Self::default()
706    }
707    pub fn kernel(mut self, k: Kernel) -> Self {
708        self.kernel = k;
709        self
710    }
711    #[inline]
712    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
713        self.range.period = (start, end, step);
714        self
715    }
716    #[inline]
717    pub fn period_static(mut self, p: usize) -> Self {
718        self.range.period = (p, p, 0);
719        self
720    }
721    #[inline]
722    pub fn scalar_range(mut self, start: f64, end: f64, step: f64) -> Self {
723        self.range.scalar = (start, end, step);
724        self
725    }
726    #[inline]
727    pub fn scalar_static(mut self, s: f64) -> Self {
728        self.range.scalar = (s, s, 0.0);
729        self
730    }
731    pub fn apply_slice(self, data: &[f64]) -> Result<CfoBatchOutput, CfoError> {
732        cfo_batch_with_kernel(data, &self.range, self.kernel)
733    }
734    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CfoBatchOutput, CfoError> {
735        CfoBatchBuilder::new().kernel(k).apply_slice(data)
736    }
737    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CfoBatchOutput, CfoError> {
738        let slice = source_type(c, src);
739        self.apply_slice(slice)
740    }
741    pub fn with_default_candles(c: &Candles) -> Result<CfoBatchOutput, CfoError> {
742        CfoBatchBuilder::new()
743            .kernel(Kernel::Auto)
744            .apply_candles(c, "close")
745    }
746}
747
748pub fn cfo_batch_with_kernel(
749    data: &[f64],
750    sweep: &CfoBatchRange,
751    k: Kernel,
752) -> Result<CfoBatchOutput, CfoError> {
753    let kernel = match k {
754        Kernel::Auto => detect_best_batch_kernel(),
755        other if other.is_batch() => other,
756        other => return Err(CfoError::InvalidKernelForBatch(other)),
757    };
758    let simd = match kernel {
759        Kernel::Avx512Batch => Kernel::Avx512,
760        Kernel::Avx2Batch => Kernel::Avx2,
761        Kernel::ScalarBatch => Kernel::Scalar,
762        _ => unreachable!(),
763    };
764    cfo_batch_par_slice(data, sweep, simd)
765}
766
767#[derive(Clone, Debug)]
768pub struct CfoBatchOutput {
769    pub values: Vec<f64>,
770    pub combos: Vec<CfoParams>,
771    pub rows: usize,
772    pub cols: usize,
773}
774
775impl CfoBatchOutput {
776    pub fn row_for_params(&self, p: &CfoParams) -> Option<usize> {
777        self.combos.iter().position(|c| {
778            c.period.unwrap_or(14) == p.period.unwrap_or(14)
779                && (c.scalar.unwrap_or(100.0) - p.scalar.unwrap_or(100.0)).abs() < 1e-12
780        })
781    }
782    pub fn values_for(&self, p: &CfoParams) -> Option<&[f64]> {
783        self.row_for_params(p).map(|row| {
784            let start = row * self.cols;
785            &self.values[start..start + self.cols]
786        })
787    }
788}
789
790#[inline(always)]
791fn expand_grid(r: &CfoBatchRange) -> Result<Vec<CfoParams>, CfoError> {
792    fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, CfoError> {
793        if step == 0 || start == end {
794            return Ok(vec![start]);
795        }
796        let mut vals = Vec::new();
797        if start < end {
798            let mut cur = start;
799            while cur <= end {
800                vals.push(cur);
801                match cur.checked_add(step) {
802                    Some(next) => cur = next,
803                    None => break,
804                }
805            }
806        } else {
807            return Err(CfoError::InvalidRange { start, end, step });
808        }
809        if vals.is_empty() {
810            return Err(CfoError::InvalidRange { start, end, step });
811        }
812        Ok(vals)
813    }
814    fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, CfoError> {
815        if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
816            return Ok(vec![start]);
817        }
818        let mut vals = Vec::new();
819        let delta = if start <= end {
820            step.abs()
821        } else {
822            -step.abs()
823        };
824        if delta.is_sign_positive() {
825            let mut x = start;
826            while x <= end + 1e-12 {
827                vals.push(x);
828                x += delta;
829                if !x.is_finite() {
830                    break;
831                }
832            }
833        } else {
834            let mut x = start;
835            while x >= end - 1e-12 {
836                vals.push(x);
837                x += delta;
838                if !x.is_finite() {
839                    break;
840                }
841            }
842        }
843        if vals.is_empty() {
844            return Err(CfoError::InvalidRange {
845                start: 0,
846                end: 0,
847                step: 0,
848            });
849        }
850        Ok(vals)
851    }
852    let periods = axis_usize(r.period)?;
853    let scalars = axis_f64(r.scalar)?;
854    let combos_len = periods
855        .len()
856        .checked_mul(scalars.len())
857        .ok_or(CfoError::InvalidRange {
858            start: periods.len(),
859            end: scalars.len(),
860            step: 0,
861        })?;
862    let mut out = Vec::with_capacity(combos_len);
863    for &p in &periods {
864        for &s in &scalars {
865            out.push(CfoParams {
866                period: Some(p),
867                scalar: Some(s),
868            });
869        }
870    }
871    Ok(out)
872}
873
874#[inline(always)]
875pub fn cfo_batch_slice(
876    data: &[f64],
877    sweep: &CfoBatchRange,
878    kern: Kernel,
879) -> Result<CfoBatchOutput, CfoError> {
880    cfo_batch_inner(data, sweep, kern, false)
881}
882
883#[inline(always)]
884pub fn cfo_batch_par_slice(
885    data: &[f64],
886    sweep: &CfoBatchRange,
887    kern: Kernel,
888) -> Result<CfoBatchOutput, CfoError> {
889    cfo_batch_inner(data, sweep, kern, true)
890}
891
892#[inline(always)]
893fn cfo_batch_inner(
894    data: &[f64],
895    sweep: &CfoBatchRange,
896    kern: Kernel,
897    parallel: bool,
898) -> Result<CfoBatchOutput, CfoError> {
899    let combos = expand_grid(sweep)?;
900    if data.is_empty() {
901        return Err(CfoError::EmptyInputData);
902    }
903
904    for combo in &combos {
905        let period = combo.period.unwrap();
906        if period == 0 || period > data.len() {
907            return Err(CfoError::InvalidPeriod {
908                period,
909                data_len: data.len(),
910            });
911        }
912    }
913    let first = data
914        .iter()
915        .position(|x| !x.is_nan())
916        .ok_or(CfoError::AllValuesNaN)?;
917    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
918    if data.len() - first < max_p {
919        return Err(CfoError::NotEnoughValidData {
920            needed: max_p,
921            valid: data.len() - first,
922        });
923    }
924    let rows = combos.len();
925    let cols = data.len();
926    rows.checked_mul(cols).ok_or(CfoError::InvalidRange {
927        start: rows,
928        end: cols,
929        step: 0,
930    })?;
931
932    let mut buf_mu = make_uninit_matrix(rows, cols);
933
934    let warm: Vec<usize> = combos
935        .iter()
936        .map(|c| first + c.period.unwrap() - 1)
937        .collect();
938    init_matrix_prefixes(&mut buf_mu, cols, &warm);
939
940    let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
941    let values: &mut [f64] = unsafe {
942        core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
943    };
944
945    cfo_batch_prefix_scalar_rows(data, first, &combos, values, rows, cols, parallel);
946
947    let values = unsafe {
948        Vec::from_raw_parts(
949            buf_guard.as_mut_ptr() as *mut f64,
950            buf_guard.len(),
951            buf_guard.capacity(),
952        )
953    };
954
955    Ok(CfoBatchOutput {
956        values,
957        combos,
958        rows,
959        cols,
960    })
961}
962
963#[inline]
964fn cfo_batch_prefix_scalar_rows(
965    data: &[f64],
966    first: usize,
967    combos: &[CfoParams],
968    values: &mut [f64],
969    rows: usize,
970    cols: usize,
971    parallel: bool,
972) {
973    let valid_len = cols - first;
974    if valid_len == 0 || rows == 0 {
975        return;
976    }
977
978    let mut p = Vec::with_capacity(valid_len + 1);
979    let mut q = Vec::with_capacity(valid_len + 1);
980    p.push(0.0);
981    q.push(0.0);
982    let mut ps = 0.0f64;
983    let mut qs = 0.0f64;
984    for (idx, &v) in data[first..].iter().enumerate() {
985        let j = (idx as f64) + 1.0;
986        ps += v;
987        qs = v.mul_add(j, qs);
988        p.push(ps);
989        q.push(qs);
990    }
991
992    let compute_row = |row: usize, out_row: &mut [f64]| {
993        let period = combos[row].period.unwrap();
994        let scalar = combos[row].scalar.unwrap();
995        let n = period as f64;
996        let inv_n = 1.0 / n;
997        let sx = ((period * (period + 1)) / 2) as f64;
998        let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
999        let inv_denom = 1.0 / (n * sx2 - sx * sx);
1000        let half_nm1 = 0.5 * (n - 1.0);
1001
1002        if cols < first + period {
1003            return;
1004        }
1005
1006        let start_idx = first + period - 1;
1007        for di in start_idx..cols {
1008            let idx = di - first;
1009
1010            let r1 = idx + 1;
1011            let l1_minus1 = idx + 1 - period;
1012
1013            let sum_y = p[r1] - p[l1_minus1];
1014            let sum_xy = (q[r1] - q[l1_minus1]) - (l1_minus1 as f64) * sum_y;
1015
1016            let b = (-sx).mul_add(sum_y, n * sum_xy) * inv_denom;
1017            let f = b.mul_add(half_nm1, sum_y * inv_n);
1018            let v = data[di];
1019            let y = if v.is_finite() && v != 0.0 {
1020                (v - f) * (scalar / v)
1021            } else {
1022                f64::NAN
1023            };
1024            out_row[di] = y;
1025        }
1026    };
1027
1028    if parallel {
1029        #[cfg(not(target_arch = "wasm32"))]
1030        {
1031            values
1032                .par_chunks_mut(cols)
1033                .enumerate()
1034                .for_each(|(row, slice)| compute_row(row, slice));
1035        }
1036        #[cfg(target_arch = "wasm32")]
1037        {
1038            for (row, slice) in values.chunks_mut(cols).enumerate() {
1039                compute_row(row, slice);
1040            }
1041        }
1042    } else {
1043        for (row, slice) in values.chunks_mut(cols).enumerate() {
1044            compute_row(row, slice);
1045        }
1046    }
1047}
1048
1049#[inline(always)]
1050pub fn cfo_batch_inner_into(
1051    data: &[f64],
1052    sweep: &CfoBatchRange,
1053    kern: Kernel,
1054    parallel: bool,
1055    output: &mut [f64],
1056) -> Result<Vec<CfoParams>, CfoError> {
1057    let combos = expand_grid(sweep)?;
1058    if data.is_empty() {
1059        return Err(CfoError::EmptyInputData);
1060    }
1061
1062    for combo in &combos {
1063        let period = combo.period.unwrap();
1064        if period == 0 || period > data.len() {
1065            return Err(CfoError::InvalidPeriod {
1066                period,
1067                data_len: data.len(),
1068            });
1069        }
1070    }
1071    let first = data
1072        .iter()
1073        .position(|x| !x.is_nan())
1074        .ok_or(CfoError::AllValuesNaN)?;
1075    let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1076    if data.len() - first < max_p {
1077        return Err(CfoError::NotEnoughValidData {
1078            needed: max_p,
1079            valid: data.len() - first,
1080        });
1081    }
1082    let rows = combos.len();
1083    let cols = data.len();
1084    let total = rows.checked_mul(cols).ok_or(CfoError::InvalidRange {
1085        start: rows,
1086        end: cols,
1087        step: 0,
1088    })?;
1089    if output.len() != total {
1090        return Err(CfoError::OutputLengthMismatch {
1091            expected: total,
1092            got: output.len(),
1093        });
1094    }
1095
1096    let out_mu: &mut [core::mem::MaybeUninit<f64>] = unsafe {
1097        core::slice::from_raw_parts_mut(
1098            output.as_mut_ptr() as *mut core::mem::MaybeUninit<f64>,
1099            output.len(),
1100        )
1101    };
1102    let warm: Vec<usize> = combos
1103        .iter()
1104        .map(|c| first + c.period.unwrap() - 1)
1105        .collect();
1106    init_matrix_prefixes(out_mu, cols, &warm);
1107
1108    let values: &mut [f64] =
1109        unsafe { core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
1110    cfo_batch_prefix_scalar_rows(data, first, &combos, values, rows, cols, parallel);
1111
1112    Ok(combos)
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117    use super::*;
1118    use crate::skip_if_unsupported;
1119    use crate::utilities::data_loader::read_candles_from_csv;
1120
1121    fn check_cfo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1122        skip_if_unsupported!(kernel, test_name);
1123        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1124        let candles = read_candles_from_csv(file_path)?;
1125
1126        let default_params = CfoParams {
1127            period: None,
1128            scalar: None,
1129        };
1130        let input = CfoInput::from_candles(&candles, "close", default_params);
1131        let output = cfo_with_kernel(&input, kernel)?;
1132        assert_eq!(output.values.len(), candles.close.len());
1133        Ok(())
1134    }
1135
1136    fn check_cfo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1137        skip_if_unsupported!(kernel, test_name);
1138        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1139        let candles = read_candles_from_csv(file_path)?;
1140        let input = CfoInput::from_candles(&candles, "close", CfoParams::default());
1141        let result = cfo_with_kernel(&input, kernel)?;
1142        let expected_last_five = [
1143            0.5998626489475746,
1144            0.47578011282578453,
1145            0.20349744599816233,
1146            0.0919617952835795,
1147            -0.5676291145560617,
1148        ];
1149        let start = result.values.len().saturating_sub(5);
1150        for (i, &val) in result.values[start..].iter().enumerate() {
1151            let diff = (val - expected_last_five[i]).abs();
1152            assert!(
1153                diff < 1e-6,
1154                "[{}] CFO {:?} mismatch at idx {}: got {}, expected {}",
1155                test_name,
1156                kernel,
1157                i,
1158                val,
1159                expected_last_five[i]
1160            );
1161        }
1162        Ok(())
1163    }
1164
1165    fn check_cfo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1166        skip_if_unsupported!(kernel, test_name);
1167        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1168        let candles = read_candles_from_csv(file_path)?;
1169        let input = CfoInput::with_default_candles(&candles);
1170        match input.data {
1171            CfoData::Candles { source, .. } => assert_eq!(source, "close"),
1172            _ => panic!("Expected CfoData::Candles"),
1173        }
1174        let output = cfo_with_kernel(&input, kernel)?;
1175        assert_eq!(output.values.len(), candles.close.len());
1176        Ok(())
1177    }
1178
1179    fn check_cfo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1180        skip_if_unsupported!(kernel, test_name);
1181        let input_data = [10.0, 20.0, 30.0];
1182        let params = CfoParams {
1183            period: Some(0),
1184            scalar: Some(100.0),
1185        };
1186        let input = CfoInput::from_slice(&input_data, params);
1187        let res = cfo_with_kernel(&input, kernel);
1188        assert!(
1189            res.is_err(),
1190            "[{}] CFO should fail with zero period",
1191            test_name
1192        );
1193        Ok(())
1194    }
1195
1196    fn check_cfo_period_exceeds_length(
1197        test_name: &str,
1198        kernel: Kernel,
1199    ) -> Result<(), Box<dyn Error>> {
1200        skip_if_unsupported!(kernel, test_name);
1201        let data_small = [10.0, 20.0, 30.0];
1202        let params = CfoParams {
1203            period: Some(10),
1204            scalar: Some(100.0),
1205        };
1206        let input = CfoInput::from_slice(&data_small, params);
1207        let res = cfo_with_kernel(&input, kernel);
1208        assert!(
1209            res.is_err(),
1210            "[{}] CFO should fail with period exceeding length",
1211            test_name
1212        );
1213        Ok(())
1214    }
1215
1216    fn check_cfo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1217        skip_if_unsupported!(kernel, test_name);
1218        let single_point = [42.0];
1219        let params = CfoParams {
1220            period: Some(14),
1221            scalar: Some(100.0),
1222        };
1223        let input = CfoInput::from_slice(&single_point, params);
1224        let res = cfo_with_kernel(&input, kernel);
1225        assert!(
1226            res.is_err(),
1227            "[{}] CFO should fail with insufficient data",
1228            test_name
1229        );
1230        Ok(())
1231    }
1232
1233    fn check_cfo_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1234        skip_if_unsupported!(kernel, test_name);
1235        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1236        let candles = read_candles_from_csv(file_path)?;
1237
1238        let first_params = CfoParams {
1239            period: Some(14),
1240            scalar: Some(100.0),
1241        };
1242        let first_input = CfoInput::from_candles(&candles, "close", first_params);
1243        let first_result = cfo_with_kernel(&first_input, kernel)?;
1244
1245        let second_params = CfoParams {
1246            period: Some(14),
1247            scalar: Some(100.0),
1248        };
1249        let second_input = CfoInput::from_slice(&first_result.values, second_params);
1250        let second_result = cfo_with_kernel(&second_input, kernel)?;
1251
1252        assert_eq!(second_result.values.len(), first_result.values.len());
1253        for i in 240..second_result.values.len() {
1254            assert!(
1255                !second_result.values[i].is_nan(),
1256                "[{}] Expected no NaN after idx 240, found NaN at {}",
1257                test_name,
1258                i
1259            );
1260        }
1261        Ok(())
1262    }
1263
1264    fn check_cfo_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1265        skip_if_unsupported!(kernel, test_name);
1266        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1267        let candles = read_candles_from_csv(file_path)?;
1268
1269        let input = CfoInput::from_candles(
1270            &candles,
1271            "close",
1272            CfoParams {
1273                period: Some(14),
1274                scalar: Some(100.0),
1275            },
1276        );
1277        let res = cfo_with_kernel(&input, kernel)?;
1278        assert_eq!(res.values.len(), candles.close.len());
1279        if res.values.len() > 240 {
1280            for (i, &val) in res.values[240..].iter().enumerate() {
1281                assert!(
1282                    !val.is_nan(),
1283                    "[{}] Found unexpected NaN at out-index {}",
1284                    test_name,
1285                    240 + i
1286                );
1287            }
1288        }
1289        Ok(())
1290    }
1291
1292    fn check_cfo_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1293        skip_if_unsupported!(kernel, test_name);
1294        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1295        let candles = read_candles_from_csv(file_path)?;
1296
1297        let period = 14;
1298        let scalar = 100.0;
1299
1300        let input = CfoInput::from_candles(
1301            &candles,
1302            "close",
1303            CfoParams {
1304                period: Some(period),
1305                scalar: Some(scalar),
1306            },
1307        );
1308        let batch_output = cfo_with_kernel(&input, kernel)?.values;
1309
1310        let mut stream = CfoStream::try_new(CfoParams {
1311            period: Some(period),
1312            scalar: Some(scalar),
1313        })?;
1314
1315        let mut stream_values = Vec::with_capacity(candles.close.len());
1316        for &price in &candles.close {
1317            match stream.update(price) {
1318                Some(val) => stream_values.push(val),
1319                None => stream_values.push(f64::NAN),
1320            }
1321        }
1322
1323        assert_eq!(batch_output.len(), stream_values.len());
1324        for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1325            if b.is_nan() && s.is_nan() {
1326                continue;
1327            }
1328            let diff = (b - s).abs();
1329            assert!(
1330                diff < 1e-9,
1331                "[{}] CFO streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1332                test_name,
1333                i,
1334                b,
1335                s,
1336                diff
1337            );
1338        }
1339        Ok(())
1340    }
1341
1342    #[test]
1343    fn test_cfo_into_matches_api() -> Result<(), Box<dyn Error>> {
1344        let mut data = Vec::with_capacity(256);
1345
1346        data.push(f64::NAN);
1347        data.push(f64::NAN);
1348        for i in 1..=254u32 {
1349            data.push(50.0 + (i as f64) * 0.25);
1350        }
1351
1352        let input = CfoInput::from_slice(&data, CfoParams::default());
1353
1354        let baseline = cfo(&input)?.values;
1355
1356        let mut out = vec![0.0; data.len()];
1357        #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1358        {
1359            cfo_into(&input, &mut out)?;
1360        }
1361        #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1362        {
1363            cfo_into_slice(&mut out, &input, Kernel::Auto)?;
1364        }
1365
1366        assert_eq!(baseline.len(), out.len());
1367
1368        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1369            (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1370        }
1371
1372        for i in 0..out.len() {
1373            assert!(
1374                eq_or_both_nan(baseline[i], out[i]),
1375                "mismatch at {}: baseline={} out={}",
1376                i,
1377                baseline[i],
1378                out[i]
1379            );
1380        }
1381
1382        Ok(())
1383    }
1384
1385    #[cfg(debug_assertions)]
1386    fn check_cfo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1387        skip_if_unsupported!(kernel, test_name);
1388
1389        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1390        let candles = read_candles_from_csv(file_path)?;
1391
1392        let param_sets = vec![
1393            CfoParams {
1394                period: Some(14),
1395                scalar: Some(100.0),
1396            },
1397            CfoParams {
1398                period: Some(7),
1399                scalar: Some(50.0),
1400            },
1401            CfoParams {
1402                period: Some(21),
1403                scalar: Some(200.0),
1404            },
1405            CfoParams {
1406                period: Some(30),
1407                scalar: Some(100.0),
1408            },
1409        ];
1410
1411        for params in param_sets {
1412            let input = CfoInput::from_candles(&candles, "close", params.clone());
1413            let output = cfo_with_kernel(&input, kernel)?;
1414
1415            for (i, &val) in output.values.iter().enumerate() {
1416                if val.is_nan() {
1417                    continue;
1418                }
1419
1420                let bits = val.to_bits();
1421
1422                if bits == 0x11111111_11111111 {
1423                    panic!(
1424						"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params {:?}",
1425						test_name, val, bits, i, params
1426					);
1427                }
1428
1429                if bits == 0x22222222_22222222 {
1430                    panic!(
1431						"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params {:?}",
1432						test_name, val, bits, i, params
1433					);
1434                }
1435
1436                if bits == 0x33333333_33333333 {
1437                    panic!(
1438						"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params {:?}",
1439						test_name, val, bits, i, params
1440					);
1441                }
1442            }
1443        }
1444
1445        Ok(())
1446    }
1447
1448    #[cfg(not(debug_assertions))]
1449    fn check_cfo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1450        Ok(())
1451    }
1452
1453    macro_rules! generate_all_cfo_tests {
1454        ($($test_fn:ident),*) => {
1455            paste::paste! {
1456                $(
1457                    #[test]
1458                    fn [<$test_fn _scalar_f64>]() {
1459                        let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1460                    }
1461                )*
1462                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1463                $(
1464                    #[test]
1465                    fn [<$test_fn _avx2_f64>]() {
1466                        let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1467                    }
1468                    #[test]
1469                    fn [<$test_fn _avx512_f64>]() {
1470                        let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1471                    }
1472                )*
1473            }
1474        }
1475    }
1476
1477    #[cfg(feature = "proptest")]
1478    #[allow(clippy::float_cmp)]
1479    fn check_cfo_property(
1480        test_name: &str,
1481        kernel: Kernel,
1482    ) -> Result<(), Box<dyn std::error::Error>> {
1483        use proptest::prelude::*;
1484        skip_if_unsupported!(kernel, test_name);
1485
1486        let strat = (2usize..=50).prop_flat_map(|period| {
1487            (
1488                prop::collection::vec(
1489                    (-1e6f64..1e6f64)
1490                        .prop_filter("finite and non-zero", |x| x.is_finite() && x.abs() > 1e-10),
1491                    period..400,
1492                ),
1493                Just(period),
1494                50.0f64..200.0f64,
1495            )
1496        });
1497
1498        proptest::test_runner::TestRunner::default()
1499            .run(&strat, |(data, period, scalar)| {
1500                let params = CfoParams {
1501                    period: Some(period),
1502                    scalar: Some(scalar),
1503                };
1504                let input = CfoInput::from_slice(&data, params);
1505
1506                let CfoOutput { values: out } = cfo_with_kernel(&input, kernel).unwrap();
1507                let CfoOutput { values: ref_out } =
1508                    cfo_with_kernel(&input, Kernel::Scalar).unwrap();
1509
1510                prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1511
1512                for i in 0..(period - 1) {
1513                    prop_assert!(
1514                        out[i].is_nan(),
1515                        "Expected NaN during warmup at index {}, got {}",
1516                        i,
1517                        out[i]
1518                    );
1519                }
1520
1521                if data.len() >= period {
1522                    let first_valid_idx = period - 1;
1523                    prop_assert!(
1524                        out[first_valid_idx].is_finite(),
1525                        "Expected finite value at first valid index {}, got {}",
1526                        first_valid_idx,
1527                        out[first_valid_idx]
1528                    );
1529                }
1530
1531                for i in (period - 1)..data.len() {
1532                    let y = out[i];
1533                    let r = ref_out[i];
1534
1535                    prop_assert!(
1536                        y.is_finite(),
1537                        "Expected finite output at index {}, got {}",
1538                        i,
1539                        y
1540                    );
1541
1542                    let window_start = i.saturating_sub(period - 1);
1543                    let window = &data[window_start..=i];
1544                    if window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) {
1545                        let expected = 0.0;
1546                        prop_assert!(
1547                            (y - expected).abs() < 1e-6,
1548                            "For constant data at idx {}, expected CFO ≈ 0, got {}",
1549                            i,
1550                            y
1551                        );
1552                    }
1553
1554                    let y_bits = y.to_bits();
1555                    let r_bits = r.to_bits();
1556
1557                    if !y.is_finite() || !r.is_finite() {
1558                        prop_assert!(
1559                            y.to_bits() == r.to_bits(),
1560                            "finite/NaN mismatch at idx {}: {} vs {}",
1561                            i,
1562                            y,
1563                            r
1564                        );
1565                        continue;
1566                    }
1567
1568                    let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1569                    prop_assert!(
1570                        (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1571                        "Kernel mismatch at idx {}: {} vs {} (ULP={})",
1572                        i,
1573                        y,
1574                        r,
1575                        ulp_diff
1576                    );
1577                }
1578
1579                if data.len() > period {
1580                    let params_double = CfoParams {
1581                        period: Some(period),
1582                        scalar: Some(scalar * 2.0),
1583                    };
1584                    let input_double = CfoInput::from_slice(&data, params_double);
1585                    let CfoOutput { values: out_double } =
1586                        cfo_with_kernel(&input_double, kernel).unwrap();
1587
1588                    for i in (period - 1)..data.len().min(period + 10) {
1589                        if out[i].is_finite() && out_double[i].is_finite() && out[i].abs() > 1e-6 {
1590                            let ratio = out_double[i] / out[i];
1591                            prop_assert!(
1592                                (ratio - 2.0).abs() < 1e-9,
1593                                "Scale linearity failed at idx {}: ratio = {} (expected 2.0)",
1594                                i,
1595                                ratio
1596                            );
1597                        }
1598                    }
1599                }
1600
1601                if period == 2 && data.len() >= 2 {
1602                    prop_assert!(out[0].is_nan(), "Period=2: first value should be NaN");
1603
1604                    prop_assert!(
1605                        out[1].is_finite(),
1606                        "Period=2: second value should be finite, got {}",
1607                        out[1]
1608                    );
1609                }
1610
1611                Ok(())
1612            })
1613            .unwrap();
1614
1615        Ok(())
1616    }
1617
1618    #[cfg(not(feature = "proptest"))]
1619    fn check_cfo_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1620        skip_if_unsupported!(kernel, test_name);
1621        Ok(())
1622    }
1623
1624    generate_all_cfo_tests!(
1625        check_cfo_partial_params,
1626        check_cfo_accuracy,
1627        check_cfo_default_candles,
1628        check_cfo_zero_period,
1629        check_cfo_period_exceeds_length,
1630        check_cfo_very_small_dataset,
1631        check_cfo_reinput,
1632        check_cfo_nan_handling,
1633        check_cfo_streaming,
1634        check_cfo_no_poison,
1635        check_cfo_property
1636    );
1637
1638    fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1639        skip_if_unsupported!(kernel, test);
1640        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1641        let c = read_candles_from_csv(file)?;
1642
1643        let output = CfoBatchBuilder::new()
1644            .kernel(kernel)
1645            .apply_candles(&c, "close")?;
1646
1647        let def = CfoParams::default();
1648        let row = output.values_for(&def).expect("default row missing");
1649
1650        assert_eq!(row.len(), c.close.len());
1651
1652        let expected = [
1653            0.5998626489475746,
1654            0.47578011282578453,
1655            0.20349744599816233,
1656            0.0919617952835795,
1657            -0.5676291145560617,
1658        ];
1659        let start = row.len() - 5;
1660        for (i, &v) in row[start..].iter().enumerate() {
1661            assert!(
1662                (v - expected[i]).abs() < 1e-6,
1663                "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1664            );
1665        }
1666        Ok(())
1667    }
1668
1669    macro_rules! gen_batch_tests {
1670        ($fn_name:ident) => {
1671            paste::paste! {
1672                #[test] fn [<$fn_name _scalar>]()      {
1673                    let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1674                }
1675                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1676                #[test] fn [<$fn_name _avx2>]()        {
1677                    let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1678                }
1679                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1680                #[test] fn [<$fn_name _avx512>]()      {
1681                    let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1682                }
1683                #[test] fn [<$fn_name _auto_detect>]() {
1684                    let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1685                }
1686            }
1687        };
1688    }
1689
1690    #[cfg(debug_assertions)]
1691    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1692        skip_if_unsupported!(kernel, test);
1693
1694        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1695        let c = read_candles_from_csv(file)?;
1696
1697        let output = CfoBatchBuilder::new()
1698            .kernel(kernel)
1699            .period_range(5, 30, 5)
1700            .scalar_range(50.0, 200.0, 50.0)
1701            .apply_candles(&c, "close")?;
1702
1703        for (idx, &val) in output.values.iter().enumerate() {
1704            if val.is_nan() {
1705                continue;
1706            }
1707
1708            let bits = val.to_bits();
1709            let row = idx / output.cols;
1710            let col = idx % output.cols;
1711
1712            if bits == 0x11111111_11111111 {
1713                panic!(
1714					"[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1715					test, val, bits, row, col, idx
1716				);
1717            }
1718
1719            if bits == 0x22222222_22222222 {
1720                panic!(
1721					"[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1722					test, val, bits, row, col, idx
1723				);
1724            }
1725
1726            if bits == 0x33333333_33333333 {
1727                panic!(
1728					"[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1729					test, val, bits, row, col, idx
1730				);
1731            }
1732        }
1733
1734        Ok(())
1735    }
1736
1737    #[cfg(not(debug_assertions))]
1738    fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1739        Ok(())
1740    }
1741
1742    gen_batch_tests!(check_batch_default_row);
1743    gen_batch_tests!(check_batch_no_poison);
1744}
1745
1746#[cfg(feature = "python")]
1747#[pyfunction(name = "cfo")]
1748#[pyo3(signature = (data, period=14, scalar=100.0, kernel=None))]
1749pub fn cfo_py<'py>(
1750    py: Python<'py>,
1751    data: numpy::PyReadonlyArray1<'py, f64>,
1752    period: usize,
1753    scalar: f64,
1754    kernel: Option<&str>,
1755) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1756    use numpy::{IntoPyArray, PyArrayMethods};
1757
1758    let slice_in = data.as_slice()?;
1759    let kern = validate_kernel(kernel, false)?;
1760
1761    let params = CfoParams {
1762        period: Some(period),
1763        scalar: Some(scalar),
1764    };
1765    let cfo_in = CfoInput::from_slice(slice_in, params);
1766
1767    let result_vec: Vec<f64> = py
1768        .allow_threads(|| cfo_with_kernel(&cfo_in, kern).map(|o| o.values))
1769        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1770
1771    Ok(result_vec.into_pyarray(py))
1772}
1773
1774#[cfg(feature = "python")]
1775#[pyclass(name = "CfoStream")]
1776pub struct CfoStreamPy {
1777    stream: CfoStream,
1778}
1779
1780#[cfg(feature = "python")]
1781#[pymethods]
1782impl CfoStreamPy {
1783    #[new]
1784    fn new(period: usize, scalar: f64) -> PyResult<Self> {
1785        let params = CfoParams {
1786            period: Some(period),
1787            scalar: Some(scalar),
1788        };
1789        let stream =
1790            CfoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1791        Ok(CfoStreamPy { stream })
1792    }
1793
1794    fn update(&mut self, value: f64) -> Option<f64> {
1795        self.stream.update(value)
1796    }
1797}
1798
1799#[cfg(feature = "python")]
1800#[pyfunction(name = "cfo_batch")]
1801#[pyo3(signature = (data, period_range, scalar_range, kernel=None))]
1802pub fn cfo_batch_py<'py>(
1803    py: Python<'py>,
1804    data: numpy::PyReadonlyArray1<'py, f64>,
1805    period_range: (usize, usize, usize),
1806    scalar_range: (f64, f64, f64),
1807    kernel: Option<&str>,
1808) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1809    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1810    use pyo3::types::PyDict;
1811
1812    let slice_in = data.as_slice()?;
1813    let kern = validate_kernel(kernel, true)?;
1814
1815    let sweep = CfoBatchRange {
1816        period: period_range,
1817        scalar: scalar_range,
1818    };
1819
1820    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1821    let rows = combos.len();
1822    let cols = slice_in.len();
1823
1824    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1825    let slice_out = unsafe { out_arr.as_slice_mut()? };
1826
1827    let combos = py
1828        .allow_threads(|| {
1829            let kernel = match kern {
1830                Kernel::Auto => detect_best_batch_kernel(),
1831                k => k,
1832            };
1833            let simd = match kernel {
1834                Kernel::Avx512Batch => Kernel::Avx512,
1835                Kernel::Avx2Batch => Kernel::Avx2,
1836                Kernel::ScalarBatch => Kernel::Scalar,
1837                _ => unreachable!(),
1838            };
1839            cfo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1840        })
1841        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1842
1843    let dict = PyDict::new(py);
1844    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1845    dict.set_item(
1846        "periods",
1847        combos
1848            .iter()
1849            .map(|p| p.period.unwrap() as u64)
1850            .collect::<Vec<_>>()
1851            .into_pyarray(py),
1852    )?;
1853    dict.set_item(
1854        "scalars",
1855        combos
1856            .iter()
1857            .map(|p| p.scalar.unwrap())
1858            .collect::<Vec<_>>()
1859            .into_pyarray(py),
1860    )?;
1861
1862    Ok(dict)
1863}
1864
1865#[cfg(all(feature = "python", feature = "cuda"))]
1866#[pyfunction(name = "cfo_cuda_batch_dev")]
1867#[pyo3(signature = (data_f32, period_range, scalar_range=(100.0, 100.0, 0.0), device_id=0))]
1868pub fn cfo_cuda_batch_dev_py<'py>(
1869    py: Python<'py>,
1870    data_f32: numpy::PyReadonlyArray1<'py, f32>,
1871    period_range: (usize, usize, usize),
1872    scalar_range: (f64, f64, f64),
1873    device_id: usize,
1874) -> PyResult<(DeviceArrayF32CfoPy, Bound<'py, PyDict>)> {
1875    use crate::cuda::cuda_available;
1876    if !cuda_available() {
1877        return Err(PyValueError::new_err("CUDA not available"));
1878    }
1879
1880    let slice_in = data_f32.as_slice()?;
1881    let sweep = CfoBatchRange {
1882        period: period_range,
1883        scalar: scalar_range,
1884    };
1885
1886    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1887    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1888        let cuda = crate::cuda::oscillators::cfo_wrapper::CudaCfo::new(device_id)
1889            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1890
1891        let ctx = cuda.context_arc();
1892        let dev = cuda.device_id();
1893        let arr = cuda
1894            .cfo_batch_dev(slice_in, &sweep)
1895            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1896        Ok::<_, PyErr>((arr, ctx, dev as i32))
1897    })?;
1898
1899    let dict = PyDict::new(py);
1900    dict.set_item(
1901        "periods",
1902        combos
1903            .iter()
1904            .map(|p| p.period.unwrap() as u64)
1905            .collect::<Vec<_>>()
1906            .into_pyarray(py),
1907    )?;
1908    dict.set_item(
1909        "scalars",
1910        combos
1911            .iter()
1912            .map(|p| p.scalar.unwrap())
1913            .collect::<Vec<_>>()
1914            .into_pyarray(py),
1915    )?;
1916    Ok((
1917        DeviceArrayF32CfoPy {
1918            inner,
1919            ctx: ctx_arc,
1920            device_id: dev_id,
1921        },
1922        dict,
1923    ))
1924}
1925
1926#[cfg(all(feature = "python", feature = "cuda"))]
1927#[pyfunction(name = "cfo_cuda_many_series_one_param_dev")]
1928#[pyo3(signature = (data_tm_f32, period, scalar=100.0, device_id=0))]
1929pub fn cfo_cuda_many_series_one_param_dev_py<'py>(
1930    py: Python<'py>,
1931    data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
1932    period: usize,
1933    scalar: f64,
1934    device_id: usize,
1935) -> PyResult<DeviceArrayF32CfoPy> {
1936    use crate::cuda::cuda_available;
1937    if !cuda_available() {
1938        return Err(PyValueError::new_err("CUDA not available"));
1939    }
1940
1941    let shape = data_tm_f32.shape();
1942    if shape.len() != 2 {
1943        return Err(PyValueError::new_err("expected 2D array"));
1944    }
1945    let rows = shape[0];
1946    let cols = shape[1];
1947    let flat = data_tm_f32.as_slice()?;
1948    let params = CfoParams {
1949        period: Some(period),
1950        scalar: Some(scalar),
1951    };
1952    let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1953        let cuda = crate::cuda::oscillators::cfo_wrapper::CudaCfo::new(device_id)
1954            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1955        let ctx = cuda.context_arc();
1956        let dev = cuda.device_id();
1957        let arr = cuda
1958            .cfo_many_series_one_param_time_major_dev(flat, cols, rows, &params)
1959            .map_err(|e| PyValueError::new_err(e.to_string()))?;
1960        Ok::<_, PyErr>((arr, ctx, dev as i32))
1961    })?;
1962    Ok(DeviceArrayF32CfoPy {
1963        inner,
1964        ctx: ctx_arc,
1965        device_id: dev_id,
1966    })
1967}
1968
1969#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1970#[wasm_bindgen]
1971pub fn cfo_js(data: &[f64], period: usize, scalar: f64) -> Result<Vec<f64>, JsValue> {
1972    let params = CfoParams {
1973        period: Some(period),
1974        scalar: Some(scalar),
1975    };
1976    let input = CfoInput::from_slice(data, params);
1977
1978    let mut output = vec![0.0; data.len()];
1979    cfo_into_slice(&mut output, &input, Kernel::Auto)
1980        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1981
1982    Ok(output)
1983}
1984
1985#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1986#[wasm_bindgen]
1987#[deprecated(note = "Use cfo_batch instead")]
1988pub fn cfo_batch_js(
1989    data: &[f64],
1990    period_start: usize,
1991    period_end: usize,
1992    period_step: usize,
1993    scalar_start: f64,
1994    scalar_end: f64,
1995    scalar_step: f64,
1996) -> Result<Vec<f64>, JsValue> {
1997    let sweep = CfoBatchRange {
1998        period: (period_start, period_end, period_step),
1999        scalar: (scalar_start, scalar_end, scalar_step),
2000    };
2001
2002    cfo_batch_inner(data, &sweep, Kernel::Scalar, false)
2003        .map(|output| output.values)
2004        .map_err(|e| JsValue::from_str(&e.to_string()))
2005}
2006
2007#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2008#[wasm_bindgen]
2009#[deprecated(note = "Use cfo_batch instead")]
2010pub fn cfo_batch_metadata_js(
2011    period_start: usize,
2012    period_end: usize,
2013    period_step: usize,
2014    scalar_start: f64,
2015    scalar_end: f64,
2016    scalar_step: f64,
2017) -> Result<Vec<f64>, JsValue> {
2018    let sweep = CfoBatchRange {
2019        period: (period_start, period_end, period_step),
2020        scalar: (scalar_start, scalar_end, scalar_step),
2021    };
2022
2023    let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2024    let mut metadata = Vec::with_capacity(combos.len() * 2);
2025
2026    for combo in combos {
2027        metadata.push(combo.period.unwrap() as f64);
2028        metadata.push(combo.scalar.unwrap());
2029    }
2030
2031    Ok(metadata)
2032}
2033
2034#[cfg(all(feature = "python", feature = "cuda"))]
2035use crate::cuda::moving_averages::DeviceArrayF32;
2036#[cfg(all(feature = "python", feature = "cuda"))]
2037use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2038#[cfg(all(feature = "python", feature = "cuda"))]
2039use cust::context::Context as CudaContext;
2040#[cfg(all(feature = "python", feature = "cuda"))]
2041use std::sync::Arc as StdArc;
2042
2043#[cfg(all(feature = "python", feature = "cuda"))]
2044#[pyclass(module = "ta_indicators.cuda", unsendable)]
2045pub struct DeviceArrayF32CfoPy {
2046    pub(crate) inner: DeviceArrayF32,
2047    pub(crate) ctx: StdArc<CudaContext>,
2048    pub(crate) device_id: i32,
2049}
2050
2051#[cfg(all(feature = "python", feature = "cuda"))]
2052#[pymethods]
2053impl DeviceArrayF32CfoPy {
2054    #[getter]
2055    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2056        let d = PyDict::new(py);
2057        let inner = &self.inner;
2058        d.set_item("shape", (inner.rows, inner.cols))?;
2059        d.set_item("typestr", "<f4")?;
2060        d.set_item(
2061            "strides",
2062            (
2063                inner.cols * std::mem::size_of::<f32>(),
2064                std::mem::size_of::<f32>(),
2065            ),
2066        )?;
2067        let ptr_val: usize = if inner.rows == 0 || inner.cols == 0 {
2068            0
2069        } else {
2070            inner.device_ptr() as usize
2071        };
2072        d.set_item("data", (ptr_val, false))?;
2073
2074        d.set_item("version", 3)?;
2075        Ok(d)
2076    }
2077
2078    fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2079        Ok((2, self.device_id))
2080    }
2081
2082    #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
2083    fn __dlpack__<'py>(
2084        &mut self,
2085        py: Python<'py>,
2086        stream: Option<pyo3::PyObject>,
2087        max_version: Option<pyo3::PyObject>,
2088        dl_device: Option<pyo3::PyObject>,
2089        copy: Option<pyo3::PyObject>,
2090    ) -> PyResult<PyObject> {
2091        let (kdl, alloc_dev) = self.__dlpack_device__()?;
2092        if let Some(dev_obj) = dl_device.as_ref() {
2093            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2094                if dev_ty != kdl || dev_id != alloc_dev {
2095                    let wants_copy = copy
2096                        .as_ref()
2097                        .and_then(|c| c.extract::<bool>(py).ok())
2098                        .unwrap_or(false);
2099                    if wants_copy {
2100                        return Err(PyValueError::new_err(
2101                            "device copy not implemented for __dlpack__",
2102                        ));
2103                    } else {
2104                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2105                    }
2106                }
2107            }
2108        }
2109        let _ = stream;
2110
2111        let dummy = cust::memory::DeviceBuffer::<f32>::from_slice(&[])
2112            .map_err(|e| PyValueError::new_err(e.to_string()))?;
2113        let inner = std::mem::replace(
2114            &mut self.inner,
2115            DeviceArrayF32 {
2116                buf: dummy,
2117                rows: 0,
2118                cols: 0,
2119            },
2120        );
2121
2122        let rows = inner.rows;
2123        let cols = inner.cols;
2124        let buf = inner.buf;
2125
2126        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2127
2128        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2129    }
2130}
2131
2132#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2133#[derive(Serialize, Deserialize)]
2134pub struct CfoBatchConfig {
2135    pub period_range: (usize, usize, usize),
2136    pub scalar_range: (f64, f64, f64),
2137}
2138
2139#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2140#[derive(Serialize, Deserialize)]
2141pub struct CfoBatchJsOutput {
2142    pub values: Vec<f64>,
2143    pub combos: Vec<CfoParams>,
2144    pub rows: usize,
2145    pub cols: usize,
2146}
2147
2148#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2149#[wasm_bindgen]
2150pub fn cfo_batch(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2151    let config: CfoBatchConfig = serde_wasm_bindgen::from_value(config)
2152        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2153
2154    let sweep = CfoBatchRange {
2155        period: config.period_range,
2156        scalar: config.scalar_range,
2157    };
2158
2159    let output = cfo_batch_inner(data, &sweep, detect_best_kernel(), false)
2160        .map_err(|e| JsValue::from_str(&e.to_string()))?;
2161
2162    let js_output = CfoBatchJsOutput {
2163        values: output.values,
2164        combos: output.combos,
2165        rows: output.rows,
2166        cols: output.cols,
2167    };
2168
2169    serde_wasm_bindgen::to_value(&js_output)
2170        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2171}
2172
2173#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2174#[wasm_bindgen]
2175pub fn cfo_alloc(len: usize) -> *mut f64 {
2176    let mut vec = Vec::<f64>::with_capacity(len);
2177    let ptr = vec.as_mut_ptr();
2178    std::mem::forget(vec);
2179    ptr
2180}
2181
2182#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2183#[wasm_bindgen]
2184pub fn cfo_free(ptr: *mut f64, len: usize) {
2185    unsafe {
2186        let _ = Vec::from_raw_parts(ptr, len, len);
2187    }
2188}
2189
2190#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2191#[wasm_bindgen]
2192pub fn cfo_into(
2193    in_ptr: *const f64,
2194    out_ptr: *mut f64,
2195    len: usize,
2196    period: usize,
2197    scalar: f64,
2198) -> Result<(), JsValue> {
2199    if in_ptr.is_null() || out_ptr.is_null() {
2200        return Err(JsValue::from_str("null pointer passed to cfo_into"));
2201    }
2202
2203    unsafe {
2204        let data = std::slice::from_raw_parts(in_ptr, len);
2205
2206        if period == 0 || period > len {
2207            return Err(JsValue::from_str("Invalid period"));
2208        }
2209
2210        let params = CfoParams {
2211            period: Some(period),
2212            scalar: Some(scalar),
2213        };
2214        let input = CfoInput::from_slice(data, params);
2215
2216        if core::ptr::eq(in_ptr, out_ptr as *const f64) {
2217            let mut temp = vec![0.0; len];
2218            cfo_into_slice(&mut temp, &input, Kernel::Auto)
2219                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2220            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2221            out.copy_from_slice(&temp);
2222        } else {
2223            let out = std::slice::from_raw_parts_mut(out_ptr, len);
2224            cfo_into_slice(out, &input, Kernel::Auto)
2225                .map_err(|e| JsValue::from_str(&e.to_string()))?;
2226        }
2227        Ok(())
2228    }
2229}
2230
2231#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2232#[wasm_bindgen]
2233pub fn cfo_batch_into(
2234    in_ptr: *const f64,
2235    out_ptr: *mut f64,
2236    len: usize,
2237    period_start: usize,
2238    period_end: usize,
2239    period_step: usize,
2240    scalar_start: f64,
2241    scalar_end: f64,
2242    scalar_step: f64,
2243) -> Result<usize, JsValue> {
2244    if in_ptr.is_null() || out_ptr.is_null() {
2245        return Err(JsValue::from_str("null pointer passed to cfo_batch_into"));
2246    }
2247
2248    unsafe {
2249        let data = std::slice::from_raw_parts(in_ptr, len);
2250
2251        let sweep = CfoBatchRange {
2252            period: (period_start, period_end, period_step),
2253            scalar: (scalar_start, scalar_end, scalar_step),
2254        };
2255
2256        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2257        let rows = combos.len();
2258        let cols = len;
2259
2260        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2261
2262        let simd = match detect_best_batch_kernel() {
2263            Kernel::Avx512Batch => Kernel::Avx512,
2264            Kernel::Avx2Batch => Kernel::Avx2,
2265            _ => Kernel::Scalar,
2266        };
2267        cfo_batch_inner_into(data, &sweep, simd, false, out)
2268            .map_err(|e| JsValue::from_str(&e.to_string()))?;
2269
2270        Ok(rows)
2271    }
2272}