1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use serde::{Deserialize, Serialize};
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use wasm_bindgen::prelude::*;
20
21impl<'a> AsRef<[f64]> for CmoInput<'a> {
22 #[inline(always)]
23 fn as_ref(&self) -> &[f64] {
24 match &self.data {
25 CmoData::Slice(slice) => slice,
26 CmoData::Candles { candles, source } => source_type(candles, source),
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
32pub enum CmoData<'a> {
33 Candles {
34 candles: &'a Candles,
35 source: &'a str,
36 },
37 Slice(&'a [f64]),
38}
39
40#[derive(Debug, Clone)]
41pub struct CmoOutput {
42 pub values: Vec<f64>,
43}
44
45#[derive(Debug, Clone)]
46#[cfg_attr(
47 all(target_arch = "wasm32", feature = "wasm"),
48 derive(Serialize, Deserialize)
49)]
50pub struct CmoParams {
51 pub period: Option<usize>,
52}
53
54impl Default for CmoParams {
55 fn default() -> Self {
56 Self { period: Some(14) }
57 }
58}
59
60#[derive(Debug, Clone)]
61pub struct CmoInput<'a> {
62 pub data: CmoData<'a>,
63 pub params: CmoParams,
64}
65
66impl<'a> CmoInput<'a> {
67 #[inline]
68 pub fn from_candles(c: &'a Candles, s: &'a str, p: CmoParams) -> Self {
69 Self {
70 data: CmoData::Candles {
71 candles: c,
72 source: s,
73 },
74 params: p,
75 }
76 }
77 #[inline]
78 pub fn from_slice(sl: &'a [f64], p: CmoParams) -> Self {
79 Self {
80 data: CmoData::Slice(sl),
81 params: p,
82 }
83 }
84 #[inline]
85 pub fn with_default_candles(c: &'a Candles) -> Self {
86 Self::from_candles(c, "close", CmoParams::default())
87 }
88 #[inline]
89 pub fn get_period(&self) -> usize {
90 self.params.period.unwrap_or(14)
91 }
92 #[inline]
93 pub fn data_len(&self) -> usize {
94 match &self.data {
95 CmoData::Slice(slice) => slice.len(),
96 CmoData::Candles { candles, .. } => candles.close.len(),
97 }
98 }
99}
100
101#[derive(Copy, Clone, Debug)]
102pub struct CmoBuilder {
103 period: Option<usize>,
104 kernel: Kernel,
105}
106
107impl Default for CmoBuilder {
108 fn default() -> Self {
109 Self {
110 period: None,
111 kernel: Kernel::Auto,
112 }
113 }
114}
115
116impl CmoBuilder {
117 #[inline(always)]
118 pub fn new() -> Self {
119 Self::default()
120 }
121 #[inline(always)]
122 pub fn period(mut self, n: usize) -> Self {
123 self.period = Some(n);
124 self
125 }
126 #[inline(always)]
127 pub fn kernel(mut self, k: Kernel) -> Self {
128 self.kernel = k;
129 self
130 }
131
132 #[inline(always)]
133 pub fn apply(self, c: &Candles) -> Result<CmoOutput, CmoError> {
134 let p = CmoParams {
135 period: self.period,
136 };
137 let i = CmoInput::from_candles(c, "close", p);
138 cmo_with_kernel(&i, self.kernel)
139 }
140
141 #[inline(always)]
142 pub fn apply_slice(self, d: &[f64]) -> Result<CmoOutput, CmoError> {
143 let p = CmoParams {
144 period: self.period,
145 };
146 let i = CmoInput::from_slice(d, p);
147 cmo_with_kernel(&i, self.kernel)
148 }
149
150 #[inline(always)]
151 pub fn into_stream(self) -> Result<CmoStream, CmoError> {
152 let p = CmoParams {
153 period: self.period,
154 };
155 CmoStream::try_new(p)
156 }
157}
158
159#[derive(Debug, Error)]
160pub enum CmoError {
161 #[error("cmo: Empty data provided.")]
162 EmptyData,
163
164 #[error("cmo: Invalid period: period={period}, data_len={data_len}")]
165 InvalidPeriod { period: usize, data_len: usize },
166
167 #[error("cmo: All values are NaN.")]
168 AllValuesNaN,
169
170 #[error("cmo: Not enough valid data: needed={needed}, valid={valid}")]
171 NotEnoughValidData { needed: usize, valid: usize },
172
173 #[error("cmo: Invalid range: start={start}, end={end}, step={step}")]
174 InvalidRange {
175 start: usize,
176 end: usize,
177 step: usize,
178 },
179
180 #[error("cmo: Invalid kernel for batch: {0:?}")]
181 InvalidKernelForBatch(Kernel),
182
183 #[error("cmo: Output length mismatch: expected={expected}, got={got}")]
184 OutputLengthMismatch { expected: usize, got: usize },
185}
186
187#[inline]
188pub fn cmo(input: &CmoInput) -> Result<CmoOutput, CmoError> {
189 cmo_with_kernel(input, Kernel::Auto)
190}
191
192#[inline(always)]
193fn cmo_prepare<'a>(
194 input: &'a CmoInput,
195 k: Kernel,
196) -> Result<(&'a [f64], usize, usize, Kernel), CmoError> {
197 let data: &[f64] = input.as_ref();
198 let len = data.len();
199 if len == 0 {
200 return Err(CmoError::EmptyData);
201 }
202 let period = input.get_period();
203 if period == 0 || period > len {
204 return Err(CmoError::InvalidPeriod {
205 period,
206 data_len: len,
207 });
208 }
209 let first = data
210 .iter()
211 .position(|x| !x.is_nan())
212 .ok_or(CmoError::AllValuesNaN)?;
213 if len - first <= period {
214 return Err(CmoError::NotEnoughValidData {
215 needed: period + 1,
216 valid: len - first,
217 });
218 }
219 let mut chosen = match k {
220 Kernel::Auto => Kernel::Scalar,
221 other => other,
222 };
223
224 if chosen.is_batch() {
225 chosen = match chosen {
226 Kernel::Avx512Batch => Kernel::Avx512,
227 Kernel::Avx2Batch => Kernel::Avx2,
228 Kernel::ScalarBatch => Kernel::Scalar,
229 _ => chosen,
230 };
231 }
232 Ok((data, period, first, chosen))
233}
234
235#[inline(always)]
236fn cmo_compute_into(data: &[f64], period: usize, first: usize, kernel: Kernel, out: &mut [f64]) {
237 unsafe {
238 match kernel {
239 Kernel::Scalar | Kernel::ScalarBatch => cmo_scalar(data, period, first, out),
240 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
241 Kernel::Avx2 | Kernel::Avx2Batch => cmo_avx2(data, period, first, out),
242 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
243 Kernel::Avx512 | Kernel::Avx512Batch => cmo_avx512(data, period, first, out),
244 _ => unreachable!(),
245 }
246 }
247}
248
249pub fn cmo_with_kernel(input: &CmoInput, kernel: Kernel) -> Result<CmoOutput, CmoError> {
250 let (data, period, first, chosen) = cmo_prepare(input, kernel)?;
251 let mut out = alloc_with_nan_prefix(data.len(), first + period);
252 cmo_compute_into(data, period, first, chosen, &mut out);
253 Ok(CmoOutput { values: out })
254}
255
256#[inline]
257pub fn cmo_into_slice(dst: &mut [f64], input: &CmoInput, kern: Kernel) -> Result<(), CmoError> {
258 let (data, period, first, chosen) = cmo_prepare(input, kern)?;
259 if dst.len() != data.len() {
260 return Err(CmoError::OutputLengthMismatch {
261 expected: data.len(),
262 got: dst.len(),
263 });
264 }
265 cmo_compute_into(data, period, first, chosen, dst);
266 let warmup_end = first + period;
267 for v in &mut dst[..warmup_end] {
268 *v = f64::NAN;
269 }
270 Ok(())
271}
272
273#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
274#[inline]
275pub fn cmo_into(input: &CmoInput, out: &mut [f64]) -> Result<(), CmoError> {
276 let (data, period, first, chosen) = cmo_prepare(input, Kernel::Auto)?;
277
278 if out.len() != data.len() {
279 return Err(CmoError::OutputLengthMismatch {
280 expected: data.len(),
281 got: out.len(),
282 });
283 }
284
285 let warmup_end = first + period;
286 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
287 let warm = warmup_end.min(out.len());
288 for v in &mut out[..warm] {
289 *v = qnan;
290 }
291
292 cmo_compute_into(data, period, first, chosen, out);
293
294 Ok(())
295}
296
297#[inline]
298pub fn cmo_scalar(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
299 let mut avg_gain = 0.0;
300 let mut avg_loss = 0.0;
301 let mut prev_price = data[first_valid];
302
303 let start_loop = first_valid + 1;
304 let init_end = first_valid + period;
305
306 let period_f = period as f64;
307 let period_m1 = (period - 1) as f64;
308 let inv_period = 1.0 / period_f;
309
310 for i in start_loop..data.len() {
311 let curr = data[i];
312 let diff = curr - prev_price;
313 prev_price = curr;
314
315 let abs_diff = diff.abs();
316 let gain = 0.5 * (diff + abs_diff);
317 let loss = 0.5 * (abs_diff - diff);
318
319 if i <= init_end {
320 avg_gain += gain;
321 avg_loss += loss;
322 if i == init_end {
323 avg_gain *= inv_period;
324 avg_loss *= inv_period;
325 let sum_gl = avg_gain + avg_loss;
326 out[i] = if sum_gl != 0.0 {
327 100.0 * ((avg_gain - avg_loss) / sum_gl)
328 } else {
329 0.0
330 };
331 }
332 } else {
333 avg_gain *= period_m1;
334 avg_loss *= period_m1;
335 avg_gain += gain;
336 avg_loss += loss;
337 avg_gain *= inv_period;
338 avg_loss *= inv_period;
339 let sum_gl = avg_gain + avg_loss;
340 out[i] = if sum_gl != 0.0 {
341 100.0 * ((avg_gain - avg_loss) / sum_gl)
342 } else {
343 0.0
344 };
345 }
346 }
347}
348
349#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
350#[inline]
351pub fn cmo_avx512(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
352 unsafe {
353 if period <= 32 {
354 cmo_avx512_short(data, period, first_valid, out)
355 } else {
356 cmo_avx512_long(data, period, first_valid, out)
357 }
358 }
359}
360
361#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
362#[inline]
363pub fn cmo_avx2(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
364 unsafe { cmo_avx2_impl(data, period, first_valid, out) }
365}
366
367#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
368#[inline]
369pub unsafe fn cmo_avx512_short(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
370 cmo_avx512_impl(data, period, first_valid, out)
371}
372
373#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
374#[inline]
375pub unsafe fn cmo_avx512_long(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
376 cmo_avx512_impl(data, period, first_valid, out)
377}
378
379#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
380#[target_feature(enable = "avx2")]
381unsafe fn cmo_avx2_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
382 use core::arch::x86_64::*;
383
384 debug_assert!(out.len() == data.len());
385
386 #[inline(always)]
387 unsafe fn hsum256_pd(v: __m256d) -> f64 {
388 let hi = _mm256_extractf128_pd(v, 1);
389 let lo = _mm256_castpd256_pd128(v);
390 let sum2 = _mm_add_pd(lo, hi);
391 let hi64 = _mm_unpackhi_pd(sum2, sum2);
392 _mm_cvtsd_f64(_mm_add_sd(sum2, hi64))
393 }
394
395 let len = data.len();
396 let start = first_valid + 1;
397 let init_end = first_valid + period;
398
399 let inv_period = 1.0 / (period as f64);
400 let period_m1 = (period - 1) as f64;
401
402 let mut acc_gain_v = _mm256_setzero_pd();
403 let mut acc_loss_v = _mm256_setzero_pd();
404 let half_v = _mm256_set1_pd(0.5);
405 let abs_mask = _mm256_castsi256_pd(_mm256_set1_epi64x(0x7FFF_FFFF_FFFF_FFFFu64 as i64));
406
407 let mut sum_gain = 0.0f64;
408 let mut sum_loss = 0.0f64;
409
410 let mut i = start;
411 while i + 3 <= init_end {
412 let curr_v = _mm256_loadu_pd(data.as_ptr().add(i));
413 let prev_v = _mm256_loadu_pd(data.as_ptr().add(i - 1));
414 let diff_v = _mm256_sub_pd(curr_v, prev_v);
415
416 let ad_v = _mm256_and_pd(diff_v, abs_mask);
417 let gain_v = _mm256_mul_pd(_mm256_add_pd(ad_v, diff_v), half_v);
418 let loss_v = _mm256_mul_pd(_mm256_sub_pd(ad_v, diff_v), half_v);
419
420 acc_gain_v = _mm256_add_pd(acc_gain_v, gain_v);
421 acc_loss_v = _mm256_add_pd(acc_loss_v, loss_v);
422
423 i += 4;
424 }
425
426 sum_gain += hsum256_pd(acc_gain_v);
427 sum_loss += hsum256_pd(acc_loss_v);
428
429 let mut prev = if i == start {
430 *data.get_unchecked(first_valid)
431 } else {
432 *data.get_unchecked(i - 1)
433 };
434
435 while i <= init_end {
436 let curr = *data.get_unchecked(i);
437 let diff = curr - prev;
438 prev = curr;
439
440 let ad = diff.abs();
441 sum_gain += 0.5 * (ad + diff);
442 sum_loss += 0.5 * (ad - diff);
443 i += 1;
444 }
445
446 let mut avg_gain = sum_gain * inv_period;
447 let mut avg_loss = sum_loss * inv_period;
448 {
449 let sum_gl = avg_gain + avg_loss;
450 *out.get_unchecked_mut(init_end) = if sum_gl != 0.0 {
451 100.0 * ((avg_gain - avg_loss) / sum_gl)
452 } else {
453 0.0
454 };
455 }
456
457 while i < len {
458 let curr = *data.get_unchecked(i);
459 let diff = curr - prev;
460 prev = curr;
461
462 let ad = diff.abs();
463 let gain = 0.5 * (ad + diff);
464 let loss = 0.5 * (ad - diff);
465
466 avg_gain *= period_m1;
467 avg_loss *= period_m1;
468 avg_gain += gain;
469 avg_loss += loss;
470 avg_gain *= inv_period;
471 avg_loss *= inv_period;
472
473 let sum_gl = avg_gain + avg_loss;
474 *out.get_unchecked_mut(i) = if sum_gl != 0.0 {
475 100.0 * ((avg_gain - avg_loss) / sum_gl)
476 } else {
477 0.0
478 };
479
480 i += 1;
481 }
482}
483
484#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
485#[target_feature(enable = "avx512f")]
486unsafe fn cmo_avx512_impl(data: &[f64], period: usize, first_valid: usize, out: &mut [f64]) {
487 use core::arch::x86_64::*;
488
489 debug_assert!(out.len() == data.len());
490
491 #[inline(always)]
492 unsafe fn hsum256_pd(v: __m256d) -> f64 {
493 let hi = _mm256_extractf128_pd(v, 1);
494 let lo = _mm256_castpd256_pd128(v);
495 let sum2 = _mm_add_pd(lo, hi);
496 let hi64 = _mm_unpackhi_pd(sum2, sum2);
497 _mm_cvtsd_f64(_mm_add_sd(sum2, hi64))
498 }
499
500 #[inline(always)]
501 unsafe fn hsum512_pd(v: __m512d) -> f64 {
502 let lo256 = _mm512_castpd512_pd256(v);
503 let hi256 = _mm512_extractf64x4_pd(v, 1);
504 hsum256_pd(_mm256_add_pd(lo256, hi256))
505 }
506
507 let len = data.len();
508 let start = first_valid + 1;
509 let init_end = first_valid + period;
510
511 let inv_period = 1.0 / (period as f64);
512 let period_m1 = (period - 1) as f64;
513
514 let mut acc_gain_v = _mm512_setzero_pd();
515 let mut acc_loss_v = _mm512_setzero_pd();
516 let half_v = _mm512_set1_pd(0.5);
517 let abs_mask_i = _mm512_set1_epi64(0x7FFF_FFFF_FFFF_FFFFu64 as i64);
518
519 let mut sum_gain = 0.0f64;
520 let mut sum_loss = 0.0f64;
521
522 let mut i = start;
523 while i + 7 <= init_end {
524 let curr_v = _mm512_loadu_pd(data.as_ptr().add(i));
525 let prev_v = _mm512_loadu_pd(data.as_ptr().add(i - 1));
526 let diff_v = _mm512_sub_pd(curr_v, prev_v);
527
528 let diff_i = _mm512_castpd_si512(diff_v);
529 let abs_i = _mm512_and_si512(diff_i, abs_mask_i);
530 let ad_v = _mm512_castsi512_pd(abs_i);
531
532 let gain_v = _mm512_mul_pd(_mm512_add_pd(ad_v, diff_v), half_v);
533 let loss_v = _mm512_mul_pd(_mm512_sub_pd(ad_v, diff_v), half_v);
534
535 acc_gain_v = _mm512_add_pd(acc_gain_v, gain_v);
536 acc_loss_v = _mm512_add_pd(acc_loss_v, loss_v);
537
538 i += 8;
539 }
540
541 sum_gain += hsum512_pd(acc_gain_v);
542 sum_loss += hsum512_pd(acc_loss_v);
543
544 let mut prev = if i == start {
545 *data.get_unchecked(first_valid)
546 } else {
547 *data.get_unchecked(i - 1)
548 };
549
550 while i <= init_end {
551 let curr = *data.get_unchecked(i);
552 let diff = curr - prev;
553 prev = curr;
554
555 let ad = diff.abs();
556 sum_gain += 0.5 * (ad + diff);
557 sum_loss += 0.5 * (ad - diff);
558 i += 1;
559 }
560
561 let mut avg_gain = sum_gain * inv_period;
562 let mut avg_loss = sum_loss * inv_period;
563 {
564 let sum_gl = avg_gain + avg_loss;
565 *out.get_unchecked_mut(init_end) = if sum_gl != 0.0 {
566 100.0 * ((avg_gain - avg_loss) / sum_gl)
567 } else {
568 0.0
569 };
570 }
571
572 while i < len {
573 let curr = *data.get_unchecked(i);
574 let diff = curr - prev;
575 prev = curr;
576
577 let ad = diff.abs();
578 let gain = 0.5 * (ad + diff);
579 let loss = 0.5 * (ad - diff);
580
581 avg_gain *= period_m1;
582 avg_loss *= period_m1;
583 avg_gain += gain;
584 avg_loss += loss;
585 avg_gain *= inv_period;
586 avg_loss *= inv_period;
587
588 let sum_gl = avg_gain + avg_loss;
589 *out.get_unchecked_mut(i) = if sum_gl != 0.0 {
590 100.0 * ((avg_gain - avg_loss) / sum_gl)
591 } else {
592 0.0
593 };
594
595 i += 1;
596 }
597}
598
599#[inline(always)]
600pub fn cmo_batch_with_kernel(
601 data: &[f64],
602 sweep: &CmoBatchRange,
603 k: Kernel,
604) -> Result<CmoBatchOutput, CmoError> {
605 let kernel = match k {
606 Kernel::Auto => Kernel::ScalarBatch,
607 other if other.is_batch() => other,
608 _ => return Err(CmoError::InvalidKernelForBatch(k)),
609 };
610 let simd = match kernel {
611 Kernel::Avx512Batch => Kernel::Avx512,
612 Kernel::Avx2Batch => Kernel::Avx2,
613 Kernel::ScalarBatch => Kernel::Scalar,
614 _ => unreachable!(),
615 };
616 cmo_batch_par_slice(data, sweep, simd)
617}
618
619#[derive(Clone, Debug)]
620pub struct CmoBatchRange {
621 pub period: (usize, usize, usize),
622}
623
624impl Default for CmoBatchRange {
625 fn default() -> Self {
626 Self {
627 period: (14, 263, 1),
628 }
629 }
630}
631
632#[derive(Clone, Debug, Default)]
633pub struct CmoBatchBuilder {
634 range: CmoBatchRange,
635 kernel: Kernel,
636}
637
638impl CmoBatchBuilder {
639 pub fn new() -> Self {
640 Self::default()
641 }
642 pub fn kernel(mut self, k: Kernel) -> Self {
643 self.kernel = k;
644 self
645 }
646 #[inline]
647 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
648 self.range.period = (start, end, step);
649 self
650 }
651 #[inline]
652 pub fn period_static(mut self, p: usize) -> Self {
653 self.range.period = (p, p, 0);
654 self
655 }
656 pub fn apply_slice(self, data: &[f64]) -> Result<CmoBatchOutput, CmoError> {
657 cmo_batch_with_kernel(data, &self.range, self.kernel)
658 }
659 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<CmoBatchOutput, CmoError> {
660 CmoBatchBuilder::new().kernel(k).apply_slice(data)
661 }
662 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<CmoBatchOutput, CmoError> {
663 let slice = source_type(c, src);
664 self.apply_slice(slice)
665 }
666 pub fn with_default_candles(c: &Candles) -> Result<CmoBatchOutput, CmoError> {
667 CmoBatchBuilder::new()
668 .kernel(Kernel::Auto)
669 .apply_candles(c, "close")
670 }
671}
672
673#[derive(Clone, Debug)]
674pub struct CmoBatchOutput {
675 pub values: Vec<f64>,
676 pub combos: Vec<CmoParams>,
677 pub rows: usize,
678 pub cols: usize,
679}
680
681impl CmoBatchOutput {
682 pub fn row_for_params(&self, p: &CmoParams) -> Option<usize> {
683 self.combos
684 .iter()
685 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
686 }
687 pub fn values_for(&self, p: &CmoParams) -> Option<&[f64]> {
688 self.row_for_params(p).map(|row| {
689 let start = row * self.cols;
690 &self.values[start..start + self.cols]
691 })
692 }
693}
694
695#[inline(always)]
696fn expand_grid(r: &CmoBatchRange) -> Vec<CmoParams> {
697 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
698 if step == 0 || start == end {
699 return vec![start];
700 }
701 let mut vals = Vec::new();
702 if start < end {
703 let mut x = start;
704 while x <= end {
705 vals.push(x);
706 let next = x.saturating_add(step);
707 if next == x {
708 break;
709 }
710 x = next;
711 }
712 } else {
713 let mut x = start;
714 loop {
715 vals.push(x);
716 if x <= end {
717 break;
718 }
719 let next = x.saturating_sub(step);
720 if next >= x {
721 break;
722 }
723 x = next;
724 }
725 }
726 vals
727 }
728 let periods = axis_usize(r.period);
729 let mut out = Vec::with_capacity(periods.len());
730 for &p in &periods {
731 out.push(CmoParams { period: Some(p) });
732 }
733 out
734}
735
736#[inline(always)]
737pub fn cmo_batch_slice(
738 data: &[f64],
739 sweep: &CmoBatchRange,
740 kern: Kernel,
741) -> Result<CmoBatchOutput, CmoError> {
742 cmo_batch_inner(data, sweep, kern, false)
743}
744
745#[inline(always)]
746pub fn cmo_batch_par_slice(
747 data: &[f64],
748 sweep: &CmoBatchRange,
749 kern: Kernel,
750) -> Result<CmoBatchOutput, CmoError> {
751 cmo_batch_inner(data, sweep, kern, true)
752}
753
754#[inline(always)]
755fn cmo_batch_inner(
756 data: &[f64],
757 sweep: &CmoBatchRange,
758 kern: Kernel,
759 parallel: bool,
760) -> Result<CmoBatchOutput, CmoError> {
761 let combos = expand_grid(sweep);
762 if combos.is_empty() {
763 return Err(CmoError::InvalidRange {
764 start: sweep.period.0,
765 end: sweep.period.1,
766 step: sweep.period.2,
767 });
768 }
769 let first = data
770 .iter()
771 .position(|x| !x.is_nan())
772 .ok_or(CmoError::AllValuesNaN)?;
773 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
774 let _ = combos
775 .len()
776 .checked_mul(max_p)
777 .ok_or(CmoError::InvalidRange {
778 start: sweep.period.0,
779 end: sweep.period.1,
780 step: sweep.period.2,
781 })?;
782 if data.len() - first <= max_p {
783 return Err(CmoError::NotEnoughValidData {
784 needed: max_p + 1,
785 valid: data.len() - first,
786 });
787 }
788 let rows = combos.len();
789 let cols = data.len();
790 let _expected = rows.checked_mul(cols).ok_or(CmoError::InvalidRange {
791 start: sweep.period.0,
792 end: sweep.period.1,
793 step: sweep.period.2,
794 })?;
795
796 let mut buf_mu = make_uninit_matrix(rows, cols);
797
798 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
799
800 init_matrix_prefixes(&mut buf_mu, cols, &warm);
801
802 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
803 let out: &mut [f64] = unsafe {
804 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
805 };
806
807 let len = data.len();
808 let start = first + 1;
809 let mut gains = vec![0.0f64; len];
810 let mut losses = vec![0.0f64; len];
811 for i in start..len {
812 let diff = data[i] - data[i - 1];
813 let ad = diff.abs();
814 gains[i] = 0.5 * (ad + diff);
815 losses[i] = 0.5 * (ad - diff);
816 }
817 let mut pg = vec![0.0f64; len + 1];
818 let mut pl = vec![0.0f64; len + 1];
819 for i in 0..len {
820 pg[i + 1] = pg[i] + gains[i];
821 pl[i + 1] = pl[i] + losses[i];
822 }
823
824 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
825 let period = combos[row].period.unwrap();
826 cmo_row_from_gl(&gains, &losses, &pg, &pl, first, period, out_row);
827 };
828
829 if parallel {
830 #[cfg(not(target_arch = "wasm32"))]
831 {
832 out.par_chunks_mut(cols)
833 .enumerate()
834 .for_each(|(row, slice)| do_row(row, slice));
835 }
836
837 #[cfg(target_arch = "wasm32")]
838 {
839 for (row, slice) in out.chunks_mut(cols).enumerate() {
840 do_row(row, slice);
841 }
842 }
843 } else {
844 for (row, slice) in out.chunks_mut(cols).enumerate() {
845 do_row(row, slice);
846 }
847 }
848
849 let values = unsafe {
850 Vec::from_raw_parts(
851 buf_guard.as_mut_ptr() as *mut f64,
852 buf_guard.len(),
853 buf_guard.capacity(),
854 )
855 };
856
857 Ok(CmoBatchOutput {
858 values,
859 combos,
860 rows,
861 cols,
862 })
863}
864
865#[inline(always)]
866fn cmo_batch_inner_into(
867 data: &[f64],
868 sweep: &CmoBatchRange,
869 kern: Kernel,
870 parallel: bool,
871 out: &mut [f64],
872) -> Result<Vec<CmoParams>, CmoError> {
873 let combos = expand_grid(sweep);
874 if combos.is_empty() {
875 return Err(CmoError::InvalidRange {
876 start: sweep.period.0,
877 end: sweep.period.1,
878 step: sweep.period.2,
879 });
880 }
881 let first = data
882 .iter()
883 .position(|x| !x.is_nan())
884 .ok_or(CmoError::AllValuesNaN)?;
885 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
886 let _ = combos
887 .len()
888 .checked_mul(max_p)
889 .ok_or(CmoError::InvalidRange {
890 start: sweep.period.0,
891 end: sweep.period.1,
892 step: sweep.period.2,
893 })?;
894 if data.len() - first <= max_p {
895 return Err(CmoError::NotEnoughValidData {
896 needed: max_p + 1,
897 valid: data.len() - first,
898 });
899 }
900 let cols = data.len();
901 let rows = combos.len();
902 let expected = rows.checked_mul(cols).ok_or(CmoError::InvalidRange {
903 start: sweep.period.0,
904 end: sweep.period.1,
905 step: sweep.period.2,
906 })?;
907 if out.len() != expected {
908 return Err(CmoError::OutputLengthMismatch {
909 expected,
910 got: out.len(),
911 });
912 }
913
914 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
915 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
916 };
917 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
918 init_matrix_prefixes(out_mu, cols, &warm);
919
920 let len = data.len();
921 let start = first + 1;
922 let mut gains = vec![0.0f64; len];
923 let mut losses = vec![0.0f64; len];
924 for i in start..len {
925 let diff = data[i] - data[i - 1];
926 let ad = diff.abs();
927 gains[i] = 0.5 * (ad + diff);
928 losses[i] = 0.5 * (ad - diff);
929 }
930 let mut pg = vec![0.0f64; len + 1];
931 let mut pl = vec![0.0f64; len + 1];
932 for i in 0..len {
933 pg[i + 1] = pg[i] + gains[i];
934 pl[i + 1] = pl[i] + losses[i];
935 }
936
937 let do_row = |row: usize, row_mu: &mut [MaybeUninit<f64>]| unsafe {
938 let period = combos[row].period.unwrap();
939 let row_dst: &mut [f64] =
940 std::slice::from_raw_parts_mut(row_mu.as_mut_ptr() as *mut f64, row_mu.len());
941 cmo_row_from_gl(&gains, &losses, &pg, &pl, first, period, row_dst);
942 };
943
944 if parallel {
945 #[cfg(not(target_arch = "wasm32"))]
946 {
947 out_mu
948 .par_chunks_mut(cols)
949 .enumerate()
950 .for_each(|(r, row_mu)| do_row(r, row_mu));
951 }
952 #[cfg(target_arch = "wasm32")]
953 {
954 for (r, row_mu) in out_mu.chunks_mut(cols).enumerate() {
955 do_row(r, row_mu);
956 }
957 }
958 } else {
959 for (r, row_mu) in out_mu.chunks_mut(cols).enumerate() {
960 do_row(r, row_mu);
961 }
962 }
963
964 Ok(combos)
965}
966
967#[inline]
968pub fn cmo_batch_into_slice(
969 out: &mut [f64],
970 data: &[f64],
971 sweep: &CmoBatchRange,
972 k: Kernel,
973) -> Result<Vec<CmoParams>, CmoError> {
974 let kernel = match k {
975 Kernel::Auto => detect_best_batch_kernel(),
976 other if other.is_batch() => other,
977 _ => return Err(CmoError::InvalidKernelForBatch(k)),
978 };
979 let simd = match kernel {
980 Kernel::Avx512Batch => Kernel::Avx512,
981 Kernel::Avx2Batch => Kernel::Avx2,
982 Kernel::ScalarBatch => Kernel::Scalar,
983 _ => unreachable!(),
984 };
985 cmo_batch_inner_into(data, sweep, simd, true, out)
986}
987
988#[inline(always)]
989unsafe fn cmo_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
990 cmo_scalar(data, period, first, out);
991}
992
993#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
994#[inline(always)]
995unsafe fn cmo_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
996 cmo_avx2(data, period, first, out);
997}
998
999#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1000#[inline(always)]
1001unsafe fn cmo_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1002 if period <= 32 {
1003 cmo_row_avx512_short(data, first, period, out);
1004 } else {
1005 cmo_row_avx512_long(data, first, period, out);
1006 }
1007}
1008
1009#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1010#[inline(always)]
1011unsafe fn cmo_row_avx512_short(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1012 cmo_avx512_short(data, period, first, out)
1013}
1014
1015#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1016#[inline(always)]
1017unsafe fn cmo_row_avx512_long(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
1018 cmo_avx512_long(data, period, first, out)
1019}
1020
1021#[inline(always)]
1022unsafe fn cmo_row_from_gl(
1023 gains: &[f64],
1024 losses: &[f64],
1025 pg: &[f64],
1026 pl: &[f64],
1027 first: usize,
1028 period: usize,
1029 out: &mut [f64],
1030) {
1031 let len = out.len();
1032 let start = first + 1;
1033 let init_end = first + period;
1034 let inv_period = 1.0 / (period as f64);
1035 let period_m1 = (period - 1) as f64;
1036
1037 let sum_gain = pg[init_end + 1] - pg[start];
1038 let sum_loss = pl[init_end + 1] - pl[start];
1039 let mut avg_gain = sum_gain * inv_period;
1040 let mut avg_loss = sum_loss * inv_period;
1041
1042 {
1043 let sum_gl = avg_gain + avg_loss;
1044 *out.get_unchecked_mut(init_end) = if sum_gl != 0.0 {
1045 100.0 * ((avg_gain - avg_loss) / sum_gl)
1046 } else {
1047 0.0
1048 };
1049 }
1050
1051 let mut i = init_end + 1;
1052 while i < len {
1053 let g = *gains.get_unchecked(i);
1054 let l = *losses.get_unchecked(i);
1055
1056 avg_gain *= period_m1;
1057 avg_loss *= period_m1;
1058 avg_gain += g;
1059 avg_loss += l;
1060 avg_gain *= inv_period;
1061 avg_loss *= inv_period;
1062
1063 let sum_gl = avg_gain + avg_loss;
1064 *out.get_unchecked_mut(i) = if sum_gl != 0.0 {
1065 100.0 * ((avg_gain - avg_loss) / sum_gl)
1066 } else {
1067 0.0
1068 };
1069 i += 1;
1070 }
1071}
1072
1073#[derive(Debug, Clone)]
1074pub struct CmoStream {
1075 period: usize,
1076 inv_period: f64,
1077 avg_gain: f64,
1078 avg_loss: f64,
1079 prev: f64,
1080 head: usize,
1081 started: bool,
1082 filled: bool,
1083}
1084
1085impl CmoStream {
1086 pub fn try_new(params: CmoParams) -> Result<Self, CmoError> {
1087 let period = params.period.unwrap_or(14);
1088 if period == 0 {
1089 return Err(CmoError::InvalidPeriod {
1090 period,
1091 data_len: 0,
1092 });
1093 }
1094 Ok(Self {
1095 period,
1096 inv_period: 1.0 / (period as f64),
1097 avg_gain: 0.0,
1098 avg_loss: 0.0,
1099 prev: 0.0,
1100 head: 0,
1101 started: false,
1102 filled: false,
1103 })
1104 }
1105 #[inline(always)]
1106 pub fn update(&mut self, value: f64) -> Option<f64> {
1107 if !self.started {
1108 self.prev = value;
1109 self.started = true;
1110 return None;
1111 }
1112
1113 let diff = value - self.prev;
1114 self.prev = value;
1115
1116 let ad = diff.abs();
1117 let gain = 0.5 * (ad + diff);
1118 let loss = 0.5 * (ad - diff);
1119
1120 if !self.filled {
1121 self.avg_gain += gain;
1122 self.avg_loss += loss;
1123 self.head += 1;
1124
1125 if self.head == self.period {
1126 self.avg_gain *= self.inv_period;
1127 self.avg_loss *= self.inv_period;
1128 self.filled = true;
1129
1130 let denom = self.avg_gain + self.avg_loss;
1131 return Some(if denom != 0.0 {
1132 100.0 * (self.avg_gain - self.avg_loss) / denom
1133 } else {
1134 0.0
1135 });
1136 }
1137 return None;
1138 }
1139
1140 let ip = self.inv_period;
1141 self.avg_gain = (gain - self.avg_gain).mul_add(ip, self.avg_gain);
1142 self.avg_loss = (loss - self.avg_loss).mul_add(ip, self.avg_loss);
1143
1144 let denom = self.avg_gain + self.avg_loss;
1145 Some(if denom != 0.0 {
1146 100.0 * (self.avg_gain - self.avg_loss) / denom
1147 } else {
1148 0.0
1149 })
1150 }
1151}
1152
1153#[cfg(feature = "python")]
1154use crate::utilities::kernel_validation::validate_kernel;
1155#[cfg(feature = "python")]
1156use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
1157#[cfg(feature = "python")]
1158use pyo3::exceptions::PyValueError;
1159#[cfg(feature = "python")]
1160use pyo3::prelude::*;
1161#[cfg(feature = "python")]
1162use pyo3::types::PyDict;
1163
1164#[cfg(all(feature = "python", feature = "cuda"))]
1165use crate::cuda::oscillators::CudaCmo;
1166#[cfg(all(feature = "python", feature = "cuda"))]
1167use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
1168
1169#[cfg(feature = "python")]
1170#[pyfunction(name = "cmo")]
1171#[pyo3(signature = (data, period=None, kernel=None))]
1172pub fn cmo_py<'py>(
1173 py: Python<'py>,
1174 data: PyReadonlyArray1<'py, f64>,
1175 period: Option<usize>,
1176 kernel: Option<&str>,
1177) -> PyResult<Bound<'py, PyArray1<f64>>> {
1178 let slice_in = data.as_slice()?;
1179 let kern = validate_kernel(kernel, false)?;
1180
1181 let params = CmoParams { period };
1182 let input = CmoInput::from_slice(slice_in, params);
1183
1184 let result_vec: Vec<f64> = py
1185 .allow_threads(|| cmo_with_kernel(&input, kern).map(|o| o.values))
1186 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1187
1188 Ok(result_vec.into_pyarray(py))
1189}
1190
1191#[cfg(feature = "python")]
1192#[pyclass(name = "CmoStream")]
1193pub struct CmoStreamPy {
1194 stream: CmoStream,
1195}
1196
1197#[cfg(feature = "python")]
1198#[pymethods]
1199impl CmoStreamPy {
1200 #[new]
1201 fn new(period: Option<usize>) -> PyResult<Self> {
1202 let params = CmoParams { period };
1203 let stream =
1204 CmoStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1205 Ok(CmoStreamPy { stream })
1206 }
1207
1208 fn update(&mut self, value: f64) -> Option<f64> {
1209 self.stream.update(value)
1210 }
1211}
1212
1213#[cfg(feature = "python")]
1214#[pyfunction(name = "cmo_batch")]
1215#[pyo3(signature = (data, period_range, kernel=None))]
1216pub fn cmo_batch_py<'py>(
1217 py: Python<'py>,
1218 data: PyReadonlyArray1<'py, f64>,
1219 period_range: (usize, usize, usize),
1220 kernel: Option<&str>,
1221) -> PyResult<Bound<'py, PyDict>> {
1222 let slice_in = data.as_slice()?;
1223
1224 let sweep = CmoBatchRange {
1225 period: period_range,
1226 };
1227
1228 let combos = expand_grid(&sweep);
1229 let rows = combos.len();
1230 let cols = slice_in.len();
1231 let total = rows.checked_mul(cols).ok_or_else(|| {
1232 PyValueError::new_err(format!(
1233 "cmo_batch: size overflow for rows={} cols={}",
1234 rows, cols
1235 ))
1236 })?;
1237
1238 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1239 let slice_out = unsafe { out_arr.as_slice_mut()? };
1240
1241 let kern = validate_kernel(kernel, true)?;
1242
1243 let combos = py
1244 .allow_threads(|| {
1245 let kernel = match kern {
1246 Kernel::Auto => detect_best_batch_kernel(),
1247 k => k,
1248 };
1249 let simd = match kernel {
1250 Kernel::Avx512Batch => Kernel::Avx512,
1251 Kernel::Avx2Batch => Kernel::Avx2,
1252 Kernel::ScalarBatch => Kernel::Scalar,
1253 _ => unreachable!(),
1254 };
1255 cmo_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1256 })
1257 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1258
1259 let dict = PyDict::new(py);
1260 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1261 dict.set_item(
1262 "periods",
1263 combos
1264 .iter()
1265 .map(|p| p.period.unwrap() as u64)
1266 .collect::<Vec<_>>()
1267 .into_pyarray(py),
1268 )?;
1269
1270 Ok(dict)
1271}
1272
1273#[cfg(all(feature = "python", feature = "cuda"))]
1274#[pyfunction(name = "cmo_cuda_batch_dev")]
1275#[pyo3(signature = (data_f32, period_range, device_id=0))]
1276pub fn cmo_cuda_batch_dev_py<'py>(
1277 py: Python<'py>,
1278 data_f32: numpy::PyReadonlyArray1<'py, f32>,
1279 period_range: (usize, usize, usize),
1280 device_id: usize,
1281) -> PyResult<(DeviceArrayF32Py, Bound<'py, pyo3::types::PyDict>)> {
1282 use crate::cuda::cuda_available;
1283 use numpy::IntoPyArray;
1284 use pyo3::types::PyDict;
1285
1286 if !cuda_available() {
1287 return Err(PyValueError::new_err("CUDA not available"));
1288 }
1289 let prices = data_f32.as_slice()?;
1290 let sweep = CmoBatchRange {
1291 period: period_range,
1292 };
1293 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1294 let cuda = CudaCmo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1295 let ctx_arc = cuda.context_arc();
1296 let dev_id = cuda.device_id();
1297 cuda.cmo_batch_dev(prices, &sweep)
1298 .map_err(|e| PyValueError::new_err(e.to_string()))
1299 .map(|inner| (inner, ctx_arc, dev_id))
1300 })?;
1301
1302 let dict = PyDict::new(py);
1303 let periods: Vec<u64> = expand_grid(&sweep)
1304 .iter()
1305 .map(|p| p.period.unwrap_or(14) as u64)
1306 .collect();
1307 dict.set_item("periods", periods.into_pyarray(py))?;
1308
1309 Ok((
1310 DeviceArrayF32Py {
1311 inner,
1312 _ctx: Some(ctx_arc),
1313 device_id: Some(dev_id),
1314 },
1315 dict,
1316 ))
1317}
1318
1319#[cfg(all(feature = "python", feature = "cuda"))]
1320#[pyfunction(name = "cmo_cuda_many_series_one_param_dev")]
1321#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1322pub fn cmo_cuda_many_series_one_param_dev_py(
1323 py: Python<'_>,
1324 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1325 period: usize,
1326 device_id: usize,
1327) -> PyResult<DeviceArrayF32Py> {
1328 use crate::cuda::cuda_available;
1329 use numpy::PyUntypedArrayMethods;
1330 if !cuda_available() {
1331 return Err(PyValueError::new_err("CUDA not available"));
1332 }
1333 let flat = data_tm_f32.as_slice()?;
1334 let rows = data_tm_f32.shape()[0];
1335 let cols = data_tm_f32.shape()[1];
1336 let params = CmoParams {
1337 period: Some(period),
1338 };
1339 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1340 let cuda = CudaCmo::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1341 let ctx_arc = cuda.context_arc();
1342 let dev_id = cuda.device_id();
1343 cuda.cmo_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
1344 .map_err(|e| PyValueError::new_err(e.to_string()))
1345 .map(|inner| (inner, ctx_arc, dev_id))
1346 })?;
1347 Ok(DeviceArrayF32Py {
1348 inner,
1349 _ctx: Some(ctx_arc),
1350 device_id: Some(dev_id),
1351 })
1352}
1353
1354#[cfg(test)]
1355mod tests {
1356 use super::*;
1357 use crate::skip_if_unsupported;
1358 use crate::utilities::data_loader::read_candles_from_csv;
1359
1360 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1361 #[test]
1362 fn test_cmo_into_matches_api() -> Result<(), Box<dyn Error>> {
1363 let mut data = vec![f64::NAN; 3];
1364 data.extend((0..256).map(|i| {
1365 let x = i as f64;
1366 (x * 0.07).sin() * 5.0 + x * 0.1
1367 }));
1368
1369 let input = CmoInput::from_slice(&data, CmoParams::default());
1370
1371 let baseline = cmo_with_kernel(&input, Kernel::Auto)?.values;
1372
1373 let mut out = vec![0.0; data.len()];
1374 cmo_into(&input, &mut out)?;
1375
1376 assert_eq!(baseline.len(), out.len());
1377
1378 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1379 (a.is_nan() && b.is_nan()) || (a == b) || ((a - b).abs() <= 1e-12)
1380 }
1381
1382 for i in 0..out.len() {
1383 assert!(
1384 eq_or_both_nan(baseline[i], out[i]),
1385 "mismatch at {}: baseline={} out={}",
1386 i,
1387 baseline[i],
1388 out[i]
1389 );
1390 }
1391
1392 Ok(())
1393 }
1394
1395 fn check_cmo_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1396 skip_if_unsupported!(kernel, test_name);
1397 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1398 let candles = read_candles_from_csv(file_path)?;
1399 let default_params = CmoParams { period: None };
1400 let input = CmoInput::from_candles(&candles, "close", default_params);
1401 let output = cmo_with_kernel(&input, kernel)?;
1402 assert_eq!(output.values.len(), candles.close.len());
1403 let params_10 = CmoParams { period: Some(10) };
1404 let input_10 = CmoInput::from_candles(&candles, "hl2", params_10);
1405 let output_10 = cmo_with_kernel(&input_10, kernel)?;
1406 assert_eq!(output_10.values.len(), candles.close.len());
1407 Ok(())
1408 }
1409
1410 fn check_cmo_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1411 skip_if_unsupported!(kernel, test_name);
1412 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1413 let candles = read_candles_from_csv(file_path)?;
1414 let params = CmoParams { period: Some(14) };
1415 let input = CmoInput::from_candles(&candles, "close", params);
1416 let cmo_result = cmo_with_kernel(&input, kernel)?;
1417 let expected_last_five = [
1418 -13.152504931406101,
1419 -14.649876201213106,
1420 -16.760170709240303,
1421 -14.274505732779227,
1422 -21.984038127126716,
1423 ];
1424 let start_idx = cmo_result.values.len() - 5;
1425 let last_five = &cmo_result.values[start_idx..];
1426 for (i, &actual) in last_five.iter().enumerate() {
1427 let expected = expected_last_five[i];
1428 assert!(
1429 (actual - expected).abs() < 1e-6,
1430 "[{}] CMO mismatch at final 5 index {}: expected {}, got {}",
1431 test_name,
1432 i,
1433 expected,
1434 actual
1435 );
1436 }
1437 Ok(())
1438 }
1439
1440 fn check_cmo_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1441 skip_if_unsupported!(kernel, test_name);
1442 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1443 let candles = read_candles_from_csv(file_path)?;
1444 let input = CmoInput::with_default_candles(&candles);
1445 match input.data {
1446 CmoData::Candles { source, .. } => assert_eq!(source, "close"),
1447 _ => panic!("Expected CmoData::Candles variant"),
1448 }
1449 let output = cmo_with_kernel(&input, kernel)?;
1450 assert_eq!(output.values.len(), candles.close.len());
1451 Ok(())
1452 }
1453
1454 fn check_cmo_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1455 skip_if_unsupported!(kernel, test_name);
1456 let data = [10.0, 20.0, 30.0];
1457 let params = CmoParams { period: Some(0) };
1458 let input = CmoInput::from_slice(&data, params);
1459 let result = cmo_with_kernel(&input, kernel);
1460 assert!(
1461 result.is_err(),
1462 "[{}] Expected error for period=0",
1463 test_name
1464 );
1465 Ok(())
1466 }
1467
1468 fn check_cmo_period_exceeds_length(
1469 test_name: &str,
1470 kernel: Kernel,
1471 ) -> Result<(), Box<dyn Error>> {
1472 skip_if_unsupported!(kernel, test_name);
1473 let data = [10.0, 20.0, 30.0];
1474 let params = CmoParams { period: Some(10) };
1475 let input = CmoInput::from_slice(&data, params);
1476 let result = cmo_with_kernel(&input, kernel);
1477 assert!(
1478 result.is_err(),
1479 "[{}] Expected error for period>data.len()",
1480 test_name
1481 );
1482 Ok(())
1483 }
1484
1485 fn check_cmo_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1486 skip_if_unsupported!(kernel, test_name);
1487 let single = [42.0];
1488 let params = CmoParams { period: Some(14) };
1489 let input = CmoInput::from_slice(&single, params);
1490 let result = cmo_with_kernel(&input, kernel);
1491 assert!(
1492 result.is_err(),
1493 "[{}] Expected error for insufficient data",
1494 test_name
1495 );
1496 Ok(())
1497 }
1498
1499 fn check_cmo_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1500 skip_if_unsupported!(kernel, test_name);
1501 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1502 let candles = read_candles_from_csv(file_path)?;
1503 let first_params = CmoParams { period: Some(14) };
1504 let first_input = CmoInput::from_candles(&candles, "close", first_params);
1505 let first_result = cmo_with_kernel(&first_input, kernel)?;
1506 let second_params = CmoParams { period: Some(14) };
1507 let second_input = CmoInput::from_slice(&first_result.values, second_params);
1508 let second_result = cmo_with_kernel(&second_input, kernel)?;
1509 assert_eq!(second_result.values.len(), first_result.values.len());
1510 for i in 28..second_result.values.len() {
1511 assert!(
1512 !second_result.values[i].is_nan(),
1513 "[{}] Expected no NaN after index 28, found NaN at {}",
1514 test_name,
1515 i
1516 );
1517 }
1518 Ok(())
1519 }
1520
1521 #[cfg(debug_assertions)]
1522 fn check_cmo_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1523 skip_if_unsupported!(kernel, test_name);
1524
1525 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1526 let candles = read_candles_from_csv(file_path)?;
1527
1528 let test_periods = vec![7, 14, 21, 28];
1529
1530 for period in test_periods {
1531 let params = CmoParams {
1532 period: Some(period),
1533 };
1534 let input = CmoInput::from_candles(&candles, "close", params);
1535 let output = cmo_with_kernel(&input, kernel)?;
1536
1537 for (i, &val) in output.values.iter().enumerate() {
1538 if val.is_nan() {
1539 continue;
1540 }
1541
1542 let bits = val.to_bits();
1543
1544 if bits == 0x11111111_11111111 {
1545 panic!(
1546 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period {}",
1547 test_name, val, bits, i, period
1548 );
1549 }
1550
1551 if bits == 0x22222222_22222222 {
1552 panic!(
1553 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period {}",
1554 test_name, val, bits, i, period
1555 );
1556 }
1557
1558 if bits == 0x33333333_33333333 {
1559 panic!(
1560 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period {}",
1561 test_name, val, bits, i, period
1562 );
1563 }
1564 }
1565 }
1566
1567 Ok(())
1568 }
1569
1570 #[cfg(not(debug_assertions))]
1571 fn check_cmo_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1572 Ok(())
1573 }
1574
1575 macro_rules! generate_all_cmo_tests {
1576 ($($test_fn:ident),*) => {
1577 paste::paste! {
1578 $(
1579 #[test]
1580 fn [<$test_fn _scalar_f64>]() {
1581 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1582 }
1583 )*
1584 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1585 $(
1586 #[test]
1587 fn [<$test_fn _avx2_f64>]() {
1588 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1589 }
1590 #[test]
1591 fn [<$test_fn _avx512_f64>]() {
1592 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1593 }
1594 )*
1595 }
1596 }
1597 }
1598
1599 #[cfg(feature = "proptest")]
1600 #[allow(clippy::float_cmp)]
1601 fn check_cmo_property(
1602 test_name: &str,
1603 kernel: Kernel,
1604 ) -> Result<(), Box<dyn std::error::Error>> {
1605 use proptest::prelude::*;
1606 skip_if_unsupported!(kernel, test_name);
1607
1608 let strat = (1usize..=50).prop_flat_map(|period| {
1609 (
1610 prop::collection::vec(
1611 (-1e6f64..1e6f64)
1612 .prop_filter("finite and non-zero", |x| x.is_finite() && x.abs() > 1e-10),
1613 (period + 1).max(2)..400,
1614 ),
1615 Just(period),
1616 )
1617 });
1618
1619 proptest::test_runner::TestRunner::default()
1620 .run(&strat, |(data, period)| {
1621 let params = CmoParams {
1622 period: Some(period),
1623 };
1624 let input = CmoInput::from_slice(&data, params);
1625
1626 let CmoOutput { values: out } = cmo_with_kernel(&input, kernel).unwrap();
1627 let CmoOutput { values: ref_out } =
1628 cmo_with_kernel(&input, Kernel::Scalar).unwrap();
1629
1630 prop_assert_eq!(out.len(), data.len(), "Output length mismatch");
1631
1632 let first_valid = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1633 let warmup = first_valid + period;
1634
1635 for i in 0..warmup.min(out.len()) {
1636 prop_assert!(
1637 out[i].is_nan(),
1638 "Expected NaN during warmup at index {}, got {}",
1639 i,
1640 out[i]
1641 );
1642 }
1643
1644 if warmup < out.len() {
1645 prop_assert!(
1646 !out[warmup].is_nan(),
1647 "Expected valid value at index {} (first after warmup), got NaN",
1648 warmup
1649 );
1650 }
1651
1652 for i in warmup..data.len() {
1653 let y = out[i];
1654 let r = ref_out[i];
1655
1656 prop_assert!(
1657 y.is_nan() || (y >= -100.0 - 1e-9 && y <= 100.0 + 1e-9),
1658 "CMO value {} at index {} outside bounds [-100, 100]",
1659 y,
1660 i
1661 );
1662
1663 if data[..=i].iter().all(|x| x.is_finite()) {
1664 prop_assert!(
1665 y.is_finite(),
1666 "Expected finite output at index {}, got {}",
1667 i,
1668 y
1669 );
1670 }
1671
1672 let y_bits = y.to_bits();
1673 let r_bits = r.to_bits();
1674
1675 if !y.is_finite() || !r.is_finite() {
1676 prop_assert_eq!(
1677 y_bits,
1678 r_bits,
1679 "Finite/NaN mismatch at index {}: {} vs {}",
1680 i,
1681 y,
1682 r
1683 );
1684 continue;
1685 }
1686
1687 let ulp_diff: u64 = y_bits.abs_diff(r_bits);
1688 prop_assert!(
1689 (y - r).abs() <= 1e-9 || ulp_diff <= 8,
1690 "Kernel mismatch at index {}: {} vs {} (ULP={})",
1691 i,
1692 y,
1693 r,
1694 ulp_diff
1695 );
1696 }
1697
1698 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-12) && warmup < data.len() {
1699 let cmo_val = out[warmup];
1700 prop_assert!(
1701 cmo_val.abs() <= 1e-9,
1702 "Constant data should produce CMO of 0, got {} at index {}",
1703 cmo_val,
1704 warmup
1705 );
1706 }
1707
1708 let is_increasing = data.windows(2).all(|w| w[1] >= w[0] - 1e-10);
1709 if is_increasing && warmup < data.len() {
1710 for i in warmup..data.len() {
1711 prop_assert!(
1712 out[i].is_nan() || out[i] >= -1e-6,
1713 "Monotonically increasing data should produce non-negative CMO, got {} at index {}",
1714 out[i],
1715 i
1716 );
1717 }
1718 }
1719
1720 let is_decreasing = data.windows(2).all(|w| w[1] <= w[0] + 1e-10);
1721 if is_decreasing && warmup < data.len() {
1722 for i in warmup..data.len() {
1723 prop_assert!(
1724 out[i].is_nan() || out[i] <= 1e-6,
1725 "Monotonically decreasing data should produce non-positive CMO, got {} at index {}",
1726 out[i],
1727 i
1728 );
1729 }
1730 }
1731
1732 if period > 1 && warmup + 5 < data.len() {
1733 let has_strong_gains = (warmup..data.len().min(warmup + 10))
1734 .zip(warmup.saturating_sub(1)..data.len().saturating_sub(1).min(warmup + 9))
1735 .all(|(i, j)| data[i] > data[j] * 1.1);
1736
1737 if has_strong_gains {
1738 let last_idx = data.len() - 1;
1739 prop_assert!(
1740 out[last_idx].is_nan() || out[last_idx] >= 50.0,
1741 "Strong gains should produce CMO > 50, got {} at index {}",
1742 out[last_idx],
1743 last_idx
1744 );
1745 }
1746 }
1747
1748 Ok(())
1749 })
1750 .unwrap();
1751
1752 Ok(())
1753 }
1754
1755 generate_all_cmo_tests!(
1756 check_cmo_partial_params,
1757 check_cmo_accuracy,
1758 check_cmo_default_candles,
1759 check_cmo_zero_period,
1760 check_cmo_period_exceeds_length,
1761 check_cmo_very_small_dataset,
1762 check_cmo_reinput,
1763 check_cmo_no_poison
1764 );
1765
1766 #[cfg(feature = "proptest")]
1767 generate_all_cmo_tests!(check_cmo_property);
1768
1769 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1770 skip_if_unsupported!(kernel, test);
1771 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1772 let c = read_candles_from_csv(file)?;
1773 let output = CmoBatchBuilder::new()
1774 .kernel(kernel)
1775 .apply_candles(&c, "close")?;
1776 let def = CmoParams::default();
1777 let row = output.values_for(&def).expect("default row missing");
1778 assert_eq!(row.len(), c.close.len());
1779 Ok(())
1780 }
1781
1782 #[cfg(debug_assertions)]
1783 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784 skip_if_unsupported!(kernel, test);
1785
1786 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1787 let c = read_candles_from_csv(file)?;
1788
1789 let output = CmoBatchBuilder::new()
1790 .kernel(kernel)
1791 .period_range(7, 28, 7)
1792 .apply_candles(&c, "close")?;
1793
1794 for (idx, &val) in output.values.iter().enumerate() {
1795 if val.is_nan() {
1796 continue;
1797 }
1798
1799 let bits = val.to_bits();
1800 let row = idx / output.cols;
1801 let col = idx % output.cols;
1802
1803 if bits == 0x11111111_11111111 {
1804 panic!(
1805 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1806 test, val, bits, row, col, idx
1807 );
1808 }
1809
1810 if bits == 0x22222222_22222222 {
1811 panic!(
1812 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1813 test, val, bits, row, col, idx
1814 );
1815 }
1816
1817 if bits == 0x33333333_33333333 {
1818 panic!(
1819 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {})",
1820 test, val, bits, row, col, idx
1821 );
1822 }
1823 }
1824
1825 Ok(())
1826 }
1827
1828 #[cfg(not(debug_assertions))]
1829 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1830 Ok(())
1831 }
1832
1833 macro_rules! gen_batch_tests {
1834 ($fn_name:ident) => {
1835 paste::paste! {
1836 #[test] fn [<$fn_name _scalar>]() {
1837 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1838 }
1839 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1840 #[test] fn [<$fn_name _avx2>]() {
1841 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1842 }
1843 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1844 #[test] fn [<$fn_name _avx512>]() {
1845 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1846 }
1847 #[test] fn [<$fn_name _auto_detect>]() {
1848 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1849 }
1850 }
1851 };
1852 }
1853 gen_batch_tests!(check_batch_default_row);
1854 gen_batch_tests!(check_batch_no_poison);
1855}
1856
1857#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1858#[wasm_bindgen]
1859pub fn cmo_js(data: &[f64], period: Option<usize>) -> Result<Vec<f64>, JsValue> {
1860 let params = CmoParams { period };
1861 let input = CmoInput::from_slice(data, params);
1862
1863 let mut output = vec![0.0; data.len()];
1864 cmo_into_slice(&mut output, &input, Kernel::Auto)
1865 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1866
1867 Ok(output)
1868}
1869
1870#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1871#[wasm_bindgen]
1872pub fn cmo_into(
1873 in_ptr: *const f64,
1874 out_ptr: *mut f64,
1875 len: usize,
1876 period: Option<usize>,
1877) -> Result<(), JsValue> {
1878 if in_ptr.is_null() || out_ptr.is_null() {
1879 return Err(JsValue::from_str("Null pointer provided"));
1880 }
1881
1882 unsafe {
1883 let data = std::slice::from_raw_parts(in_ptr, len);
1884 let params = CmoParams { period };
1885 let input = CmoInput::from_slice(data, params);
1886
1887 if in_ptr == out_ptr {
1888 let mut temp = vec![0.0; len];
1889 cmo_into_slice(&mut temp, &input, Kernel::Auto)
1890 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1891 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1892 out.copy_from_slice(&temp);
1893 } else {
1894 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1895 cmo_into_slice(out, &input, Kernel::Auto)
1896 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1897 }
1898 Ok(())
1899 }
1900}
1901
1902#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1903#[wasm_bindgen]
1904pub fn cmo_alloc(len: usize) -> *mut f64 {
1905 let mut vec = Vec::<f64>::with_capacity(len);
1906 let ptr = vec.as_mut_ptr();
1907 std::mem::forget(vec);
1908 ptr
1909}
1910
1911#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1912#[wasm_bindgen]
1913pub fn cmo_free(ptr: *mut f64, len: usize) {
1914 if !ptr.is_null() {
1915 unsafe {
1916 let _ = Vec::from_raw_parts(ptr, len, len);
1917 }
1918 }
1919}
1920
1921#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1922#[derive(Serialize, Deserialize)]
1923pub struct CmoBatchConfig {
1924 pub period_range: (usize, usize, usize),
1925}
1926
1927#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1928#[derive(Serialize, Deserialize)]
1929pub struct CmoBatchJsOutput {
1930 pub values: Vec<f64>,
1931 pub combos: Vec<CmoParams>,
1932 pub rows: usize,
1933 pub cols: usize,
1934}
1935
1936#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1937#[wasm_bindgen(js_name = cmo_batch)]
1938pub fn cmo_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1939 let config: CmoBatchConfig = serde_wasm_bindgen::from_value(config)
1940 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1941
1942 let (p_start, p_end, p_step) = config.period_range;
1943
1944 let batch_range = CmoBatchRange {
1945 period: (p_start, p_end, p_step),
1946 };
1947
1948 let output = cmo_batch_with_kernel(data, &batch_range, Kernel::Auto)
1949 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1950
1951 let js_output = CmoBatchJsOutput {
1952 values: output.values,
1953 combos: output.combos,
1954 rows: output.rows,
1955 cols: output.cols,
1956 };
1957
1958 serde_wasm_bindgen::to_value(&js_output).map_err(|e| JsValue::from_str(&e.to_string()))
1959}
1960
1961#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1962#[wasm_bindgen]
1963pub fn cmo_batch_into(
1964 in_ptr: *const f64,
1965 out_ptr: *mut f64,
1966 len: usize,
1967 period_start: usize,
1968 period_end: usize,
1969 period_step: usize,
1970) -> Result<usize, JsValue> {
1971 if in_ptr.is_null() || out_ptr.is_null() {
1972 return Err(JsValue::from_str("null pointer passed to cmo_batch_into"));
1973 }
1974
1975 unsafe {
1976 let data = std::slice::from_raw_parts(in_ptr, len);
1977
1978 let sweep = CmoBatchRange {
1979 period: (period_start, period_end, period_step),
1980 };
1981
1982 let combos = expand_grid(&sweep);
1983 let rows = combos.len();
1984 let cols = len;
1985
1986 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1987
1988 cmo_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1989 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1990
1991 Ok(rows)
1992 }
1993}