1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyUntypedArrayMethods};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9#[cfg(feature = "python")]
10use pyo3::PyErr;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25
26#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
27use core::arch::x86_64::*;
28#[cfg(not(target_arch = "wasm32"))]
29use rayon::prelude::*;
30use std::convert::AsRef;
31use std::error::Error;
32use std::mem::MaybeUninit;
33use thiserror::Error;
34
35#[derive(Debug, Clone)]
36pub enum CfoData<'a> {
37 Candles {
38 candles: &'a Candles,
39 source: &'a str,
40 },
41 Slice(&'a [f64]),
42}
43
44impl<'a> AsRef<[f64]> for CfoInput<'a> {
45 #[inline(always)]
46 fn as_ref(&self) -> &[f64] {
47 match &self.data {
48 CfoData::Slice(slice) => slice,
49 CfoData::Candles { candles, source } => source_type(candles, source),
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct CfoOutput {
56 pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone)]
60#[cfg_attr(
61 all(target_arch = "wasm32", feature = "wasm"),
62 derive(Serialize, Deserialize)
63)]
64pub struct CfoParams {
65 pub period: Option<usize>,
66 pub scalar: Option<f64>,
67}
68
69impl Default for CfoParams {
70 fn default() -> Self {
71 Self {
72 period: Some(14),
73 scalar: Some(100.0),
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
79pub struct CfoInput<'a> {
80 pub data: CfoData<'a>,
81 pub params: CfoParams,
82}
83
84impl<'a> CfoInput<'a> {
85 #[inline]
86 pub fn from_candles(c: &'a Candles, s: &'a str, p: CfoParams) -> Self {
87 Self {
88 data: CfoData::Candles {
89 candles: c,
90 source: s,
91 },
92 params: p,
93 }
94 }
95 #[inline]
96 pub fn from_slice(sl: &'a [f64], p: CfoParams) -> Self {
97 Self {
98 data: CfoData::Slice(sl),
99 params: p,
100 }
101 }
102 #[inline]
103 pub fn with_default_candles(c: &'a Candles) -> Self {
104 Self::from_candles(c, "close", CfoParams::default())
105 }
106 #[inline]
107 pub fn get_period(&self) -> usize {
108 self.params.period.unwrap_or(14)
109 }
110 #[inline]
111 pub fn get_scalar(&self) -> f64 {
112 self.params.scalar.unwrap_or(100.0)
113 }
114}
115
116#[derive(Copy, Clone, Debug)]
117pub struct CfoBuilder {
118 period: Option<usize>,
119 scalar: Option<f64>,
120 kernel: Kernel,
121}
122
123impl Default for CfoBuilder {
124 fn default() -> Self {
125 Self {
126 period: None,
127 scalar: None,
128 kernel: Kernel::Auto,
129 }
130 }
131}
132
133impl CfoBuilder {
134 #[inline(always)]
135 pub fn new() -> Self {
136 Self::default()
137 }
138 #[inline(always)]
139 pub fn period(mut self, n: usize) -> Self {
140 self.period = Some(n);
141 self
142 }
143 #[inline(always)]
144 pub fn scalar(mut self, x: f64) -> Self {
145 self.scalar = Some(x);
146 self
147 }
148 #[inline(always)]
149 pub fn kernel(mut self, k: Kernel) -> Self {
150 self.kernel = k;
151 self
152 }
153 #[inline(always)]
154 pub fn apply(self, c: &Candles) -> Result<CfoOutput, CfoError> {
155 let p = CfoParams {
156 period: self.period,
157 scalar: self.scalar,
158 };
159 let i = CfoInput::from_candles(c, "close", p);
160 cfo_with_kernel(&i, self.kernel)
161 }
162 #[inline(always)]
163 pub fn apply_slice(self, d: &[f64]) -> Result<CfoOutput, CfoError> {
164 let p = CfoParams {
165 period: self.period,
166 scalar: self.scalar,
167 };
168 let i = CfoInput::from_slice(d, p);
169 cfo_with_kernel(&i, self.kernel)
170 }
171 #[inline(always)]
172 pub fn into_stream(self) -> Result<CfoStream, CfoError> {
173 let p = CfoParams {
174 period: self.period,
175 scalar: self.scalar,
176 };
177 CfoStream::try_new(p)
178 }
179}
180
181#[derive(Debug, Error)]
182pub enum CfoError {
183 #[error("cfo: All values are NaN.")]
184 AllValuesNaN,
185 #[error("cfo: Invalid period: period = {period}, data length = {data_len}")]
186 InvalidPeriod { period: usize, data_len: usize },
187 #[error("cfo: Not enough valid data: needed = {needed}, valid = {valid}")]
188 NotEnoughValidData { needed: usize, valid: usize },
189 #[error("cfo: No data provided.")]
190 EmptyInputData,
191 #[error("cfo: output length mismatch: expected = {expected}, got = {got}")]
192 OutputLengthMismatch { expected: usize, got: usize },
193 #[error("cfo: invalid kernel for batch: {0:?}")]
194 InvalidKernelForBatch(Kernel),
195 #[error("cfo: Invalid range: start={start} end={end} step={step}")]
196 InvalidRange {
197 start: usize,
198 end: usize,
199 step: usize,
200 },
201}
202
203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
204impl From<CfoError> for JsValue {
205 fn from(err: CfoError) -> Self {
206 JsValue::from_str(&err.to_string())
207 }
208}
209
210#[inline]
211pub fn cfo(input: &CfoInput) -> Result<CfoOutput, CfoError> {
212 cfo_with_kernel(input, Kernel::Auto)
213}
214
215#[inline]
216pub fn cfo_into_slice(dst: &mut [f64], input: &CfoInput, kernel: Kernel) -> Result<(), CfoError> {
217 let data: &[f64] = input.as_ref();
218 let len = data.len();
219 let period = input.get_period();
220 let scalar = input.get_scalar();
221
222 if len == 0 {
223 return Err(CfoError::EmptyInputData);
224 }
225
226 if dst.len() != data.len() {
227 return Err(CfoError::OutputLengthMismatch {
228 expected: data.len(),
229 got: dst.len(),
230 });
231 }
232
233 let first = data
234 .iter()
235 .position(|x| !x.is_nan())
236 .ok_or(CfoError::AllValuesNaN)?;
237
238 if period == 0 || period > len {
239 return Err(CfoError::InvalidPeriod {
240 period,
241 data_len: len,
242 });
243 }
244 if (len - first) < period {
245 return Err(CfoError::NotEnoughValidData {
246 needed: period,
247 valid: len - first,
248 });
249 }
250
251 let chosen = match kernel {
252 Kernel::Auto => Kernel::Scalar,
253 other => other,
254 };
255
256 let warmup_end = first + period - 1;
257 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
258 for v in &mut dst[..warmup_end] {
259 *v = qnan;
260 }
261
262 unsafe {
263 match chosen {
264 Kernel::Scalar | Kernel::ScalarBatch => cfo_scalar(data, period, scalar, first, dst),
265 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
266 Kernel::Avx2 | Kernel::Avx2Batch => cfo_avx2(data, period, scalar, first, dst),
267 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
268 Kernel::Avx512 | Kernel::Avx512Batch => cfo_avx512(data, period, scalar, first, dst),
269 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
270 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
271 cfo_scalar(data, period, scalar, first, dst)
272 }
273 _ => unreachable!(),
274 }
275 }
276
277 Ok(())
278}
279
280#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
281#[inline]
282pub fn cfo_into(input: &CfoInput, out: &mut [f64]) -> Result<(), CfoError> {
283 cfo_into_slice(out, input, Kernel::Auto)
284}
285
286pub fn cfo_with_kernel(input: &CfoInput, kernel: Kernel) -> Result<CfoOutput, CfoError> {
287 let data: &[f64] = input.as_ref();
288 let len = data.len();
289 let period = input.get_period();
290 let scalar = input.get_scalar();
291
292 if len == 0 {
293 return Err(CfoError::EmptyInputData);
294 }
295
296 let first = data
297 .iter()
298 .position(|x| !x.is_nan())
299 .ok_or(CfoError::AllValuesNaN)?;
300
301 if period == 0 || period > len {
302 return Err(CfoError::InvalidPeriod {
303 period,
304 data_len: len,
305 });
306 }
307 if (len - first) < period {
308 return Err(CfoError::NotEnoughValidData {
309 needed: period,
310 valid: len - first,
311 });
312 }
313
314 let chosen = match kernel {
315 Kernel::Auto => Kernel::Scalar,
316 other => other,
317 };
318
319 let mut out = alloc_with_nan_prefix(len, first + period - 1);
320
321 unsafe {
322 match chosen {
323 Kernel::Scalar | Kernel::ScalarBatch => {
324 cfo_scalar(data, period, scalar, first, &mut out)
325 }
326 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
327 Kernel::Avx2 | Kernel::Avx2Batch => cfo_avx2(data, period, scalar, first, &mut out),
328 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
329 Kernel::Avx512 | Kernel::Avx512Batch => {
330 cfo_avx512(data, period, scalar, first, &mut out)
331 }
332 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
333 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
334 cfo_scalar(data, period, scalar, first, &mut out)
335 }
336 _ => unreachable!(),
337 }
338 }
339
340 Ok(CfoOutput { values: out })
341}
342
343#[inline]
344pub fn cfo_scalar(data: &[f64], period: usize, scalar: f64, first_valid: usize, out: &mut [f64]) {
345 let size = data.len();
346
347 let n = period as f64;
348 let inv_n = 1.0 / n;
349 let sx = ((period * (period + 1)) / 2) as f64;
350 let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
351 let inv_denom = 1.0 / (n * sx2 - sx * sx);
352 let half_nm1 = 0.5 * (n - 1.0);
353
354 let start = first_valid;
355 let pre = period - 1;
356 let mut sum_y = 0.0;
357 let mut sum_xy = 0.0;
358 for k in 0..pre {
359 let v = data[start + k];
360 let w = (k as f64) + 1.0;
361 sum_y += v;
362 sum_xy = v.mul_add(w, sum_xy);
363 }
364
365 let mut i = start + pre;
366 while i < size {
367 let v = data[i];
368 sum_xy = v.mul_add(n, sum_xy);
369 sum_y += v;
370 let b = (-sx).mul_add(sum_y, n * sum_xy) * inv_denom;
371 let f = b.mul_add(half_nm1, sum_y * inv_n);
372 unsafe {
373 *out.get_unchecked_mut(i) = if v.is_finite() && v != 0.0 {
374 (v - f) * (scalar / v)
375 } else {
376 f64::NAN
377 };
378 }
379 sum_xy -= sum_y;
380 sum_y -= data[i - pre];
381 i += 1;
382 }
383}
384
385#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
386#[inline]
387pub fn cfo_avx512(data: &[f64], period: usize, scalar: f64, first_valid: usize, out: &mut [f64]) {
388 if period <= 32 {
389 unsafe { cfo_avx512_short(data, period, scalar, first_valid, out) }
390 } else {
391 unsafe { cfo_avx512_long(data, period, scalar, first_valid, out) }
392 }
393}
394
395#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
396#[target_feature(enable = "avx2")]
397#[inline]
398pub unsafe fn cfo_avx2(
399 data: &[f64],
400 period: usize,
401 scalar: f64,
402 first_valid: usize,
403 out: &mut [f64],
404) {
405 cfo_scalar(data, period, scalar, first_valid, out);
406}
407
408#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
409#[target_feature(enable = "avx512f")]
410#[inline]
411pub unsafe fn cfo_avx512_short(
412 data: &[f64],
413 period: usize,
414 scalar: f64,
415 first_valid: usize,
416 out: &mut [f64],
417) {
418 cfo_scalar(data, period, scalar, first_valid, out);
419}
420
421#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
422#[target_feature(enable = "avx512f")]
423#[inline]
424pub unsafe fn cfo_avx512_long(
425 data: &[f64],
426 period: usize,
427 scalar: f64,
428 first_valid: usize,
429 out: &mut [f64],
430) {
431 cfo_scalar(data, period, scalar, first_valid, out);
432}
433
434#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
435#[target_feature(enable = "avx512f")]
436#[inline]
437unsafe fn cfo_avx512_impl(
438 data: &[f64],
439 period: usize,
440 scalar: f64,
441 first_valid: usize,
442 out: &mut [f64],
443) {
444 use core::arch::x86_64::*;
445 if period < 2 {
446 return cfo_scalar(data, period, scalar, first_valid, out);
447 }
448
449 let size = data.len();
450 let n = period as f64;
451 let inv_n = 1.0 / n;
452 let sx = ((period * (period + 1)) / 2) as f64;
453 let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
454 let inv_denom = 1.0 / (n * sx2 - sx * sx);
455 let half_nm1 = 0.5 * (n - 1.0);
456
457 let start = first_valid;
458 let pre = period - 1;
459 let end_init = start + pre;
460
461 let dp = data.as_ptr();
462
463 let mut acc_y = _mm512_set1_pd(0.0);
464 let mut acc_xy = _mm512_set1_pd(0.0);
465 let mut w = _mm512_setr_pd(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
466 let stepw = _mm512_set1_pd(8.0);
467
468 let mut j = start;
469 while j + 8 <= end_init {
470 let v = _mm512_loadu_pd(dp.add(j));
471 acc_y = _mm512_add_pd(acc_y, v);
472 let prod = _mm512_mul_pd(v, w);
473 acc_xy = _mm512_add_pd(acc_xy, prod);
474 w = _mm512_add_pd(w, stepw);
475 j += 8;
476 }
477
478 let mut tmp8 = [0.0f64; 8];
479 _mm512_storeu_pd(tmp8.as_mut_ptr(), acc_y);
480 let mut sum_y = tmp8.iter().sum::<f64>();
481 _mm512_storeu_pd(tmp8.as_mut_ptr(), acc_xy);
482 let mut sum_xy = tmp8.iter().sum::<f64>();
483
484 let mut w_scalar = 1.0 + ((j - start) as f64);
485 while j < end_init {
486 let v = *dp.add(j);
487 sum_y += v;
488 sum_xy = v.mul_add(w_scalar, sum_xy);
489 w_scalar += 1.0;
490 j += 1;
491 }
492
493 let mut i = start + pre;
494 while i < size {
495 let v0 = *dp.add(i);
496 sum_xy = v0.mul_add(n, sum_xy);
497 sum_y += v0;
498 let b0 = (-sx).mul_add(sum_y, n * sum_xy) * inv_denom;
499 let f0 = b0.mul_add(half_nm1, sum_y * inv_n);
500 *out.get_unchecked_mut(i) = if v0.is_finite() && v0 != 0.0 {
501 (v0 - f0) * (scalar / v0)
502 } else {
503 f64::NAN
504 };
505 sum_xy -= sum_y;
506 sum_y -= *dp.add(i - pre);
507 i += 1;
508 }
509}
510
511#[inline(always)]
512pub fn cfo_row_scalar(
513 data: &[f64],
514 first: usize,
515 period: usize,
516 _stride: usize,
517 scalar: f64,
518 out: &mut [f64],
519) {
520 cfo_scalar(data, period, scalar, first, out)
521}
522
523#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
524#[inline(always)]
525pub unsafe fn cfo_row_avx2(
526 data: &[f64],
527 first: usize,
528 period: usize,
529 _stride: usize,
530 scalar: f64,
531 out: &mut [f64],
532) {
533 cfo_scalar(data, period, scalar, first, out)
534}
535
536#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
537#[inline(always)]
538pub unsafe fn cfo_row_avx512(
539 data: &[f64],
540 first: usize,
541 period: usize,
542 stride: usize,
543 scalar: f64,
544 out: &mut [f64],
545) {
546 if period <= 32 {
547 cfo_row_avx512_short(data, first, period, stride, scalar, out)
548 } else {
549 cfo_row_avx512_long(data, first, period, stride, scalar, out)
550 }
551}
552
553#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
554#[inline(always)]
555pub unsafe fn cfo_row_avx512_short(
556 data: &[f64],
557 first: usize,
558 period: usize,
559 _stride: usize,
560 scalar: f64,
561 out: &mut [f64],
562) {
563 cfo_scalar(data, period, scalar, first, out)
564}
565
566#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
567#[inline(always)]
568pub unsafe fn cfo_row_avx512_long(
569 data: &[f64],
570 first: usize,
571 period: usize,
572 _stride: usize,
573 scalar: f64,
574 out: &mut [f64],
575) {
576 cfo_scalar(data, period, scalar, first, out)
577}
578
579#[derive(Debug, Clone)]
580pub struct CfoStream {
581 period: usize,
582 scalar: f64,
583
584 buf: Vec<f64>,
585 idx: usize,
586 filled: bool,
587
588 sum_y: f64,
589 sum_xy: f64,
590
591 n: f64,
592 inv_n: f64,
593 sx: f64,
594 inv_denom: f64,
595 half_nm1: f64,
596}
597
598impl CfoStream {
599 pub fn try_new(params: CfoParams) -> Result<Self, CfoError> {
600 let period = params.period.unwrap_or(14);
601 if period == 0 {
602 return Err(CfoError::InvalidPeriod {
603 period,
604 data_len: 0,
605 });
606 }
607 let scalar = params.scalar.unwrap_or(100.0);
608 let n = period as f64;
609 let inv_n = 1.0 / n;
610 let sx = ((period * (period + 1)) / 2) as f64;
611 let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
612 let inv_denom = 1.0 / (n * sx2 - sx * sx);
613 let half_nm1 = 0.5 * (n - 1.0);
614
615 Ok(Self {
616 period,
617 scalar,
618 buf: vec![f64::NAN; period],
619 idx: 0,
620 filled: false,
621 sum_y: 0.0,
622 sum_xy: 0.0,
623 n,
624 inv_n,
625 sx,
626 inv_denom,
627 half_nm1,
628 })
629 }
630
631 #[inline(always)]
632 pub fn update(&mut self, value: f64) -> Option<f64> {
633 if !self.filled {
634 let k = (self.idx as f64) + 1.0;
635 self.sum_y += value;
636 self.sum_xy = value.mul_add(k, self.sum_xy);
637
638 self.buf[self.idx] = value;
639 self.idx += 1;
640
641 if self.idx == self.period {
642 self.idx = 0;
643 self.filled = true;
644
645 return Some(self.calc_current());
646 } else {
647 return None;
648 }
649 }
650
651 let y_old = self.buf[self.idx];
652
653 let new_sum_xy = (self.n * value) + (self.sum_xy - self.sum_y);
654 let new_sum_y = self.sum_y - y_old + value;
655
656 self.buf[self.idx] = value;
657 self.sum_xy = new_sum_xy;
658 self.sum_y = new_sum_y;
659
660 self.idx = (self.idx + 1) % self.period;
661
662 Some(self.calc_current())
663 }
664
665 #[inline(always)]
666 fn calc_current(&self) -> f64 {
667 debug_assert!(self.filled, "calc_current() called before buffer filled");
668
669 let cur = self.buf[(self.idx + self.period - 1) % self.period];
670
671 if cur.is_finite() && cur != 0.0 {
672 let b = (-self.sx).mul_add(self.sum_y, self.n * self.sum_xy) * self.inv_denom;
673 let f = b.mul_add(self.half_nm1, self.sum_y * self.inv_n);
674
675 self.scalar.mul_add(-f / cur, self.scalar)
676 } else {
677 f64::NAN
678 }
679 }
680}
681
682#[derive(Clone, Debug)]
683pub struct CfoBatchRange {
684 pub period: (usize, usize, usize),
685 pub scalar: (f64, f64, f64),
686}
687
688impl Default for CfoBatchRange {
689 fn default() -> Self {
690 Self {
691 period: (14, 14, 0),
692 scalar: (100.0, 124.9, 0.1),
693 }
694 }
695}
696
697#[derive(Clone, Debug, Default)]
698pub struct CfoBatchBuilder {
699 range: CfoBatchRange,
700 kernel: Kernel,
701}
702
703impl CfoBatchBuilder {
704 pub fn new() -> Self {
705 Self::default()
706 }
707 pub fn kernel(mut self, k: Kernel) -> Self {
708 self.kernel = k;
709 self
710 }
711 #[inline]
712 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
713 self.range.period = (start, end, step);
714 self
715 }
716 #[inline]
717 pub fn period_static(mut self, p: usize) -> Self {
718 self.range.period = (p, p, 0);
719 self
720 }
721 #[inline]
722 pub fn scalar_range(mut self, start: f64, end: f64, step: f64) -> Self {
723 self.range.scalar = (start, end, step);
724 self
725 }
726 #[inline]
727 pub fn scalar_static(mut self, s: f64) -> Self {
728 self.range.scalar = (s, s, 0.0);
729 self
730 }
731 pub fn apply_slice(self, data: &[f64]) -> Result<CfoBatchOutput, CfoError> {
732 cfo_batch_with_kernel(data, &self.range, self.kernel)
733 }
734 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CfoBatchOutput, CfoError> {
735 CfoBatchBuilder::new().kernel(k).apply_slice(data)
736 }
737 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CfoBatchOutput, CfoError> {
738 let slice = source_type(c, src);
739 self.apply_slice(slice)
740 }
741 pub fn with_default_candles(c: &Candles) -> Result<CfoBatchOutput, CfoError> {
742 CfoBatchBuilder::new()
743 .kernel(Kernel::Auto)
744 .apply_candles(c, "close")
745 }
746}
747
748pub fn cfo_batch_with_kernel(
749 data: &[f64],
750 sweep: &CfoBatchRange,
751 k: Kernel,
752) -> Result<CfoBatchOutput, CfoError> {
753 let kernel = match k {
754 Kernel::Auto => detect_best_batch_kernel(),
755 other if other.is_batch() => other,
756 other => return Err(CfoError::InvalidKernelForBatch(other)),
757 };
758 let simd = match kernel {
759 Kernel::Avx512Batch => Kernel::Avx512,
760 Kernel::Avx2Batch => Kernel::Avx2,
761 Kernel::ScalarBatch => Kernel::Scalar,
762 _ => unreachable!(),
763 };
764 cfo_batch_par_slice(data, sweep, simd)
765}
766
767#[derive(Clone, Debug)]
768pub struct CfoBatchOutput {
769 pub values: Vec<f64>,
770 pub combos: Vec<CfoParams>,
771 pub rows: usize,
772 pub cols: usize,
773}
774
775impl CfoBatchOutput {
776 pub fn row_for_params(&self, p: &CfoParams) -> Option<usize> {
777 self.combos.iter().position(|c| {
778 c.period.unwrap_or(14) == p.period.unwrap_or(14)
779 && (c.scalar.unwrap_or(100.0) - p.scalar.unwrap_or(100.0)).abs() < 1e-12
780 })
781 }
782 pub fn values_for(&self, p: &CfoParams) -> Option<&[f64]> {
783 self.row_for_params(p).map(|row| {
784 let start = row * self.cols;
785 &self.values[start..start + self.cols]
786 })
787 }
788}
789
790#[inline(always)]
791fn expand_grid(r: &CfoBatchRange) -> Result<Vec<CfoParams>, CfoError> {
792 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, CfoError> {
793 if step == 0 || start == end {
794 return Ok(vec![start]);
795 }
796 let mut vals = Vec::new();
797 if start < end {
798 let mut cur = start;
799 while cur <= end {
800 vals.push(cur);
801 match cur.checked_add(step) {
802 Some(next) => cur = next,
803 None => break,
804 }
805 }
806 } else {
807 return Err(CfoError::InvalidRange { start, end, step });
808 }
809 if vals.is_empty() {
810 return Err(CfoError::InvalidRange { start, end, step });
811 }
812 Ok(vals)
813 }
814 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, CfoError> {
815 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
816 return Ok(vec![start]);
817 }
818 let mut vals = Vec::new();
819 let delta = if start <= end {
820 step.abs()
821 } else {
822 -step.abs()
823 };
824 if delta.is_sign_positive() {
825 let mut x = start;
826 while x <= end + 1e-12 {
827 vals.push(x);
828 x += delta;
829 if !x.is_finite() {
830 break;
831 }
832 }
833 } else {
834 let mut x = start;
835 while x >= end - 1e-12 {
836 vals.push(x);
837 x += delta;
838 if !x.is_finite() {
839 break;
840 }
841 }
842 }
843 if vals.is_empty() {
844 return Err(CfoError::InvalidRange {
845 start: 0,
846 end: 0,
847 step: 0,
848 });
849 }
850 Ok(vals)
851 }
852 let periods = axis_usize(r.period)?;
853 let scalars = axis_f64(r.scalar)?;
854 let combos_len = periods
855 .len()
856 .checked_mul(scalars.len())
857 .ok_or(CfoError::InvalidRange {
858 start: periods.len(),
859 end: scalars.len(),
860 step: 0,
861 })?;
862 let mut out = Vec::with_capacity(combos_len);
863 for &p in &periods {
864 for &s in &scalars {
865 out.push(CfoParams {
866 period: Some(p),
867 scalar: Some(s),
868 });
869 }
870 }
871 Ok(out)
872}
873
874#[inline(always)]
875pub fn cfo_batch_slice(
876 data: &[f64],
877 sweep: &CfoBatchRange,
878 kern: Kernel,
879) -> Result<CfoBatchOutput, CfoError> {
880 cfo_batch_inner(data, sweep, kern, false)
881}
882
883#[inline(always)]
884pub fn cfo_batch_par_slice(
885 data: &[f64],
886 sweep: &CfoBatchRange,
887 kern: Kernel,
888) -> Result<CfoBatchOutput, CfoError> {
889 cfo_batch_inner(data, sweep, kern, true)
890}
891
892#[inline(always)]
893fn cfo_batch_inner(
894 data: &[f64],
895 sweep: &CfoBatchRange,
896 kern: Kernel,
897 parallel: bool,
898) -> Result<CfoBatchOutput, CfoError> {
899 let combos = expand_grid(sweep)?;
900 if data.is_empty() {
901 return Err(CfoError::EmptyInputData);
902 }
903
904 for combo in &combos {
905 let period = combo.period.unwrap();
906 if period == 0 || period > data.len() {
907 return Err(CfoError::InvalidPeriod {
908 period,
909 data_len: data.len(),
910 });
911 }
912 }
913 let first = data
914 .iter()
915 .position(|x| !x.is_nan())
916 .ok_or(CfoError::AllValuesNaN)?;
917 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
918 if data.len() - first < max_p {
919 return Err(CfoError::NotEnoughValidData {
920 needed: max_p,
921 valid: data.len() - first,
922 });
923 }
924 let rows = combos.len();
925 let cols = data.len();
926 rows.checked_mul(cols).ok_or(CfoError::InvalidRange {
927 start: rows,
928 end: cols,
929 step: 0,
930 })?;
931
932 let mut buf_mu = make_uninit_matrix(rows, cols);
933
934 let warm: Vec<usize> = combos
935 .iter()
936 .map(|c| first + c.period.unwrap() - 1)
937 .collect();
938 init_matrix_prefixes(&mut buf_mu, cols, &warm);
939
940 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
941 let values: &mut [f64] = unsafe {
942 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
943 };
944
945 cfo_batch_prefix_scalar_rows(data, first, &combos, values, rows, cols, parallel);
946
947 let values = unsafe {
948 Vec::from_raw_parts(
949 buf_guard.as_mut_ptr() as *mut f64,
950 buf_guard.len(),
951 buf_guard.capacity(),
952 )
953 };
954
955 Ok(CfoBatchOutput {
956 values,
957 combos,
958 rows,
959 cols,
960 })
961}
962
963#[inline]
964fn cfo_batch_prefix_scalar_rows(
965 data: &[f64],
966 first: usize,
967 combos: &[CfoParams],
968 values: &mut [f64],
969 rows: usize,
970 cols: usize,
971 parallel: bool,
972) {
973 let valid_len = cols - first;
974 if valid_len == 0 || rows == 0 {
975 return;
976 }
977
978 let mut p = Vec::with_capacity(valid_len + 1);
979 let mut q = Vec::with_capacity(valid_len + 1);
980 p.push(0.0);
981 q.push(0.0);
982 let mut ps = 0.0f64;
983 let mut qs = 0.0f64;
984 for (idx, &v) in data[first..].iter().enumerate() {
985 let j = (idx as f64) + 1.0;
986 ps += v;
987 qs = v.mul_add(j, qs);
988 p.push(ps);
989 q.push(qs);
990 }
991
992 let compute_row = |row: usize, out_row: &mut [f64]| {
993 let period = combos[row].period.unwrap();
994 let scalar = combos[row].scalar.unwrap();
995 let n = period as f64;
996 let inv_n = 1.0 / n;
997 let sx = ((period * (period + 1)) / 2) as f64;
998 let sx2 = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
999 let inv_denom = 1.0 / (n * sx2 - sx * sx);
1000 let half_nm1 = 0.5 * (n - 1.0);
1001
1002 if cols < first + period {
1003 return;
1004 }
1005
1006 let start_idx = first + period - 1;
1007 for di in start_idx..cols {
1008 let idx = di - first;
1009
1010 let r1 = idx + 1;
1011 let l1_minus1 = idx + 1 - period;
1012
1013 let sum_y = p[r1] - p[l1_minus1];
1014 let sum_xy = (q[r1] - q[l1_minus1]) - (l1_minus1 as f64) * sum_y;
1015
1016 let b = (-sx).mul_add(sum_y, n * sum_xy) * inv_denom;
1017 let f = b.mul_add(half_nm1, sum_y * inv_n);
1018 let v = data[di];
1019 let y = if v.is_finite() && v != 0.0 {
1020 (v - f) * (scalar / v)
1021 } else {
1022 f64::NAN
1023 };
1024 out_row[di] = y;
1025 }
1026 };
1027
1028 if parallel {
1029 #[cfg(not(target_arch = "wasm32"))]
1030 {
1031 values
1032 .par_chunks_mut(cols)
1033 .enumerate()
1034 .for_each(|(row, slice)| compute_row(row, slice));
1035 }
1036 #[cfg(target_arch = "wasm32")]
1037 {
1038 for (row, slice) in values.chunks_mut(cols).enumerate() {
1039 compute_row(row, slice);
1040 }
1041 }
1042 } else {
1043 for (row, slice) in values.chunks_mut(cols).enumerate() {
1044 compute_row(row, slice);
1045 }
1046 }
1047}
1048
1049#[inline(always)]
1050pub fn cfo_batch_inner_into(
1051 data: &[f64],
1052 sweep: &CfoBatchRange,
1053 kern: Kernel,
1054 parallel: bool,
1055 output: &mut [f64],
1056) -> Result<Vec<CfoParams>, CfoError> {
1057 let combos = expand_grid(sweep)?;
1058 if data.is_empty() {
1059 return Err(CfoError::EmptyInputData);
1060 }
1061
1062 for combo in &combos {
1063 let period = combo.period.unwrap();
1064 if period == 0 || period > data.len() {
1065 return Err(CfoError::InvalidPeriod {
1066 period,
1067 data_len: data.len(),
1068 });
1069 }
1070 }
1071 let first = data
1072 .iter()
1073 .position(|x| !x.is_nan())
1074 .ok_or(CfoError::AllValuesNaN)?;
1075 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1076 if data.len() - first < max_p {
1077 return Err(CfoError::NotEnoughValidData {
1078 needed: max_p,
1079 valid: data.len() - first,
1080 });
1081 }
1082 let rows = combos.len();
1083 let cols = data.len();
1084 let total = rows.checked_mul(cols).ok_or(CfoError::InvalidRange {
1085 start: rows,
1086 end: cols,
1087 step: 0,
1088 })?;
1089 if output.len() != total {
1090 return Err(CfoError::OutputLengthMismatch {
1091 expected: total,
1092 got: output.len(),
1093 });
1094 }
1095
1096 let out_mu: &mut [core::mem::MaybeUninit<f64>] = unsafe {
1097 core::slice::from_raw_parts_mut(
1098 output.as_mut_ptr() as *mut core::mem::MaybeUninit<f64>,
1099 output.len(),
1100 )
1101 };
1102 let warm: Vec<usize> = combos
1103 .iter()
1104 .map(|c| first + c.period.unwrap() - 1)
1105 .collect();
1106 init_matrix_prefixes(out_mu, cols, &warm);
1107
1108 let values: &mut [f64] =
1109 unsafe { core::slice::from_raw_parts_mut(out_mu.as_mut_ptr() as *mut f64, out_mu.len()) };
1110 cfo_batch_prefix_scalar_rows(data, first, &combos, values, rows, cols, parallel);
1111
1112 Ok(combos)
1113}
1114
1115#[cfg(test)]
1116mod tests {
1117 use super::*;
1118 use crate::skip_if_unsupported;
1119 use crate::utilities::data_loader::read_candles_from_csv;
1120
1121 fn check_cfo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1122 skip_if_unsupported!(kernel, test_name);
1123 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1124 let candles = read_candles_from_csv(file_path)?;
1125
1126 let default_params = CfoParams {
1127 period: None,
1128 scalar: None,
1129 };
1130 let input = CfoInput::from_candles(&candles, "close", default_params);
1131 let output = cfo_with_kernel(&input, kernel)?;
1132 assert_eq!(output.values.len(), candles.close.len());
1133 Ok(())
1134 }
1135
1136 fn check_cfo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1137 skip_if_unsupported!(kernel, test_name);
1138 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1139 let candles = read_candles_from_csv(file_path)?;
1140 let input = CfoInput::from_candles(&candles, "close", CfoParams::default());
1141 let result = cfo_with_kernel(&input, kernel)?;
1142 let expected_last_five = [
1143 0.5998626489475746,
1144 0.47578011282578453,
1145 0.20349744599816233,
1146 0.0919617952835795,
1147 -0.5676291145560617,
1148 ];
1149 let start = result.values.len().saturating_sub(5);
1150 for (i, &val) in result.values[start..].iter().enumerate() {
1151 let diff = (val - expected_last_five[i]).abs();
1152 assert!(
1153 diff < 1e-6,
1154 "[{}] CFO {:?} mismatch at idx {}: got {}, expected {}",
1155 test_name,
1156 kernel,
1157 i,
1158 val,
1159 expected_last_five[i]
1160 );
1161 }
1162 Ok(())
1163 }
1164
1165 fn check_cfo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1166 skip_if_unsupported!(kernel, test_name);
1167 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1168 let candles = read_candles_from_csv(file_path)?;
1169 let input = CfoInput::with_default_candles(&candles);
1170 match input.data {
1171 CfoData::Candles { source, .. } => assert_eq!(source, "close"),
1172 _ => panic!("Expected CfoData::Candles"),
1173 }
1174 let output = cfo_with_kernel(&input, kernel)?;
1175 assert_eq!(output.values.len(), candles.close.len());
1176 Ok(())
1177 }
1178
1179 fn check_cfo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1180 skip_if_unsupported!(kernel, test_name);
1181 let input_data = [10.0, 20.0, 30.0];
1182 let params = CfoParams {
1183 period: Some(0),
1184 scalar: Some(100.0),
1185 };
1186 let input = CfoInput::from_slice(&input_data, params);
1187 let res = cfo_with_kernel(&input, kernel);
1188 assert!(
1189 res.is_err(),
1190 "[{}] CFO should fail with zero period",
1191 test_name
1192 );
1193 Ok(())
1194 }
1195
1196 fn check_cfo_period_exceeds_length(
1197 test_name: &str,
1198 kernel: Kernel,
1199 ) -> Result<(), Box<dyn Error>> {
1200 skip_if_unsupported!(kernel, test_name);
1201 let data_small = [10.0, 20.0, 30.0];
1202 let params = CfoParams {
1203 period: Some(10),
1204 scalar: Some(100.0),
1205 };
1206 let input = CfoInput::from_slice(&data_small, params);
1207 let res = cfo_with_kernel(&input, kernel);
1208 assert!(
1209 res.is_err(),
1210 "[{}] CFO should fail with period exceeding length",
1211 test_name
1212 );
1213 Ok(())
1214 }
1215
1216 fn check_cfo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1217 skip_if_unsupported!(kernel, test_name);
1218 let single_point = [42.0];
1219 let params = CfoParams {
1220 period: Some(14),
1221 scalar: Some(100.0),
1222 };
1223 let input = CfoInput::from_slice(&single_point, params);
1224 let res = cfo_with_kernel(&input, kernel);
1225 assert!(
1226 res.is_err(),
1227 "[{}] CFO should fail with insufficient data",
1228 test_name
1229 );
1230 Ok(())
1231 }
1232
1233 fn check_cfo_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1234 skip_if_unsupported!(kernel, test_name);
1235 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1236 let candles = read_candles_from_csv(file_path)?;
1237
1238 let first_params = CfoParams {
1239 period: Some(14),
1240 scalar: Some(100.0),
1241 };
1242 let first_input = CfoInput::from_candles(&candles, "close", first_params);
1243 let first_result = cfo_with_kernel(&first_input, kernel)?;
1244
1245 let second_params = CfoParams {
1246 period: Some(14),
1247 scalar: Some(100.0),
1248 };
1249 let second_input = CfoInput::from_slice(&first_result.values, second_params);
1250 let second_result = cfo_with_kernel(&second_input, kernel)?;
1251
1252 assert_eq!(second_result.values.len(), first_result.values.len());
1253 for i in 240..second_result.values.len() {
1254 assert!(
1255 !second_result.values[i].is_nan(),
1256 "[{}] Expected no NaN after idx 240, found NaN at {}",
1257 test_name,
1258 i
1259 );
1260 }
1261 Ok(())
1262 }
1263
1264 fn check_cfo_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1265 skip_if_unsupported!(kernel, test_name);
1266 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1267 let candles = read_candles_from_csv(file_path)?;
1268
1269 let input = CfoInput::from_candles(
1270 &candles,
1271 "close",
1272 CfoParams {
1273 period: Some(14),
1274 scalar: Some(100.0),
1275 },
1276 );
1277 let res = cfo_with_kernel(&input, kernel)?;
1278 assert_eq!(res.values.len(), candles.close.len());
1279 if res.values.len() > 240 {
1280 for (i, &val) in res.values[240..].iter().enumerate() {
1281 assert!(
1282 !val.is_nan(),
1283 "[{}] Found unexpected NaN at out-index {}",
1284 test_name,
1285 240 + i
1286 );
1287 }
1288 }
1289 Ok(())
1290 }
1291
1292 fn check_cfo_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1293 skip_if_unsupported!(kernel, test_name);
1294 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1295 let candles = read_candles_from_csv(file_path)?;
1296
1297 let period = 14;
1298 let scalar = 100.0;
1299
1300 let input = CfoInput::from_candles(
1301 &candles,
1302 "close",
1303 CfoParams {
1304 period: Some(period),
1305 scalar: Some(scalar),
1306 },
1307 );
1308 let batch_output = cfo_with_kernel(&input, kernel)?.values;
1309
1310 let mut stream = CfoStream::try_new(CfoParams {
1311 period: Some(period),
1312 scalar: Some(scalar),
1313 })?;
1314
1315 let mut stream_values = Vec::with_capacity(candles.close.len());
1316 for &price in &candles.close {
1317 match stream.update(price) {
1318 Some(val) => stream_values.push(val),
1319 None => stream_values.push(f64::NAN),
1320 }
1321 }
1322
1323 assert_eq!(batch_output.len(), stream_values.len());
1324 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1325 if b.is_nan() && s.is_nan() {
1326 continue;
1327 }
1328 let diff = (b - s).abs();
1329 assert!(
1330 diff < 1e-9,
1331 "[{}] CFO streaming f64 mismatch at idx {}: batch={}, stream={}, diff={}",
1332 test_name,
1333 i,
1334 b,
1335 s,
1336 diff
1337 );
1338 }
1339 Ok(())
1340 }
1341
1342 #[test]
1343 fn test_cfo_into_matches_api() -> Result<(), Box<dyn Error>> {
1344 let mut data = Vec::with_capacity(256);
1345
1346 data.push(f64::NAN);
1347 data.push(f64::NAN);
1348 for i in 1..=254u32 {
1349 data.push(50.0 + (i as f64) * 0.25);
1350 }
1351
1352 let input = CfoInput::from_slice(&data, CfoParams::default());
1353
1354 let baseline = cfo(&input)?.values;
1355
1356 let mut out = vec![0.0; data.len()];
1357 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1358 {
1359 cfo_into(&input, &mut out)?;
1360 }
1361 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1362 {
1363 cfo_into_slice(&mut out, &input, Kernel::Auto)?;
1364 }
1365
1366 assert_eq!(baseline.len(), out.len());
1367
1368 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1369 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1370 }
1371
1372 for i in 0..out.len() {
1373 assert!(
1374 eq_or_both_nan(baseline[i], out[i]),
1375 "mismatch at {}: baseline={} out={}",
1376 i,
1377 baseline[i],
1378 out[i]
1379 );
1380 }
1381
1382 Ok(())
1383 }
1384
1385 #[cfg(debug_assertions)]
1386 fn check_cfo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1387 skip_if_unsupported!(kernel, test_name);
1388
1389 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1390 let candles = read_candles_from_csv(file_path)?;
1391
1392 let param_sets = vec![
1393 CfoParams {
1394 period: Some(14),
1395 scalar: Some(100.0),
1396 },
1397 CfoParams {
1398 period: Some(7),
1399 scalar: Some(50.0),
1400 },
1401 CfoParams {
1402 period: Some(21),
1403 scalar: Some(200.0),
1404 },
1405 CfoParams {
1406 period: Some(30),
1407 scalar: Some(100.0),
1408 },
1409 ];
1410
1411 for params in param_sets {
1412 let input = CfoInput::from_candles(&candles, "close", params.clone());
1413 let output = cfo_with_kernel(&input, kernel)?;
1414
1415 for (i, &val) in output.values.iter().enumerate() {
1416 if val.is_nan() {
1417 continue;
1418 }
1419
1420 let bits = val.to_bits();
1421
1422 if bits == 0x11111111_11111111 {
1423 panic!(
1424 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with params {:?}",
1425 test_name, val, bits, i, params
1426 );
1427 }
1428
1429 if bits == 0x22222222_22222222 {
1430 panic!(
1431 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with params {:?}",
1432 test_name, val, bits, i, params
1433 );
1434 }
1435
1436 if bits == 0x33333333_33333333 {
1437 panic!(
1438 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with params {:?}",
1439 test_name, val, bits, i, params
1440 );
1441 }
1442 }
1443 }
1444
1445 Ok(())
1446 }
1447
1448 #[cfg(not(debug_assertions))]
1449 fn check_cfo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1450 Ok(())
1451 }
1452
1453 macro_rules! generate_all_cfo_tests {
1454 ($($test_fn:ident),*) => {
1455 paste::paste! {
1456 $(
1457 #[test]
1458 fn [<$test_fn _scalar_f64>]() {
1459 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1460 }
1461 )*
1462 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1463 $(
1464 #[test]
1465 fn [<$test_fn _avx2_f64>]() {
1466 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1467 }
1468 #[test]
1469 fn [<$test_fn _avx512_f64>]() {
1470 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1471 }
1472 )*
1473 }
1474 }
1475 }
1476
1477 #[cfg(feature = "proptest")]
1478 #[allow(clippy::float_cmp)]
1479 fn check_cfo_property(
1480 test_name: &str,
1481 kernel: Kernel,
1482 ) -> Result<(), Box<dyn std::error::Error>> {
1483 use proptest::prelude::*;
1484 skip_if_unsupported!(kernel, test_name);
1485
1486 let strat = (2usize..=50).prop_flat_map(|period| {
1487 (
1488 prop::collection::vec(
1489 (-1e6f64..1e6f64)
1490 .prop_filter("finite and non-zero", |x| x.is_finite() && x.abs() > 1e-10),
1491 period..400,
1492 ),
1493 Just(period),
1494 50.0f64..200.0f64,
1495 )
1496 });
1497
1498 proptest::test_runner::TestRunner::default()
1499 .run(&strat, |(data, period, scalar)| {
1500 let params = CfoParams {
1501 period: Some(period),
1502 scalar: Some(scalar),
1503 };
1504 let input = CfoInput::from_slice(&data, params);
1505
1506 let CfoOutput { values: out } = cfo_with_kernel(&input, kernel).unwrap();
1507 let CfoOutput { values: ref_out } =
1508 cfo_with_kernel(&input, Kernel::Scalar).unwrap();
1509
1510 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1511
1512 for i in 0..(period - 1) {
1513 prop_assert!(
1514 out[i].is_nan(),
1515 "Expected NaN during warmup at index {}, got {}",
1516 i,
1517 out[i]
1518 );
1519 }
1520
1521 if data.len() >= period {
1522 let first_valid_idx = period - 1;
1523 prop_assert!(
1524 out[first_valid_idx].is_finite(),
1525 "Expected finite value at first valid index {}, got {}",
1526 first_valid_idx,
1527 out[first_valid_idx]
1528 );
1529 }
1530
1531 for i in (period - 1)..data.len() {
1532 let y = out[i];
1533 let r = ref_out[i];
1534
1535 prop_assert!(
1536 y.is_finite(),
1537 "Expected finite output at index {}, got {}",
1538 i,
1539 y
1540 );
1541
1542 let window_start = i.saturating_sub(period - 1);
1543 let window = &data[window_start..=i];
1544 if window.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) {
1545 let expected = 0.0;
1546 prop_assert!(
1547 (y - expected).abs() < 1e-6,
1548 "For constant data at idx {}, expected CFO ≈ 0, got {}",
1549 i,
1550 y
1551 );
1552 }
1553
1554 let y_bits = y.to_bits();
1555 let r_bits = r.to_bits();
1556
1557 if !y.is_finite() || !r.is_finite() {
1558 prop_assert!(
1559 y.to_bits() == r.to_bits(),
1560 "finite/NaN mismatch at idx {}: {} vs {}",
1561 i,
1562 y,
1563 r
1564 );
1565 continue;
1566 }
1567
1568 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1569 prop_assert!(
1570 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1571 "Kernel mismatch at idx {}: {} vs {} (ULP={})",
1572 i,
1573 y,
1574 r,
1575 ulp_diff
1576 );
1577 }
1578
1579 if data.len() > period {
1580 let params_double = CfoParams {
1581 period: Some(period),
1582 scalar: Some(scalar * 2.0),
1583 };
1584 let input_double = CfoInput::from_slice(&data, params_double);
1585 let CfoOutput { values: out_double } =
1586 cfo_with_kernel(&input_double, kernel).unwrap();
1587
1588 for i in (period - 1)..data.len().min(period + 10) {
1589 if out[i].is_finite() && out_double[i].is_finite() && out[i].abs() > 1e-6 {
1590 let ratio = out_double[i] / out[i];
1591 prop_assert!(
1592 (ratio - 2.0).abs() < 1e-9,
1593 "Scale linearity failed at idx {}: ratio = {} (expected 2.0)",
1594 i,
1595 ratio
1596 );
1597 }
1598 }
1599 }
1600
1601 if period == 2 && data.len() >= 2 {
1602 prop_assert!(out[0].is_nan(), "Period=2: first value should be NaN");
1603
1604 prop_assert!(
1605 out[1].is_finite(),
1606 "Period=2: second value should be finite, got {}",
1607 out[1]
1608 );
1609 }
1610
1611 Ok(())
1612 })
1613 .unwrap();
1614
1615 Ok(())
1616 }
1617
1618 #[cfg(not(feature = "proptest"))]
1619 fn check_cfo_property(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1620 skip_if_unsupported!(kernel, test_name);
1621 Ok(())
1622 }
1623
1624 generate_all_cfo_tests!(
1625 check_cfo_partial_params,
1626 check_cfo_accuracy,
1627 check_cfo_default_candles,
1628 check_cfo_zero_period,
1629 check_cfo_period_exceeds_length,
1630 check_cfo_very_small_dataset,
1631 check_cfo_reinput,
1632 check_cfo_nan_handling,
1633 check_cfo_streaming,
1634 check_cfo_no_poison,
1635 check_cfo_property
1636 );
1637
1638 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1639 skip_if_unsupported!(kernel, test);
1640 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1641 let c = read_candles_from_csv(file)?;
1642
1643 let output = CfoBatchBuilder::new()
1644 .kernel(kernel)
1645 .apply_candles(&c, "close")?;
1646
1647 let def = CfoParams::default();
1648 let row = output.values_for(&def).expect("default row missing");
1649
1650 assert_eq!(row.len(), c.close.len());
1651
1652 let expected = [
1653 0.5998626489475746,
1654 0.47578011282578453,
1655 0.20349744599816233,
1656 0.0919617952835795,
1657 -0.5676291145560617,
1658 ];
1659 let start = row.len() - 5;
1660 for (i, &v) in row[start..].iter().enumerate() {
1661 assert!(
1662 (v - expected[i]).abs() < 1e-6,
1663 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1664 );
1665 }
1666 Ok(())
1667 }
1668
1669 macro_rules! gen_batch_tests {
1670 ($fn_name:ident) => {
1671 paste::paste! {
1672 #[test] fn [<$fn_name _scalar>]() {
1673 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1674 }
1675 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1676 #[test] fn [<$fn_name _avx2>]() {
1677 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1678 }
1679 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1680 #[test] fn [<$fn_name _avx512>]() {
1681 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1682 }
1683 #[test] fn [<$fn_name _auto_detect>]() {
1684 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1685 }
1686 }
1687 };
1688 }
1689
1690 #[cfg(debug_assertions)]
1691 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1692 skip_if_unsupported!(kernel, test);
1693
1694 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1695 let c = read_candles_from_csv(file)?;
1696
1697 let output = CfoBatchBuilder::new()
1698 .kernel(kernel)
1699 .period_range(5, 30, 5)
1700 .scalar_range(50.0, 200.0, 50.0)
1701 .apply_candles(&c, "close")?;
1702
1703 for (idx, &val) in output.values.iter().enumerate() {
1704 if val.is_nan() {
1705 continue;
1706 }
1707
1708 let bits = val.to_bits();
1709 let row = idx / output.cols;
1710 let col = idx % output.cols;
1711
1712 if bits == 0x11111111_11111111 {
1713 panic!(
1714 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1715 test, val, bits, row, col, idx
1716 );
1717 }
1718
1719 if bits == 0x22222222_22222222 {
1720 panic!(
1721 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1722 test, val, bits, row, col, idx
1723 );
1724 }
1725
1726 if bits == 0x33333333_33333333 {
1727 panic!(
1728 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1729 test, val, bits, row, col, idx
1730 );
1731 }
1732 }
1733
1734 Ok(())
1735 }
1736
1737 #[cfg(not(debug_assertions))]
1738 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1739 Ok(())
1740 }
1741
1742 gen_batch_tests!(check_batch_default_row);
1743 gen_batch_tests!(check_batch_no_poison);
1744}
1745
1746#[cfg(feature = "python")]
1747#[pyfunction(name = "cfo")]
1748#[pyo3(signature = (data, period=14, scalar=100.0, kernel=None))]
1749pub fn cfo_py<'py>(
1750 py: Python<'py>,
1751 data: numpy::PyReadonlyArray1<'py, f64>,
1752 period: usize,
1753 scalar: f64,
1754 kernel: Option<&str>,
1755) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1756 use numpy::{IntoPyArray, PyArrayMethods};
1757
1758 let slice_in = data.as_slice()?;
1759 let kern = validate_kernel(kernel, false)?;
1760
1761 let params = CfoParams {
1762 period: Some(period),
1763 scalar: Some(scalar),
1764 };
1765 let cfo_in = CfoInput::from_slice(slice_in, params);
1766
1767 let result_vec: Vec<f64> = py
1768 .allow_threads(|| cfo_with_kernel(&cfo_in, kern).map(|o| o.values))
1769 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1770
1771 Ok(result_vec.into_pyarray(py))
1772}
1773
1774#[cfg(feature = "python")]
1775#[pyclass(name = "CfoStream")]
1776pub struct CfoStreamPy {
1777 stream: CfoStream,
1778}
1779
1780#[cfg(feature = "python")]
1781#[pymethods]
1782impl CfoStreamPy {
1783 #[new]
1784 fn new(period: usize, scalar: f64) -> PyResult<Self> {
1785 let params = CfoParams {
1786 period: Some(period),
1787 scalar: Some(scalar),
1788 };
1789 let stream =
1790 CfoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1791 Ok(CfoStreamPy { stream })
1792 }
1793
1794 fn update(&mut self, value: f64) -> Option<f64> {
1795 self.stream.update(value)
1796 }
1797}
1798
1799#[cfg(feature = "python")]
1800#[pyfunction(name = "cfo_batch")]
1801#[pyo3(signature = (data, period_range, scalar_range, kernel=None))]
1802pub fn cfo_batch_py<'py>(
1803 py: Python<'py>,
1804 data: numpy::PyReadonlyArray1<'py, f64>,
1805 period_range: (usize, usize, usize),
1806 scalar_range: (f64, f64, f64),
1807 kernel: Option<&str>,
1808) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1809 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1810 use pyo3::types::PyDict;
1811
1812 let slice_in = data.as_slice()?;
1813 let kern = validate_kernel(kernel, true)?;
1814
1815 let sweep = CfoBatchRange {
1816 period: period_range,
1817 scalar: scalar_range,
1818 };
1819
1820 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1821 let rows = combos.len();
1822 let cols = slice_in.len();
1823
1824 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1825 let slice_out = unsafe { out_arr.as_slice_mut()? };
1826
1827 let combos = py
1828 .allow_threads(|| {
1829 let kernel = match kern {
1830 Kernel::Auto => detect_best_batch_kernel(),
1831 k => k,
1832 };
1833 let simd = match kernel {
1834 Kernel::Avx512Batch => Kernel::Avx512,
1835 Kernel::Avx2Batch => Kernel::Avx2,
1836 Kernel::ScalarBatch => Kernel::Scalar,
1837 _ => unreachable!(),
1838 };
1839 cfo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1840 })
1841 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1842
1843 let dict = PyDict::new(py);
1844 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1845 dict.set_item(
1846 "periods",
1847 combos
1848 .iter()
1849 .map(|p| p.period.unwrap() as u64)
1850 .collect::<Vec<_>>()
1851 .into_pyarray(py),
1852 )?;
1853 dict.set_item(
1854 "scalars",
1855 combos
1856 .iter()
1857 .map(|p| p.scalar.unwrap())
1858 .collect::<Vec<_>>()
1859 .into_pyarray(py),
1860 )?;
1861
1862 Ok(dict)
1863}
1864
1865#[cfg(all(feature = "python", feature = "cuda"))]
1866#[pyfunction(name = "cfo_cuda_batch_dev")]
1867#[pyo3(signature = (data_f32, period_range, scalar_range=(100.0, 100.0, 0.0), device_id=0))]
1868pub fn cfo_cuda_batch_dev_py<'py>(
1869 py: Python<'py>,
1870 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1871 period_range: (usize, usize, usize),
1872 scalar_range: (f64, f64, f64),
1873 device_id: usize,
1874) -> PyResult<(DeviceArrayF32CfoPy, Bound<'py, PyDict>)> {
1875 use crate::cuda::cuda_available;
1876 if !cuda_available() {
1877 return Err(PyValueError::new_err("CUDA not available"));
1878 }
1879
1880 let slice_in = data_f32.as_slice()?;
1881 let sweep = CfoBatchRange {
1882 period: period_range,
1883 scalar: scalar_range,
1884 };
1885
1886 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1887 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1888 let cuda = crate::cuda::oscillators::cfo_wrapper::CudaCfo::new(device_id)
1889 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1890
1891 let ctx = cuda.context_arc();
1892 let dev = cuda.device_id();
1893 let arr = cuda
1894 .cfo_batch_dev(slice_in, &sweep)
1895 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1896 Ok::<_, PyErr>((arr, ctx, dev as i32))
1897 })?;
1898
1899 let dict = PyDict::new(py);
1900 dict.set_item(
1901 "periods",
1902 combos
1903 .iter()
1904 .map(|p| p.period.unwrap() as u64)
1905 .collect::<Vec<_>>()
1906 .into_pyarray(py),
1907 )?;
1908 dict.set_item(
1909 "scalars",
1910 combos
1911 .iter()
1912 .map(|p| p.scalar.unwrap())
1913 .collect::<Vec<_>>()
1914 .into_pyarray(py),
1915 )?;
1916 Ok((
1917 DeviceArrayF32CfoPy {
1918 inner,
1919 ctx: ctx_arc,
1920 device_id: dev_id,
1921 },
1922 dict,
1923 ))
1924}
1925
1926#[cfg(all(feature = "python", feature = "cuda"))]
1927#[pyfunction(name = "cfo_cuda_many_series_one_param_dev")]
1928#[pyo3(signature = (data_tm_f32, period, scalar=100.0, device_id=0))]
1929pub fn cfo_cuda_many_series_one_param_dev_py<'py>(
1930 py: Python<'py>,
1931 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
1932 period: usize,
1933 scalar: f64,
1934 device_id: usize,
1935) -> PyResult<DeviceArrayF32CfoPy> {
1936 use crate::cuda::cuda_available;
1937 if !cuda_available() {
1938 return Err(PyValueError::new_err("CUDA not available"));
1939 }
1940
1941 let shape = data_tm_f32.shape();
1942 if shape.len() != 2 {
1943 return Err(PyValueError::new_err("expected 2D array"));
1944 }
1945 let rows = shape[0];
1946 let cols = shape[1];
1947 let flat = data_tm_f32.as_slice()?;
1948 let params = CfoParams {
1949 period: Some(period),
1950 scalar: Some(scalar),
1951 };
1952 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1953 let cuda = crate::cuda::oscillators::cfo_wrapper::CudaCfo::new(device_id)
1954 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1955 let ctx = cuda.context_arc();
1956 let dev = cuda.device_id();
1957 let arr = cuda
1958 .cfo_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
1959 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1960 Ok::<_, PyErr>((arr, ctx, dev as i32))
1961 })?;
1962 Ok(DeviceArrayF32CfoPy {
1963 inner,
1964 ctx: ctx_arc,
1965 device_id: dev_id,
1966 })
1967}
1968
1969#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1970#[wasm_bindgen]
1971pub fn cfo_js(data: &[f64], period: usize, scalar: f64) -> Result<Vec<f64>, JsValue> {
1972 let params = CfoParams {
1973 period: Some(period),
1974 scalar: Some(scalar),
1975 };
1976 let input = CfoInput::from_slice(data, params);
1977
1978 let mut output = vec![0.0; data.len()];
1979 cfo_into_slice(&mut output, &input, Kernel::Auto)
1980 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1981
1982 Ok(output)
1983}
1984
1985#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1986#[wasm_bindgen]
1987#[deprecated(note = "Use cfo_batch instead")]
1988pub fn cfo_batch_js(
1989 data: &[f64],
1990 period_start: usize,
1991 period_end: usize,
1992 period_step: usize,
1993 scalar_start: f64,
1994 scalar_end: f64,
1995 scalar_step: f64,
1996) -> Result<Vec<f64>, JsValue> {
1997 let sweep = CfoBatchRange {
1998 period: (period_start, period_end, period_step),
1999 scalar: (scalar_start, scalar_end, scalar_step),
2000 };
2001
2002 cfo_batch_inner(data, &sweep, Kernel::Scalar, false)
2003 .map(|output| output.values)
2004 .map_err(|e| JsValue::from_str(&e.to_string()))
2005}
2006
2007#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2008#[wasm_bindgen]
2009#[deprecated(note = "Use cfo_batch instead")]
2010pub fn cfo_batch_metadata_js(
2011 period_start: usize,
2012 period_end: usize,
2013 period_step: usize,
2014 scalar_start: f64,
2015 scalar_end: f64,
2016 scalar_step: f64,
2017) -> Result<Vec<f64>, JsValue> {
2018 let sweep = CfoBatchRange {
2019 period: (period_start, period_end, period_step),
2020 scalar: (scalar_start, scalar_end, scalar_step),
2021 };
2022
2023 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2024 let mut metadata = Vec::with_capacity(combos.len() * 2);
2025
2026 for combo in combos {
2027 metadata.push(combo.period.unwrap() as f64);
2028 metadata.push(combo.scalar.unwrap());
2029 }
2030
2031 Ok(metadata)
2032}
2033
2034#[cfg(all(feature = "python", feature = "cuda"))]
2035use crate::cuda::moving_averages::DeviceArrayF32;
2036#[cfg(all(feature = "python", feature = "cuda"))]
2037use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
2038#[cfg(all(feature = "python", feature = "cuda"))]
2039use cust::context::Context as CudaContext;
2040#[cfg(all(feature = "python", feature = "cuda"))]
2041use std::sync::Arc as StdArc;
2042
2043#[cfg(all(feature = "python", feature = "cuda"))]
2044#[pyclass(module = "ta_indicators.cuda", unsendable)]
2045pub struct DeviceArrayF32CfoPy {
2046 pub(crate) inner: DeviceArrayF32,
2047 pub(crate) ctx: StdArc<CudaContext>,
2048 pub(crate) device_id: i32,
2049}
2050
2051#[cfg(all(feature = "python", feature = "cuda"))]
2052#[pymethods]
2053impl DeviceArrayF32CfoPy {
2054 #[getter]
2055 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2056 let d = PyDict::new(py);
2057 let inner = &self.inner;
2058 d.set_item("shape", (inner.rows, inner.cols))?;
2059 d.set_item("typestr", "<f4")?;
2060 d.set_item(
2061 "strides",
2062 (
2063 inner.cols * std::mem::size_of::<f32>(),
2064 std::mem::size_of::<f32>(),
2065 ),
2066 )?;
2067 let ptr_val: usize = if inner.rows == 0 || inner.cols == 0 {
2068 0
2069 } else {
2070 inner.device_ptr() as usize
2071 };
2072 d.set_item("data", (ptr_val, false))?;
2073
2074 d.set_item("version", 3)?;
2075 Ok(d)
2076 }
2077
2078 fn __dlpack_device__(&self) -> PyResult<(i32, i32)> {
2079 Ok((2, self.device_id))
2080 }
2081
2082 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
2083 fn __dlpack__<'py>(
2084 &mut self,
2085 py: Python<'py>,
2086 stream: Option<pyo3::PyObject>,
2087 max_version: Option<pyo3::PyObject>,
2088 dl_device: Option<pyo3::PyObject>,
2089 copy: Option<pyo3::PyObject>,
2090 ) -> PyResult<PyObject> {
2091 let (kdl, alloc_dev) = self.__dlpack_device__()?;
2092 if let Some(dev_obj) = dl_device.as_ref() {
2093 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2094 if dev_ty != kdl || dev_id != alloc_dev {
2095 let wants_copy = copy
2096 .as_ref()
2097 .and_then(|c| c.extract::<bool>(py).ok())
2098 .unwrap_or(false);
2099 if wants_copy {
2100 return Err(PyValueError::new_err(
2101 "device copy not implemented for __dlpack__",
2102 ));
2103 } else {
2104 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2105 }
2106 }
2107 }
2108 }
2109 let _ = stream;
2110
2111 let dummy = cust::memory::DeviceBuffer::<f32>::from_slice(&[])
2112 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2113 let inner = std::mem::replace(
2114 &mut self.inner,
2115 DeviceArrayF32 {
2116 buf: dummy,
2117 rows: 0,
2118 cols: 0,
2119 },
2120 );
2121
2122 let rows = inner.rows;
2123 let cols = inner.cols;
2124 let buf = inner.buf;
2125
2126 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2127
2128 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2129 }
2130}
2131
2132#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2133#[derive(Serialize, Deserialize)]
2134pub struct CfoBatchConfig {
2135 pub period_range: (usize, usize, usize),
2136 pub scalar_range: (f64, f64, f64),
2137}
2138
2139#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2140#[derive(Serialize, Deserialize)]
2141pub struct CfoBatchJsOutput {
2142 pub values: Vec<f64>,
2143 pub combos: Vec<CfoParams>,
2144 pub rows: usize,
2145 pub cols: usize,
2146}
2147
2148#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2149#[wasm_bindgen]
2150pub fn cfo_batch(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2151 let config: CfoBatchConfig = serde_wasm_bindgen::from_value(config)
2152 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2153
2154 let sweep = CfoBatchRange {
2155 period: config.period_range,
2156 scalar: config.scalar_range,
2157 };
2158
2159 let output = cfo_batch_inner(data, &sweep, detect_best_kernel(), false)
2160 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2161
2162 let js_output = CfoBatchJsOutput {
2163 values: output.values,
2164 combos: output.combos,
2165 rows: output.rows,
2166 cols: output.cols,
2167 };
2168
2169 serde_wasm_bindgen::to_value(&js_output)
2170 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2171}
2172
2173#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2174#[wasm_bindgen]
2175pub fn cfo_alloc(len: usize) -> *mut f64 {
2176 let mut vec = Vec::<f64>::with_capacity(len);
2177 let ptr = vec.as_mut_ptr();
2178 std::mem::forget(vec);
2179 ptr
2180}
2181
2182#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2183#[wasm_bindgen]
2184pub fn cfo_free(ptr: *mut f64, len: usize) {
2185 unsafe {
2186 let _ = Vec::from_raw_parts(ptr, len, len);
2187 }
2188}
2189
2190#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2191#[wasm_bindgen]
2192pub fn cfo_into(
2193 in_ptr: *const f64,
2194 out_ptr: *mut f64,
2195 len: usize,
2196 period: usize,
2197 scalar: f64,
2198) -> Result<(), JsValue> {
2199 if in_ptr.is_null() || out_ptr.is_null() {
2200 return Err(JsValue::from_str("null pointer passed to cfo_into"));
2201 }
2202
2203 unsafe {
2204 let data = std::slice::from_raw_parts(in_ptr, len);
2205
2206 if period == 0 || period > len {
2207 return Err(JsValue::from_str("Invalid period"));
2208 }
2209
2210 let params = CfoParams {
2211 period: Some(period),
2212 scalar: Some(scalar),
2213 };
2214 let input = CfoInput::from_slice(data, params);
2215
2216 if core::ptr::eq(in_ptr, out_ptr as *const f64) {
2217 let mut temp = vec![0.0; len];
2218 cfo_into_slice(&mut temp, &input, Kernel::Auto)
2219 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2220 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2221 out.copy_from_slice(&temp);
2222 } else {
2223 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2224 cfo_into_slice(out, &input, Kernel::Auto)
2225 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2226 }
2227 Ok(())
2228 }
2229}
2230
2231#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2232#[wasm_bindgen]
2233pub fn cfo_batch_into(
2234 in_ptr: *const f64,
2235 out_ptr: *mut f64,
2236 len: usize,
2237 period_start: usize,
2238 period_end: usize,
2239 period_step: usize,
2240 scalar_start: f64,
2241 scalar_end: f64,
2242 scalar_step: f64,
2243) -> Result<usize, JsValue> {
2244 if in_ptr.is_null() || out_ptr.is_null() {
2245 return Err(JsValue::from_str("null pointer passed to cfo_batch_into"));
2246 }
2247
2248 unsafe {
2249 let data = std::slice::from_raw_parts(in_ptr, len);
2250
2251 let sweep = CfoBatchRange {
2252 period: (period_start, period_end, period_step),
2253 scalar: (scalar_start, scalar_end, scalar_step),
2254 };
2255
2256 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2257 let rows = combos.len();
2258 let cols = len;
2259
2260 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2261
2262 let simd = match detect_best_batch_kernel() {
2263 Kernel::Avx512Batch => Kernel::Avx512,
2264 Kernel::Avx2Batch => Kernel::Avx2,
2265 _ => Kernel::Scalar,
2266 };
2267 cfo_batch_inner_into(data, &sweep, simd, false, out)
2268 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2269
2270 Ok(rows)
2271 }
2272}