1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23
24#[cfg(not(target_arch = "wasm32"))]
25use rayon::prelude::*;
26
27use std::convert::AsRef;
28use std::error::Error;
29use std::mem::MaybeUninit;
30use thiserror::Error;
31
32#[cfg(all(feature = "python", feature = "cuda"))]
33use crate::cuda::cuda_available;
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::moving_averages::CudaCoraWave;
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::cuda::moving_averages::DeviceArrayF32;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use crate::utilities::dlpack_cuda::DeviceArrayF32Py;
40
41impl<'a> AsRef<[f64]> for CoraWaveInput<'a> {
42 #[inline(always)]
43 fn as_ref(&self) -> &[f64] {
44 match &self.data {
45 CoraWaveData::Slice(slice) => slice,
46 CoraWaveData::Candles { candles, source } => source_type(candles, source),
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub enum CoraWaveData<'a> {
53 Candles {
54 candles: &'a Candles,
55 source: &'a str,
56 },
57 Slice(&'a [f64]),
58}
59
60#[derive(Debug, Clone)]
61pub struct CoraWaveOutput {
62 pub values: Vec<f64>,
63}
64
65#[derive(Debug, Clone)]
66#[cfg_attr(
67 all(target_arch = "wasm32", feature = "wasm"),
68 derive(Serialize, Deserialize)
69)]
70pub struct CoraWaveParams {
71 pub period: Option<usize>,
72 pub r_multi: Option<f64>,
73 pub smooth: Option<bool>,
74}
75
76impl Default for CoraWaveParams {
77 fn default() -> Self {
78 Self {
79 period: Some(20),
80 r_multi: Some(2.0),
81 smooth: Some(true),
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub struct CoraWaveInput<'a> {
88 pub data: CoraWaveData<'a>,
89 pub params: CoraWaveParams,
90}
91
92impl<'a> CoraWaveInput<'a> {
93 #[inline]
94 pub fn from_candles(c: &'a Candles, s: &'a str, p: CoraWaveParams) -> Self {
95 Self {
96 data: CoraWaveData::Candles {
97 candles: c,
98 source: s,
99 },
100 params: p,
101 }
102 }
103
104 #[inline]
105 pub fn from_slice(sl: &'a [f64], p: CoraWaveParams) -> Self {
106 Self {
107 data: CoraWaveData::Slice(sl),
108 params: p,
109 }
110 }
111
112 #[inline]
113 pub fn with_default_candles(c: &'a Candles) -> Self {
114 Self::from_candles(c, "close", CoraWaveParams::default())
115 }
116
117 #[inline]
118 pub fn get_period(&self) -> usize {
119 self.params.period.unwrap_or(20)
120 }
121
122 #[inline]
123 pub fn get_r_multi(&self) -> f64 {
124 self.params.r_multi.unwrap_or(2.0)
125 }
126
127 #[inline]
128 pub fn get_smooth(&self) -> bool {
129 self.params.smooth.unwrap_or(true)
130 }
131}
132
133#[derive(Copy, Clone, Debug)]
134pub struct CoraWaveBuilder {
135 period: Option<usize>,
136 r_multi: Option<f64>,
137 smooth: Option<bool>,
138 kernel: Kernel,
139}
140
141impl Default for CoraWaveBuilder {
142 fn default() -> Self {
143 Self {
144 period: None,
145 r_multi: None,
146 smooth: None,
147 kernel: Kernel::Auto,
148 }
149 }
150}
151
152impl CoraWaveBuilder {
153 #[inline(always)]
154 pub fn new() -> Self {
155 Self::default()
156 }
157
158 #[inline(always)]
159 pub fn period(mut self, val: usize) -> Self {
160 self.period = Some(val);
161 self
162 }
163
164 #[inline(always)]
165 pub fn r_multi(mut self, val: f64) -> Self {
166 self.r_multi = Some(val);
167 self
168 }
169
170 #[inline(always)]
171 pub fn smooth(mut self, val: bool) -> Self {
172 self.smooth = Some(val);
173 self
174 }
175
176 #[inline(always)]
177 pub fn kernel(mut self, k: Kernel) -> Self {
178 self.kernel = k;
179 self
180 }
181
182 #[inline(always)]
183 pub fn apply(self, c: &Candles) -> Result<CoraWaveOutput, CoraWaveError> {
184 let p = CoraWaveParams {
185 period: self.period,
186 r_multi: self.r_multi,
187 smooth: self.smooth,
188 };
189 let i = CoraWaveInput::from_candles(c, "close", p);
190 cora_wave_with_kernel(&i, self.kernel)
191 }
192
193 #[inline(always)]
194 pub fn apply_slice(self, d: &[f64]) -> Result<CoraWaveOutput, CoraWaveError> {
195 let p = CoraWaveParams {
196 period: self.period,
197 r_multi: self.r_multi,
198 smooth: self.smooth,
199 };
200 let i = CoraWaveInput::from_slice(d, p);
201 cora_wave_with_kernel(&i, self.kernel)
202 }
203
204 #[inline(always)]
205 pub fn into_stream(self) -> Result<CoraWaveStream, CoraWaveError> {
206 let p = CoraWaveParams {
207 period: self.period,
208 r_multi: self.r_multi,
209 smooth: self.smooth,
210 };
211 CoraWaveStream::try_new(p)
212 }
213}
214
215#[derive(Debug, Error)]
216pub enum CoraWaveError {
217 #[error("cora_wave: Input data slice is empty.")]
218 EmptyInputData,
219
220 #[error("cora_wave: All values are NaN.")]
221 AllValuesNaN,
222
223 #[error("cora_wave: Invalid period: period = {period}, data length = {data_len}")]
224 InvalidPeriod { period: usize, data_len: usize },
225
226 #[error("cora_wave: Not enough valid data: needed = {needed}, valid = {valid}")]
227 NotEnoughValidData { needed: usize, valid: usize },
228
229 #[error("cora_wave: Invalid r_multi: {value}")]
230 InvalidRMulti { value: f64 },
231
232 #[error("cora_wave: Output length mismatch: expected = {expected}, got = {got}")]
233 OutputLengthMismatch { expected: usize, got: usize },
234
235 #[error("cora_wave: Invalid range: start = {start}, end = {end}, step = {step}")]
236 InvalidRange { start: f64, end: f64, step: f64 },
237
238 #[error("cora_wave: invalid kernel for batch: {0:?}")]
239 InvalidKernelForBatch(Kernel),
240
241 #[error("cora_wave: invalid input: {0}")]
242 InvalidInput(String),
243}
244
245#[inline]
246pub fn cora_wave(input: &CoraWaveInput) -> Result<CoraWaveOutput, CoraWaveError> {
247 cora_wave_with_kernel(input, Kernel::Auto)
248}
249
250pub fn cora_wave_with_kernel(
251 input: &CoraWaveInput,
252 kernel: Kernel,
253) -> Result<CoraWaveOutput, CoraWaveError> {
254 let (data, weights, inv_wsum, smooth_period, first, chosen) = cora_wave_prepare(input, kernel)?;
255 let period = weights.len();
256 let warm = first + period - 1 + smooth_period.saturating_sub(1);
257
258 let mut out = alloc_with_nan_prefix(data.len(), warm);
259 cora_wave_compute_into(
260 data,
261 &weights,
262 inv_wsum,
263 smooth_period,
264 first,
265 chosen,
266 &mut out,
267 );
268 Ok(CoraWaveOutput { values: out })
269}
270
271#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
272#[inline]
273pub fn cora_wave_into(input: &CoraWaveInput, out: &mut [f64]) -> Result<(), CoraWaveError> {
274 let (data, weights, inv_wsum, smooth_period, first, chosen) =
275 cora_wave_prepare(input, Kernel::Auto)?;
276
277 if out.len() != data.len() {
278 return Err(CoraWaveError::OutputLengthMismatch {
279 expected: data.len(),
280 got: out.len(),
281 });
282 }
283
284 let warm = first + weights.len() - 1 + smooth_period.saturating_sub(1);
285 let warm = warm.min(out.len());
286 if warm > 0 {
287 let qnan = f64::from_bits(0x7ff8_0000_0000_0000);
288 for v in &mut out[..warm] {
289 *v = qnan;
290 }
291 }
292
293 cora_wave_compute_into(data, &weights, inv_wsum, smooth_period, first, chosen, out);
294 Ok(())
295}
296
297#[inline]
298pub fn cora_wave_into_slice(
299 dst: &mut [f64],
300 input: &CoraWaveInput,
301 kern: Kernel,
302) -> Result<(), CoraWaveError> {
303 let (data, weights, inv_wsum, smooth_period, first, chosen) = cora_wave_prepare(input, kern)?;
304 if dst.len() != data.len() {
305 return Err(CoraWaveError::OutputLengthMismatch {
306 expected: data.len(),
307 got: dst.len(),
308 });
309 }
310 cora_wave_compute_into(data, &weights, inv_wsum, smooth_period, first, chosen, dst);
311
312 let warm = first + weights.len() - 1 + smooth_period.saturating_sub(1);
313 for v in &mut dst[..warm] {
314 *v = f64::NAN;
315 }
316 Ok(())
317}
318
319#[inline(always)]
320fn cora_wave_prepare<'a>(
321 input: &'a CoraWaveInput,
322 kernel: Kernel,
323) -> Result<(&'a [f64], Vec<f64>, f64, usize, usize, Kernel), CoraWaveError> {
324 let data: &[f64] = input.as_ref();
325 let len = data.len();
326 if len == 0 {
327 return Err(CoraWaveError::EmptyInputData);
328 }
329
330 let first = data
331 .iter()
332 .position(|x| !x.is_nan())
333 .ok_or(CoraWaveError::AllValuesNaN)?;
334
335 let period = input.get_period();
336 let r_multi = input.get_r_multi();
337 let smooth = input.get_smooth();
338
339 if period == 0 || period > len {
340 return Err(CoraWaveError::InvalidPeriod {
341 period,
342 data_len: len,
343 });
344 }
345 if len - first < period {
346 return Err(CoraWaveError::NotEnoughValidData {
347 needed: period,
348 valid: len - first,
349 });
350 }
351 if r_multi < 0.0 || !r_multi.is_finite() {
352 return Err(CoraWaveError::InvalidRMulti { value: r_multi });
353 }
354
355 let mut weights = Vec::with_capacity(period);
356 let inv_sum: f64;
357 if period == 1 {
358 weights.push(1.0);
359 inv_sum = 1.0;
360 } else {
361 let start_wt = 0.01;
362 let end_wt = period as f64;
363 let r = (end_wt / start_wt).powf(1.0 / (period as f64 - 1.0)) - 1.0;
364 let base = 1.0 + r * r_multi;
365
366 let mut sum = 0.0;
367
368 let mut w = start_wt * base;
369 for _ in 0..period {
370 weights.push(w);
371 sum += w;
372 w *= base;
373 }
374 inv_sum = 1.0 / sum;
375 }
376
377 let smooth_period = if smooth {
378 ((period as f64).sqrt().round() as usize).max(1)
379 } else {
380 1
381 };
382 let chosen = match kernel {
383 Kernel::Auto => detect_best_kernel(),
384 k => k,
385 };
386
387 Ok((data, weights, inv_sum, smooth_period, first, chosen))
388}
389
390#[inline(always)]
391fn cora_wave_compute_into(
392 data: &[f64],
393 weights: &[f64],
394 inv_wsum: f64,
395 smooth_period: usize,
396 first: usize,
397 kernel: Kernel,
398 out: &mut [f64],
399) {
400 unsafe {
401 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
402 {
403 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
404 cora_wave_scalar_with_weights(data, weights, inv_wsum, smooth_period, first, out);
405 return;
406 }
407 }
408 match kernel {
409 Kernel::Scalar | Kernel::ScalarBatch => {
410 cora_wave_scalar_with_weights(data, weights, inv_wsum, smooth_period, first, out)
411 }
412 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
413 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
414 cora_wave_scalar_with_weights(data, weights, inv_wsum, smooth_period, first, out)
415 }
416 _ => unreachable!(),
417 }
418 }
419}
420
421#[inline]
422pub fn cora_wave_scalar_with_weights(
423 data: &[f64],
424 weights: &[f64],
425 inv_wsum: f64,
426 smooth_period: usize,
427 first_val: usize,
428 out: &mut [f64],
429) {
430 let n = data.len();
431 let p = weights.len();
432 if p == 0 || n == 0 {
433 return;
434 }
435
436 if smooth_period == 1 {
437 if p == 1 {
438 let start = first_val;
439 if start < n {
440 for i in start..n {
441 unsafe {
442 let v = *data.get_unchecked(i);
443 *out.get_unchecked_mut(i) = v * inv_wsum;
444 }
445 }
446 }
447 return;
448 }
449
450 let w0 = unsafe { *weights.get_unchecked(0) };
451 let w1 = unsafe { *weights.get_unchecked(1) };
452 let inv_R = w0 / w1;
453 let a_old = w0 * inv_R;
454 let w_last = unsafe { *weights.get_unchecked(p - 1) };
455
456 let warm0 = first_val + p - 1;
457 if warm0 >= n {
458 return;
459 }
460 let start0 = warm0 + 1 - p;
461
462 let mut acc0 = 0.0;
463 let mut acc1 = 0.0;
464 let mut acc2 = 0.0;
465 let mut acc3 = 0.0;
466 let mut j = 0usize;
467 let end4 = p & !3usize;
468
469 unsafe {
470 let xptr = data.as_ptr().add(start0);
471 let wptr = weights.as_ptr();
472 while j < end4 {
473 let x0 = *xptr.add(j);
474 let x1 = *xptr.add(j + 1);
475 let x2 = *xptr.add(j + 2);
476 let x3 = *xptr.add(j + 3);
477
478 let y0 = *wptr.add(j);
479 let y1 = *wptr.add(j + 1);
480 let y2 = *wptr.add(j + 2);
481 let y3 = *wptr.add(j + 3);
482
483 acc0 = x0.mul_add(y0, acc0);
484 acc1 = x1.mul_add(y1, acc1);
485 acc2 = x2.mul_add(y2, acc2);
486 acc3 = x3.mul_add(y3, acc3);
487
488 j += 4;
489 }
490 let mut S = (acc0 + acc1) + (acc2 + acc3);
491 while j < p {
492 let x = *xptr.add(j);
493 let y = *wptr.add(j);
494 S = x.mul_add(y, S);
495 j += 1;
496 }
497
498 *out.get_unchecked_mut(warm0) = S * inv_wsum;
499
500 let mut i = warm0;
501 while i + 1 < n {
502 let x_old = *data.get_unchecked(i + 1 - p);
503 let x_new = *data.get_unchecked(i + 1);
504 S = (S * inv_R) - a_old * x_old + w_last * x_new;
505 *out.get_unchecked_mut(i + 1) = S * inv_wsum;
506 i += 1;
507 }
508 }
509 return;
510 }
511
512 let m = smooth_period;
513 let wma_sum = (m as f64) * ((m as f64) + 1.0) * 0.5;
514
515 if p == 1 {
516 let warm0 = first_val;
517 if warm0 >= n {
518 return;
519 }
520
521 let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
522 let mut head = 0usize;
523 let warm_total = warm0 + m - 1;
524 unsafe {
525 for i in warm0..n {
526 ring_mu
527 .get_unchecked_mut(head)
528 .write(*data.get_unchecked(i));
529 head = (head + 1) % m;
530
531 if i >= warm_total {
532 let mut acc = 0.0;
533 for k in 0..m {
534 let idx = (head + k) % m;
535 let v = *ring_mu.get_unchecked(idx).assume_init_ref();
536 acc += v * ((k + 1) as f64);
537 }
538 *out.get_unchecked_mut(i) = acc / wma_sum;
539 }
540 }
541 }
542 return;
543 }
544
545 let w0 = unsafe { *weights.get_unchecked(0) };
546 let w1 = unsafe { *weights.get_unchecked(1) };
547 let inv_R = w0 / w1;
548 let a_old = w0 * inv_R;
549 let w_last = unsafe { *weights.get_unchecked(p - 1) };
550
551 let warm0 = first_val + p - 1;
552 if warm0 >= n {
553 return;
554 }
555 let start0 = warm0 + 1 - p;
556
557 let mut acc0 = 0.0;
558 let mut acc1 = 0.0;
559 let mut acc2 = 0.0;
560 let mut acc3 = 0.0;
561 let mut j = 0usize;
562 let end4 = p & !3usize;
563
564 unsafe {
565 let xptr = data.as_ptr().add(start0);
566 let wptr = weights.as_ptr();
567 while j < end4 {
568 let x0 = *xptr.add(j);
569 let x1 = *xptr.add(j + 1);
570 let x2 = *xptr.add(j + 2);
571 let x3 = *xptr.add(j + 3);
572
573 let y0 = *wptr.add(j);
574 let y1 = *wptr.add(j + 1);
575 let y2 = *wptr.add(j + 2);
576 let y3 = *wptr.add(j + 3);
577
578 acc0 = x0.mul_add(y0, acc0);
579 acc1 = x1.mul_add(y1, acc1);
580 acc2 = x2.mul_add(y2, acc2);
581 acc3 = x3.mul_add(y3, acc3);
582
583 j += 4;
584 }
585 let mut S = (acc0 + acc1) + (acc2 + acc3);
586 while j < p {
587 let x = *xptr.add(j);
588 let y = *wptr.add(j);
589 S = x.mul_add(y, S);
590 j += 1;
591 }
592
593 let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
594 let mut fill = 0usize;
595
596 let mut y = S * inv_wsum;
597 ring_mu.get_unchecked_mut(fill).write(y);
598 fill += 1;
599
600 let warm_total = warm0 + m - 1;
601 let mut i = warm0;
602 while i + 1 <= warm_total && i + 1 < n {
603 let x_old = *data.get_unchecked(i + 1 - p);
604 let x_new = *data.get_unchecked(i + 1);
605 S = (S * inv_R) - a_old * x_old + w_last * x_new;
606 y = S * inv_wsum;
607 ring_mu.get_unchecked_mut(fill).write(y);
608 fill += 1;
609 i += 1;
610 }
611 if warm_total >= n {
612 return;
613 }
614
615 let mut head = 0usize;
616
617 {
618 let mut acc = 0.0;
619 for k in 0..m {
620 let idx = (head + k) % m;
621 let v = *ring_mu.get_unchecked(idx).assume_init_ref();
622 acc += v * ((k + 1) as f64);
623 }
624 *out.get_unchecked_mut(warm_total) = acc / wma_sum;
625 }
626
627 while i + 1 < n {
628 let x_old = *data.get_unchecked(i + 1 - p);
629 let x_new = *data.get_unchecked(i + 1);
630 S = (S * inv_R) - a_old * x_old + w_last * x_new;
631 let y_new = S * inv_wsum;
632
633 ring_mu.get_unchecked_mut(head).write(y_new);
634 head = (head + 1) % m;
635
636 let mut acc = 0.0;
637 for k in 0..m {
638 let idx = (head + k) % m;
639 let v = *ring_mu.get_unchecked(idx).assume_init_ref();
640 acc += v * ((k + 1) as f64);
641 }
642 *out.get_unchecked_mut(i + 1) = acc / wma_sum;
643 i += 1;
644 }
645 }
646}
647
648#[inline]
649pub fn cora_wave_scalar(
650 data: &[f64],
651 period: usize,
652 r_multi: f64,
653 smooth_period: usize,
654 first_val: usize,
655 out: &mut [f64],
656) {
657 if period == 1 {
658 cora_wave_scalar_with_weights(data, &[1.0], 1.0, smooth_period, first_val, out);
659 return;
660 }
661 let start_wt = 0.01;
662 let end_wt = period as f64;
663 let r = (end_wt / start_wt).powf(1.0 / (period as f64 - 1.0)) - 1.0;
664 let base = 1.0 + r * r_multi;
665
666 let mut weights = Vec::with_capacity(period);
667 let mut weight_sum = 0.0;
668 for j in 0..period {
669 let w = start_wt * base.powi((j + 1) as i32);
670 weights.push(w);
671 weight_sum += w;
672 }
673
674 cora_wave_scalar_with_weights(
675 data,
676 &weights,
677 1.0 / weight_sum,
678 smooth_period,
679 first_val,
680 out,
681 );
682}
683
684#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
685#[inline]
686unsafe fn cora_wave_simd128(
687 data: &[f64],
688 period: usize,
689 r_multi: f64,
690 smooth_period: usize,
691 first_val: usize,
692 out: &mut [f64],
693) {
694 use core::arch::wasm32::*;
695
696 cora_wave_scalar(data, period, r_multi, smooth_period, first_val, out);
697}
698
699#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
700#[target_feature(enable = "avx2,fma")]
701unsafe fn cora_wave_avx2(
702 data: &[f64],
703 period: usize,
704 r_multi: f64,
705 smooth_period: usize,
706 first_val: usize,
707 out: &mut [f64],
708) {
709 cora_wave_scalar(data, period, r_multi, smooth_period, first_val, out);
710}
711
712#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
713#[target_feature(enable = "avx512f,fma")]
714unsafe fn cora_wave_avx512(
715 data: &[f64],
716 period: usize,
717 r_multi: f64,
718 smooth_period: usize,
719 first_val: usize,
720 out: &mut [f64],
721) {
722 cora_wave_scalar(data, period, r_multi, smooth_period, first_val, out);
723}
724
725#[derive(Debug, Clone)]
726pub struct CoraWaveStream {
727 period: usize,
728 r_multi: f64,
729 smooth: bool,
730 smooth_period: usize,
731
732 base: f64,
733 inv_R: f64,
734 a_old: f64,
735 w_last: f64,
736 inv_wsum: f64,
737
738 ring_x: Vec<f64>,
739 head_x: usize,
740 idx: usize,
741 have_S: bool,
742 S: f64,
743
744 m: usize,
745 wma_sum: f64,
746 ring_y: Vec<f64>,
747 head_y: usize,
748 y_count: usize,
749
750 Ssum_y: f64,
751 Wsum_y: f64,
752
753 fast_smooth: bool,
754}
755
756impl CoraWaveStream {
757 #[inline]
758 pub fn try_new(params: CoraWaveParams) -> Result<Self, CoraWaveError> {
759 let period = params.period.unwrap_or(20);
760 let r_multi = params.r_multi.unwrap_or(2.0);
761 let smooth = params.smooth.unwrap_or(true);
762
763 if period == 0 {
764 return Err(CoraWaveError::InvalidPeriod {
765 period,
766 data_len: 0,
767 });
768 }
769 if r_multi < 0.0 || !r_multi.is_finite() {
770 return Err(CoraWaveError::InvalidRMulti { value: r_multi });
771 }
772
773 let m = if smooth {
774 ((period as f64).sqrt().round() as usize).max(1)
775 } else {
776 1
777 };
778
779 let p = period;
780 let start_wt = 0.01_f64;
781
782 let end_wt = p as f64;
783 let r = (end_wt / start_wt).powf(1.0 / (p as f64 - 1.0)) - 1.0;
784 let base = 1.0 + r * r_multi;
785 let inv_R = 1.0 / base;
786
787 let a_old = start_wt;
788
789 let base_pow_p = if (base - 1.0).abs() < 1e-16 {
790 1.0
791 } else {
792 base.powi(p as i32)
793 };
794 let w_last = a_old * base_pow_p;
795
796 let weight_sum = if (base - 1.0).abs() < 1e-16 {
797 a_old * (p as f64)
798 } else {
799 a_old * base * (base_pow_p - 1.0) / (base - 1.0)
800 };
801 let inv_wsum = 1.0 / weight_sum;
802
803 let wma_sum = (m as f64) * ((m as f64) + 1.0) * 0.5;
804
805 const FAST_WMA_O1_DEFAULT: bool = false;
806
807 Ok(Self {
808 period: p,
809 r_multi,
810 smooth,
811 smooth_period: m,
812 base,
813 inv_R,
814 a_old,
815 w_last,
816 inv_wsum,
817 ring_x: vec![0.0; p],
818 head_x: 0,
819 idx: 0,
820 have_S: false,
821 S: 0.0,
822 m,
823 wma_sum,
824 ring_y: vec![0.0; m.max(1)],
825 head_y: 0,
826 y_count: 0,
827 Ssum_y: 0.0,
828 Wsum_y: 0.0,
829 fast_smooth: FAST_WMA_O1_DEFAULT,
830 })
831 }
832
833 #[inline]
834 pub fn update(&mut self, x_new: f64) -> Option<f64> {
835 let pos = self.head_x;
836 let x_old = self.ring_x[pos];
837 self.ring_x[pos] = x_new;
838 self.head_x = (pos + 1) % self.period;
839 self.idx += 1;
840
841 if !self.have_S {
842 if self.idx < self.period {
843 return None;
844 }
845
846 let mut S = 0.0;
847 let mut w = self.a_old * self.base;
848 let mut i = 0usize;
849 while i < self.period {
850 let xi = self.ring_x[(self.head_x + i) % self.period];
851 S = xi.mul_add(w, S);
852 w *= self.base;
853 i += 1;
854 }
855 self.S = S;
856 self.have_S = true;
857
858 let y = self.S * self.inv_wsum;
859 if self.m == 1 {
860 return Some(y);
861 }
862
863 self.ring_y[self.y_count] = y;
864 self.y_count += 1;
865
866 self.head_y = self.y_count % self.m;
867 if self.fast_smooth {
868 self.Ssum_y += y;
869 self.Wsum_y += (self.y_count as f64) * y;
870 }
871 return None;
872 }
873
874 self.S = (self.S * self.inv_R) - self.a_old * x_old + self.w_last * x_new;
875 let y = self.S * self.inv_wsum;
876
877 if self.m == 1 {
878 return Some(y);
879 }
880
881 if !self.fast_smooth {
882 if self.y_count < self.m {
883 self.ring_y[self.head_y] = y;
884 self.head_y = (self.head_y + 1) % self.m;
885 self.y_count += 1;
886 if self.y_count < self.m {
887 return None;
888 }
889
890 let mut acc = 0.0;
891 let mut k = 0usize;
892 while k < self.m {
893 let idx = (self.head_y + k) % self.m;
894 let v = self.ring_y[idx];
895 acc = v.mul_add((k + 1) as f64, acc);
896 k += 1;
897 }
898 return Some(acc / self.wma_sum);
899 } else {
900 self.ring_y[self.head_y] = y;
901 self.head_y = (self.head_y + 1) % self.m;
902 let mut acc = 0.0;
903 let mut k = 0usize;
904 while k < self.m {
905 let idx = (self.head_y + k) % self.m;
906 let v = self.ring_y[idx];
907 acc = v.mul_add((k + 1) as f64, acc);
908 k += 1;
909 }
910 return Some(acc / self.wma_sum);
911 }
912 } else {
913 if self.y_count < self.m {
914 self.ring_y[self.y_count] = y;
915 self.y_count += 1;
916 self.Ssum_y += y;
917 self.Wsum_y += (self.y_count as f64) * y;
918 if self.y_count < self.m {
919 return None;
920 }
921
922 self.head_y = 0;
923 return Some(self.Wsum_y / self.wma_sum);
924 }
925
926 let y_old = self.ring_y[self.head_y];
927
928 self.Wsum_y = self.Wsum_y - self.Ssum_y + (self.m as f64) * y;
929 self.ring_y[self.head_y] = y;
930 self.Ssum_y = self.Ssum_y + y - y_old;
931 self.head_y = (self.head_y + 1) % self.m;
932
933 Some(self.Wsum_y / self.wma_sum)
934 }
935 }
936}
937
938#[derive(Clone, Debug)]
939pub struct CoraWaveBatchRange {
940 pub period: (usize, usize, usize),
941 pub r_multi: (f64, f64, f64),
942 pub smooth: bool,
943}
944
945impl Default for CoraWaveBatchRange {
946 fn default() -> Self {
947 Self {
948 period: (20, 20, 0),
949 r_multi: (2.0, 2.249, 0.001),
950 smooth: true,
951 }
952 }
953}
954
955#[derive(Clone, Debug)]
956pub struct CoraWaveBatchOutput {
957 pub values: Vec<f64>,
958 pub combos: Vec<CoraWaveParams>,
959 pub rows: usize,
960 pub cols: usize,
961}
962
963impl CoraWaveBatchOutput {
964 pub fn row_for_params(&self, p: &CoraWaveParams) -> Option<usize> {
965 self.combos.iter().position(|c| {
966 c.period.unwrap_or(20) == p.period.unwrap_or(20)
967 && (c.r_multi.unwrap_or(2.0) - p.r_multi.unwrap_or(2.0)).abs() < 1e-12
968 && c.smooth.unwrap_or(true) == p.smooth.unwrap_or(true)
969 })
970 }
971
972 pub fn values_for(&self, p: &CoraWaveParams) -> Option<&[f64]> {
973 self.row_for_params(p).map(|row| {
974 let start = row * self.cols;
975 &self.values[start..start + self.cols]
976 })
977 }
978}
979
980#[inline(always)]
981fn axis_usize((s, e, t): (usize, usize, usize)) -> Result<Vec<usize>, CoraWaveError> {
982 if t == 0 || s == e {
983 return Ok(vec![s]);
984 }
985 let mut v = Vec::new();
986 if s < e {
987 v = (s..=e).step_by(t).collect();
988 } else if s > e {
989 let mut x = s;
990 loop {
991 v.push(x);
992 if x <= e {
993 break;
994 }
995 if x < t {
996 break;
997 }
998 let next = x - t;
999 if next < e {
1000 break;
1001 }
1002 x = next;
1003 }
1004
1005 if *v.last().unwrap_or(&s) != e && s >= e {}
1006 }
1007 if v.is_empty() {
1008 return Err(CoraWaveError::InvalidRange {
1009 start: s as f64,
1010 end: e as f64,
1011 step: t as f64,
1012 });
1013 }
1014 Ok(v)
1015}
1016#[inline(always)]
1017fn axis_f64((s, e, t): (f64, f64, f64)) -> Result<Vec<f64>, CoraWaveError> {
1018 if t.abs() < 1e-12 || (s - e).abs() < 1e-12 {
1019 return Ok(vec![s]);
1020 }
1021 let step = t.abs();
1022 let mut v = Vec::new();
1023 if s <= e {
1024 let mut x = s;
1025 while x <= e + 1e-12 {
1026 v.push(x);
1027 x += step;
1028 }
1029 } else {
1030 let mut x = s;
1031 while x >= e - 1e-12 {
1032 v.push(x);
1033 x -= step;
1034 }
1035 }
1036 if v.is_empty() {
1037 return Err(CoraWaveError::InvalidRange {
1038 start: s,
1039 end: e,
1040 step: t,
1041 });
1042 }
1043 Ok(v)
1044}
1045#[inline(always)]
1046fn expand_grid_cw(r: &CoraWaveBatchRange) -> Result<Vec<CoraWaveParams>, CoraWaveError> {
1047 let periods = axis_usize(r.period)?;
1048 let mults = axis_f64(r.r_multi)?;
1049 if periods.is_empty() || mults.is_empty() {
1050 return Err(CoraWaveError::InvalidRange {
1051 start: r.period.0 as f64,
1052 end: r.period.1 as f64,
1053 step: r.period.2 as f64,
1054 });
1055 }
1056 let cap = periods
1057 .len()
1058 .checked_mul(mults.len())
1059 .ok_or_else(|| CoraWaveError::InvalidInput("periods*mults overflow".into()))?;
1060 let mut out = Vec::with_capacity(cap);
1061 for &p in &periods {
1062 for &m in &mults {
1063 out.push(CoraWaveParams {
1064 period: Some(p),
1065 r_multi: Some(m),
1066 smooth: Some(r.smooth),
1067 });
1068 }
1069 }
1070 Ok(out)
1071}
1072
1073#[inline(always)]
1074pub fn cora_wave_batch_slice(
1075 data: &[f64],
1076 sweep: &CoraWaveBatchRange,
1077 kern: Kernel,
1078) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1079 cora_wave_batch_inner(data, sweep, kern, false)
1080}
1081
1082#[inline(always)]
1083pub fn cora_wave_batch_par_slice(
1084 data: &[f64],
1085 sweep: &CoraWaveBatchRange,
1086 kern: Kernel,
1087) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1088 cora_wave_batch_inner(data, sweep, kern, true)
1089}
1090
1091pub fn cora_wave_batch_with_kernel(
1092 data: &[f64],
1093 sweep: &CoraWaveBatchRange,
1094 k: Kernel,
1095) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1096 let kernel = match k {
1097 Kernel::Auto => detect_best_batch_kernel(),
1098 other if other.is_batch() => other,
1099 _ => return Err(CoraWaveError::InvalidKernelForBatch(k)),
1100 };
1101 let simd = match kernel {
1102 Kernel::Avx512Batch => Kernel::Avx512,
1103 Kernel::Avx2Batch => Kernel::Avx2,
1104 Kernel::ScalarBatch => Kernel::Scalar,
1105 _ => unreachable!(),
1106 };
1107 cora_wave_batch_par_slice(data, sweep, simd)
1108}
1109
1110#[inline(always)]
1111fn cora_wave_batch_inner(
1112 data: &[f64],
1113 sweep: &CoraWaveBatchRange,
1114 kern: Kernel,
1115 parallel: bool,
1116) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1117 let combos = expand_grid_cw(sweep)?;
1118 if combos.is_empty() {
1119 return Err(CoraWaveError::InvalidRange {
1120 start: sweep.period.0 as f64,
1121 end: sweep.period.1 as f64,
1122 step: sweep.period.2 as f64,
1123 });
1124 }
1125
1126 let cols = data.len();
1127 if cols == 0 {
1128 return Err(CoraWaveError::AllValuesNaN);
1129 }
1130
1131 let first = data
1132 .iter()
1133 .position(|x| !x.is_nan())
1134 .ok_or(CoraWaveError::AllValuesNaN)?;
1135 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1136 if cols - first < max_p {
1137 return Err(CoraWaveError::NotEnoughValidData {
1138 needed: max_p,
1139 valid: cols - first,
1140 });
1141 }
1142
1143 let rows = combos.len();
1144 let _total = rows
1145 .checked_mul(cols)
1146 .ok_or_else(|| CoraWaveError::InvalidInput("rows*cols overflow".into()))?;
1147 let mut buf_mu = make_uninit_matrix(rows, cols);
1148
1149 let warms: Vec<usize> = combos
1150 .iter()
1151 .map(|c| {
1152 let p = c.period.unwrap();
1153 let sp = if c.smooth.unwrap_or(true) {
1154 ((p as f64).sqrt().round() as usize).max(1)
1155 } else {
1156 1
1157 };
1158 first + p - 1 + sp.saturating_sub(1)
1159 })
1160 .collect();
1161 init_matrix_prefixes(&mut buf_mu, cols, &warms);
1162
1163 let flat_len = rows
1164 .checked_mul(max_p)
1165 .ok_or_else(|| CoraWaveError::InvalidInput("rows*max_period overflow".into()))?;
1166 let mut flat_w = vec![0.0f64; flat_len];
1167 let mut inv_sums = vec![0.0f64; rows];
1168
1169 for (row, prm) in combos.iter().enumerate() {
1170 let p = prm.period.unwrap();
1171 let r_multi = prm.r_multi.unwrap();
1172
1173 if p == 1 {
1174 flat_w[row * max_p] = 1.0;
1175 inv_sums[row] = 1.0;
1176 } else {
1177 let start_wt = 0.01;
1178 let end_wt = p as f64;
1179 let r = (end_wt / start_wt).powf(1.0 / (p as f64 - 1.0)) - 1.0;
1180 let base = 1.0 + r * r_multi;
1181
1182 let mut sum = 0.0;
1183 for j in 0..p {
1184 let w = start_wt * base.powi((j + 1) as i32);
1185 flat_w[row * max_p + j] = w;
1186 sum += w;
1187 }
1188 inv_sums[row] = 1.0 / sum;
1189 }
1190 }
1191
1192 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1193 let out_uninit: &mut [MaybeUninit<f64>] =
1194 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr(), guard.len()) };
1195
1196 let actual = match kern {
1197 Kernel::Auto => detect_best_batch_kernel(),
1198 k => k,
1199 };
1200
1201 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1202 let p = combos[row].period.unwrap();
1203 let sp = if combos[row].smooth.unwrap_or(true) {
1204 ((p as f64).sqrt().round() as usize).max(1)
1205 } else {
1206 1
1207 };
1208 let w_ptr = flat_w[row * max_p..].as_ptr();
1209 let inv = inv_sums[row];
1210
1211 let dst = unsafe { core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, cols) };
1212
1213 match actual {
1214 Kernel::Scalar
1215 | Kernel::ScalarBatch
1216 | Kernel::Avx2
1217 | Kernel::Avx2Batch
1218 | Kernel::Avx512
1219 | Kernel::Avx512Batch => unsafe {
1220 cora_wave_row_scalar_with_weights(data, first, p, w_ptr, inv, sp, dst)
1221 },
1222 _ => unreachable!(),
1223 }
1224 };
1225
1226 #[cfg(not(target_arch = "wasm32"))]
1227 {
1228 use rayon::prelude::*;
1229 if parallel {
1230 out_uninit
1231 .par_chunks_mut(cols)
1232 .enumerate()
1233 .for_each(|(row, slice)| do_row(row, slice));
1234 } else {
1235 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1236 do_row(row, slice);
1237 }
1238 }
1239 }
1240 #[cfg(target_arch = "wasm32")]
1241 {
1242 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
1243 do_row(row, slice);
1244 }
1245 }
1246
1247 let values = unsafe {
1248 Vec::from_raw_parts(
1249 guard.as_mut_ptr() as *mut f64,
1250 guard.len(),
1251 guard.capacity(),
1252 )
1253 };
1254
1255 Ok(CoraWaveBatchOutput {
1256 values,
1257 combos,
1258 rows,
1259 cols,
1260 })
1261}
1262
1263#[inline(always)]
1264pub fn cora_wave_batch_inner_into(
1265 data: &[f64],
1266 sweep: &CoraWaveBatchRange,
1267 kern: Kernel,
1268 parallel: bool,
1269 out: &mut [f64],
1270) -> Result<Vec<CoraWaveParams>, CoraWaveError> {
1271 let combos = expand_grid_cw(sweep)?;
1272 if combos.is_empty() {
1273 return Err(CoraWaveError::InvalidRange {
1274 start: sweep.period.0 as f64,
1275 end: sweep.period.1 as f64,
1276 step: sweep.period.2 as f64,
1277 });
1278 }
1279
1280 let cols = data.len();
1281 let rows = combos.len();
1282 if cols == 0 {
1283 return Err(CoraWaveError::AllValuesNaN);
1284 }
1285 let expected = rows
1286 .checked_mul(cols)
1287 .ok_or_else(|| CoraWaveError::InvalidInput("rows*cols overflow".into()))?;
1288 if out.len() != expected {
1289 return Err(CoraWaveError::OutputLengthMismatch {
1290 expected,
1291 got: out.len(),
1292 });
1293 }
1294
1295 let first = data
1296 .iter()
1297 .position(|x| !x.is_nan())
1298 .ok_or(CoraWaveError::AllValuesNaN)?;
1299 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1300 if cols - first < max_p {
1301 return Err(CoraWaveError::NotEnoughValidData {
1302 needed: max_p,
1303 valid: cols - first,
1304 });
1305 }
1306
1307 let warms: Vec<usize> = combos
1308 .iter()
1309 .map(|c| {
1310 let p = c.period.unwrap();
1311 let sp = if c.smooth.unwrap_or(true) {
1312 ((p as f64).sqrt().round() as usize).max(1)
1313 } else {
1314 1
1315 };
1316 first + p - 1 + sp.saturating_sub(1)
1317 })
1318 .collect();
1319
1320 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1321 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1322 };
1323 init_matrix_prefixes(out_mu, cols, &warms);
1324
1325 let flat_len = rows
1326 .checked_mul(max_p)
1327 .ok_or_else(|| CoraWaveError::InvalidInput("rows*max_period overflow".into()))?;
1328 let mut flat_w = vec![0.0f64; flat_len];
1329 let mut inv_sums = vec![0.0f64; rows];
1330 for (row, prm) in combos.iter().enumerate() {
1331 let p = prm.period.unwrap();
1332 let r_multi = prm.r_multi.unwrap();
1333
1334 if p == 1 {
1335 flat_w[row * max_p] = 1.0;
1336 inv_sums[row] = 1.0;
1337 } else {
1338 let start_wt = 0.01;
1339 let end_wt = p as f64;
1340 let r = (end_wt / start_wt).powf(1.0 / (p as f64 - 1.0)) - 1.0;
1341 let base = 1.0 + r * r_multi;
1342
1343 let mut sum = 0.0;
1344 for j in 0..p {
1345 let w = start_wt * base.powi((j + 1) as i32);
1346 flat_w[row * max_p + j] = w;
1347 sum += w;
1348 }
1349 inv_sums[row] = 1.0 / sum;
1350 }
1351 }
1352
1353 let actual = match kern {
1354 Kernel::Auto => detect_best_batch_kernel(),
1355 k => k,
1356 };
1357
1358 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1359 let p = combos[row].period.unwrap();
1360 let sp = if combos[row].smooth.unwrap_or(true) {
1361 ((p as f64).sqrt().round() as usize).max(1)
1362 } else {
1363 1
1364 };
1365 let w_ptr = flat_w[row * max_p..].as_ptr();
1366 let inv = inv_sums[row];
1367
1368 let dst: &mut [f64] =
1369 unsafe { core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, cols) };
1370 match actual {
1371 Kernel::Scalar
1372 | Kernel::ScalarBatch
1373 | Kernel::Avx2
1374 | Kernel::Avx2Batch
1375 | Kernel::Avx512
1376 | Kernel::Avx512Batch => unsafe {
1377 cora_wave_row_scalar_with_weights(data, first, p, w_ptr, inv, sp, dst)
1378 },
1379 _ => unreachable!(),
1380 }
1381 };
1382
1383 #[cfg(not(target_arch = "wasm32"))]
1384 {
1385 if parallel {
1386 use rayon::prelude::*;
1387 out_mu
1388 .par_chunks_mut(cols)
1389 .enumerate()
1390 .for_each(|(row, slice)| do_row(row, slice));
1391 } else {
1392 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1393 do_row(row, slice);
1394 }
1395 }
1396 }
1397 #[cfg(target_arch = "wasm32")]
1398 {
1399 for (row, slice) in out_mu.chunks_mut(cols).enumerate() {
1400 do_row(row, slice);
1401 }
1402 }
1403
1404 Ok(combos)
1405}
1406
1407#[inline(always)]
1408unsafe fn cora_wave_row_scalar_with_weights(
1409 data: &[f64],
1410 first: usize,
1411 period: usize,
1412 w_ptr: *const f64,
1413 inv_wsum: f64,
1414 smooth_period: usize,
1415 out: &mut [f64],
1416) {
1417 let n = data.len();
1418 let p = period;
1419 if p == 0 || n == 0 {
1420 return;
1421 }
1422
1423 if smooth_period == 1 {
1424 if p == 1 {
1425 let warm0 = first;
1426 let mut i = warm0;
1427 while i < n {
1428 *out.get_unchecked_mut(i) = *data.get_unchecked(i) * inv_wsum;
1429 i += 1;
1430 }
1431 return;
1432 }
1433
1434 let w0 = *w_ptr.add(0);
1435 let w1 = *w_ptr.add(1);
1436 let inv_R = w0 / w1;
1437 let a_old = w0 * inv_R;
1438 let w_last = *w_ptr.add(p - 1);
1439
1440 let warm0 = first + p - 1;
1441 if warm0 >= n {
1442 return;
1443 }
1444 let start0 = warm0 + 1 - p;
1445
1446 let mut acc0 = 0.0;
1447 let mut acc1 = 0.0;
1448 let mut acc2 = 0.0;
1449 let mut acc3 = 0.0;
1450 let mut j = 0usize;
1451 let end4 = p & !3usize;
1452 let xptr = data.as_ptr().add(start0);
1453
1454 while j < end4 {
1455 let x0 = *xptr.add(j);
1456 let x1 = *xptr.add(j + 1);
1457 let x2 = *xptr.add(j + 2);
1458 let x3 = *xptr.add(j + 3);
1459
1460 let y0 = *w_ptr.add(j);
1461 let y1 = *w_ptr.add(j + 1);
1462 let y2 = *w_ptr.add(j + 2);
1463 let y3 = *w_ptr.add(j + 3);
1464
1465 acc0 = x0.mul_add(y0, acc0);
1466 acc1 = x1.mul_add(y1, acc1);
1467 acc2 = x2.mul_add(y2, acc2);
1468 acc3 = x3.mul_add(y3, acc3);
1469
1470 j += 4;
1471 }
1472 let mut S = (acc0 + acc1) + (acc2 + acc3);
1473 while j < p {
1474 let x = *xptr.add(j);
1475 let y = *w_ptr.add(j);
1476 S = x.mul_add(y, S);
1477 j += 1;
1478 }
1479
1480 *out.get_unchecked_mut(warm0) = S * inv_wsum;
1481
1482 let mut i = warm0;
1483 while i + 1 < n {
1484 let x_old = *data.get_unchecked(i + 1 - p);
1485 let x_new = *data.get_unchecked(i + 1);
1486 S = (S * inv_R) - a_old * x_old + w_last * x_new;
1487 *out.get_unchecked_mut(i + 1) = S * inv_wsum;
1488 i += 1;
1489 }
1490 return;
1491 }
1492
1493 let m = smooth_period;
1494 let wma_sum = (m as f64) * ((m as f64) + 1.0) * 0.5;
1495
1496 if p == 1 {
1497 let warm0 = first;
1498 if warm0 >= n {
1499 return;
1500 }
1501
1502 let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
1503 let mut fill = 0usize;
1504
1505 let warm_total = warm0 + m - 1;
1506 let mut i = warm0;
1507 while i <= warm_total && i < n {
1508 ring_mu
1509 .get_unchecked_mut(fill)
1510 .write(*data.get_unchecked(i));
1511 fill += 1;
1512 i += 1;
1513 }
1514 if warm_total >= n {
1515 return;
1516 }
1517
1518 let mut Ssum = 0.0;
1519 let mut Wsum = 0.0;
1520 for k in 0..m {
1521 let v = *ring_mu.get_unchecked(k).assume_init_ref();
1522 Ssum += v;
1523 Wsum += v * ((k + 1) as f64);
1524 }
1525 let mut head = 0usize;
1526 let mut t = warm_total;
1527 *out.get_unchecked_mut(t) = Wsum / wma_sum;
1528
1529 while t + 1 < n {
1530 let y_old = *ring_mu.get_unchecked(head).assume_init_ref();
1531 let y_new = *data.get_unchecked(t + 1);
1532
1533 Wsum = Wsum - Ssum + (m as f64) * y_new;
1534
1535 ring_mu.get_unchecked_mut(head).write(y_new);
1536 Ssum = Ssum + y_new - y_old;
1537 head = (head + 1) % m;
1538
1539 *out.get_unchecked_mut(t + 1) = Wsum / wma_sum;
1540 t += 1;
1541 }
1542 return;
1543 }
1544
1545 let w0 = *w_ptr.add(0);
1546 let w1 = *w_ptr.add(1);
1547 let inv_R = w0 / w1;
1548 let a_old = w0 * inv_R;
1549 let w_last = *w_ptr.add(p - 1);
1550
1551 let warm0 = first + p - 1;
1552 if warm0 >= n {
1553 return;
1554 }
1555 let start0 = warm0 + 1 - p;
1556
1557 let mut acc0 = 0.0;
1558 let mut acc1 = 0.0;
1559 let mut acc2 = 0.0;
1560 let mut acc3 = 0.0;
1561 let mut j = 0usize;
1562 let end4 = p & !3usize;
1563 let xptr = data.as_ptr().add(start0);
1564 while j < end4 {
1565 let x0 = *xptr.add(j);
1566 let x1 = *xptr.add(j + 1);
1567 let x2 = *xptr.add(j + 2);
1568 let x3 = *xptr.add(j + 3);
1569
1570 let y0 = *w_ptr.add(j);
1571 let y1 = *w_ptr.add(j + 1);
1572 let y2 = *w_ptr.add(j + 2);
1573 let y3 = *w_ptr.add(j + 3);
1574
1575 acc0 = x0.mul_add(y0, acc0);
1576 acc1 = x1.mul_add(y1, acc1);
1577 acc2 = x2.mul_add(y2, acc2);
1578 acc3 = x3.mul_add(y3, acc3);
1579
1580 j += 4;
1581 }
1582 let mut S = (acc0 + acc1) + (acc2 + acc3);
1583 while j < p {
1584 let x = *xptr.add(j);
1585 let y = *w_ptr.add(j);
1586 S = x.mul_add(y, S);
1587 j += 1;
1588 }
1589
1590 let mut ring_mu: Vec<MaybeUninit<f64>> = make_uninit_matrix(1, m);
1591 let mut fill = 0usize;
1592
1593 let mut y = S * inv_wsum;
1594 ring_mu.get_unchecked_mut(fill).write(y);
1595 fill += 1;
1596
1597 let warm_total = warm0 + m - 1;
1598 let mut i = warm0;
1599 while i + 1 <= warm_total && i + 1 < n {
1600 let x_old = *data.get_unchecked(i + 1 - p);
1601 let x_new = *data.get_unchecked(i + 1);
1602 S = (S * inv_R) - a_old * x_old + w_last * x_new;
1603 y = S * inv_wsum;
1604 ring_mu.get_unchecked_mut(fill).write(y);
1605 fill += 1;
1606 i += 1;
1607 }
1608 if warm_total >= n {
1609 return;
1610 }
1611
1612 let mut Ssum = 0.0;
1613 let mut Wsum = 0.0;
1614 for k in 0..m {
1615 let v = *ring_mu.get_unchecked(k).assume_init_ref();
1616 Ssum += v;
1617 Wsum += v * ((k + 1) as f64);
1618 }
1619 let mut head = 0usize;
1620 *out.get_unchecked_mut(warm_total) = Wsum / wma_sum;
1621
1622 while i + 1 < n {
1623 let x_old = *data.get_unchecked(i + 1 - p);
1624 let x_new = *data.get_unchecked(i + 1);
1625 S = (S * inv_R) - a_old * x_old + w_last * x_new;
1626 let y_new = S * inv_wsum;
1627
1628 Wsum = Wsum - Ssum + (m as f64) * y_new;
1629
1630 let y_old = *ring_mu.get_unchecked(head).assume_init_ref();
1631 ring_mu.get_unchecked_mut(head).write(y_new);
1632 Ssum = Ssum + y_new - y_old;
1633 head = (head + 1) % m;
1634
1635 *out.get_unchecked_mut(i + 1) = Wsum / wma_sum;
1636 i += 1;
1637 }
1638}
1639
1640#[derive(Clone, Debug, Default)]
1641pub struct CoraWaveBatchBuilder {
1642 range: CoraWaveBatchRange,
1643 kernel: Kernel,
1644}
1645
1646impl CoraWaveBatchBuilder {
1647 pub fn new() -> Self {
1648 Self::default()
1649 }
1650
1651 pub fn kernel(mut self, k: Kernel) -> Self {
1652 self.kernel = k;
1653 self
1654 }
1655
1656 #[inline]
1657 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1658 self.range.period = (start, end, step);
1659 self
1660 }
1661
1662 #[inline]
1663 pub fn period_static(mut self, val: usize) -> Self {
1664 self.range.period = (val, val, 0);
1665 self
1666 }
1667
1668 #[inline]
1669 pub fn r_multi_range(mut self, start: f64, end: f64, step: f64) -> Self {
1670 self.range.r_multi = (start, end, step);
1671 self
1672 }
1673
1674 #[inline]
1675 pub fn r_multi_static(mut self, val: f64) -> Self {
1676 self.range.r_multi = (val, val, 0.0);
1677 self
1678 }
1679
1680 #[inline]
1681 pub fn smooth(mut self, val: bool) -> Self {
1682 self.range.smooth = val;
1683 self
1684 }
1685
1686 pub fn apply_slice(self, data: &[f64]) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1687 cora_wave_batch_with_kernel(data, &self.range, self.kernel)
1688 }
1689
1690 pub fn apply_candles(
1691 self,
1692 c: &Candles,
1693 src: &str,
1694 ) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1695 let slice = source_type(c, src);
1696 self.apply_slice(slice)
1697 }
1698
1699 pub fn with_default_candles(c: &Candles) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1700 CoraWaveBatchBuilder::new()
1701 .kernel(Kernel::Auto)
1702 .apply_candles(c, "close")
1703 }
1704
1705 pub fn with_default_slice(
1706 data: &[f64],
1707 k: Kernel,
1708 ) -> Result<CoraWaveBatchOutput, CoraWaveError> {
1709 CoraWaveBatchBuilder::new().kernel(k).apply_slice(data)
1710 }
1711}
1712
1713#[cfg(feature = "python")]
1714#[pyfunction(name = "cora_wave")]
1715#[pyo3(signature = (data, period, r_multi, smooth, kernel=None))]
1716pub fn cora_wave_py<'py>(
1717 py: Python<'py>,
1718 data: PyReadonlyArray1<'py, f64>,
1719 period: usize,
1720 r_multi: f64,
1721 smooth: bool,
1722 kernel: Option<&str>,
1723) -> PyResult<Bound<'py, PyArray1<f64>>> {
1724 let slice_in = data.as_slice()?;
1725 let kern = validate_kernel(kernel, false)?;
1726 let params = CoraWaveParams {
1727 period: Some(period),
1728 r_multi: Some(r_multi),
1729 smooth: Some(smooth),
1730 };
1731 let input = CoraWaveInput::from_slice(slice_in, params);
1732
1733 let result_vec: Vec<f64> = py
1734 .allow_threads(|| cora_wave_with_kernel(&input, kern).map(|o| o.values))
1735 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1736
1737 Ok(result_vec.into_pyarray(py))
1738}
1739
1740#[cfg(feature = "python")]
1741#[pyclass(name = "CoraWaveStream")]
1742pub struct CoraWaveStreamPy {
1743 stream: CoraWaveStream,
1744}
1745
1746#[cfg(feature = "python")]
1747#[pymethods]
1748impl CoraWaveStreamPy {
1749 #[new]
1750 fn new(period: usize, r_multi: f64, smooth: bool) -> PyResult<Self> {
1751 let params = CoraWaveParams {
1752 period: Some(period),
1753 r_multi: Some(r_multi),
1754 smooth: Some(smooth),
1755 };
1756 let stream =
1757 CoraWaveStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1758 Ok(CoraWaveStreamPy { stream })
1759 }
1760
1761 fn update(&mut self, value: f64) -> Option<f64> {
1762 self.stream.update(value)
1763 }
1764}
1765
1766#[cfg(feature = "python")]
1767#[pyfunction(name = "cora_wave_batch")]
1768#[pyo3(signature = (data, period_range, r_multi_range, smooth=true, kernel=None))]
1769pub fn cora_wave_batch_py<'py>(
1770 py: Python<'py>,
1771 data: PyReadonlyArray1<'py, f64>,
1772 period_range: (usize, usize, usize),
1773 r_multi_range: (f64, f64, f64),
1774 smooth: bool,
1775 kernel: Option<&str>,
1776) -> PyResult<Bound<'py, PyDict>> {
1777 use numpy::{IntoPyArray, PyArray1};
1778 let slice_in = data.as_slice()?;
1779 let sweep = CoraWaveBatchRange {
1780 period: period_range,
1781 r_multi: r_multi_range,
1782 smooth,
1783 };
1784
1785 let combos = expand_grid_cw(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1786 let rows = combos.len();
1787 let cols = slice_in.len();
1788
1789 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1790 let out_slice = unsafe { out_arr.as_slice_mut()? };
1791
1792 let kern = validate_kernel(kernel, true)?;
1793 let combos = py
1794 .allow_threads(|| {
1795 let kernel = match kern {
1796 Kernel::Auto => detect_best_batch_kernel(),
1797 k => k,
1798 };
1799 let simd = match kernel {
1800 Kernel::Avx512Batch => Kernel::Avx512,
1801 Kernel::Avx2Batch => Kernel::Avx2,
1802 Kernel::ScalarBatch => Kernel::Scalar,
1803 _ => unreachable!(),
1804 };
1805 cora_wave_batch_inner_into(slice_in, &sweep, simd, true, out_slice)
1806 })
1807 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1808
1809 let dict = PyDict::new(py);
1810 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1811 dict.set_item(
1812 "periods",
1813 combos
1814 .iter()
1815 .map(|p| p.period.unwrap() as u64)
1816 .collect::<Vec<_>>()
1817 .into_pyarray(py),
1818 )?;
1819 dict.set_item(
1820 "r_multis",
1821 combos
1822 .iter()
1823 .map(|p| p.r_multi.unwrap())
1824 .collect::<Vec<_>>()
1825 .into_pyarray(py),
1826 )?;
1827 dict.set_item("smooth", smooth)?;
1828 Ok(dict)
1829}
1830
1831#[cfg(all(feature = "python", feature = "cuda"))]
1832#[pyfunction(name = "cora_wave_cuda_batch_dev")]
1833#[pyo3(signature = (data_f32, period_range, r_multi_range=(2.0,2.0,0.0), smooth=true, device_id=0))]
1834pub fn cora_wave_cuda_batch_dev_py<'py>(
1835 py: Python<'py>,
1836 data_f32: PyReadonlyArray1<'py, f32>,
1837 period_range: (usize, usize, usize),
1838 r_multi_range: (f64, f64, f64),
1839 smooth: bool,
1840 device_id: usize,
1841) -> PyResult<(DeviceArrayF32Py, Bound<'py, PyDict>)> {
1842 if !cuda_available() {
1843 return Err(PyValueError::new_err("CUDA not available"));
1844 }
1845 let slice_in = data_f32.as_slice()?;
1846 let sweep = CoraWaveBatchRange {
1847 period: period_range,
1848 r_multi: r_multi_range,
1849 smooth,
1850 };
1851
1852 fn combos_for_py(sweep: &CoraWaveBatchRange) -> Vec<CoraWaveParams> {
1853 let (ps, pe, pt) = sweep.period;
1854 let periods: Vec<usize> = if pt == 0 || ps == pe {
1855 vec![ps]
1856 } else if ps <= pe {
1857 (ps..=pe).step_by(pt).collect()
1858 } else {
1859 let mut v = Vec::new();
1860 let mut x = ps;
1861 loop {
1862 v.push(x);
1863 if x <= pe {
1864 break;
1865 }
1866 if x < pt {
1867 break;
1868 }
1869 let next = x - pt;
1870 if next < pe {
1871 break;
1872 }
1873 x = next;
1874 }
1875 v
1876 };
1877 let (ms, me, mt) = sweep.r_multi;
1878 let mut mults: Vec<f64> = vec![];
1879 if mt.abs() < 1e-12 || (ms - me).abs() < 1e-12 {
1880 mults.push(ms);
1881 } else if ms <= me {
1882 let mut x = ms;
1883 let step = mt.abs();
1884 while x <= me + 1e-12 {
1885 mults.push(x);
1886 x += step;
1887 }
1888 } else {
1889 let mut x = ms;
1890 let step = mt.abs();
1891 while x >= me - 1e-12 {
1892 mults.push(x);
1893 x -= step;
1894 }
1895 }
1896 let mut out = Vec::with_capacity(periods.len().saturating_mul(mults.len()));
1897 for &p in &periods {
1898 for &m in &mults {
1899 out.push(CoraWaveParams {
1900 period: Some(p),
1901 r_multi: Some(m),
1902 smooth: Some(sweep.smooth),
1903 });
1904 }
1905 }
1906 out
1907 }
1908
1909 let (inner, ctx_arc, dev_id, combos) = py.allow_threads(|| {
1910 let cuda =
1911 CudaCoraWave::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1912 let ctx = cuda.context_arc();
1913 let dev = cuda.device_id();
1914 let out = cuda
1915 .cora_wave_batch_dev(slice_in, &sweep)
1916 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1917 Ok::<
1918 (
1919 DeviceArrayF32,
1920 std::sync::Arc<cust::context::Context>,
1921 u32,
1922 Vec<CoraWaveParams>,
1923 ),
1924 PyErr,
1925 >((out, ctx, dev, combos_for_py(&sweep)))
1926 })?;
1927
1928 let dict = PyDict::new(py);
1929 use numpy::PyArrayMethods;
1930 let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1931 let r_multis: Vec<f64> = combos.iter().map(|c| c.r_multi.unwrap()).collect();
1932 dict.set_item("periods", periods.into_pyarray(py))?;
1933 dict.set_item("r_multis", r_multis.into_pyarray(py))?;
1934 dict.set_item("smooth", smooth)?;
1935 Ok((
1936 DeviceArrayF32Py {
1937 inner,
1938 _ctx: Some(ctx_arc),
1939 device_id: Some(dev_id),
1940 },
1941 dict,
1942 ))
1943}
1944
1945#[cfg(all(feature = "python", feature = "cuda"))]
1946#[pyfunction(name = "cora_wave_cuda_many_series_one_param_dev")]
1947#[pyo3(signature = (data_tm_f32, cols, rows, period, r_multi=2.0, smooth=true, device_id=0))]
1948pub fn cora_wave_cuda_many_series_one_param_dev_py<'py>(
1949 py: Python<'py>,
1950 data_tm_f32: PyReadonlyArray1<'py, f32>,
1951 cols: usize,
1952 rows: usize,
1953 period: usize,
1954 r_multi: f64,
1955 smooth: bool,
1956 device_id: usize,
1957) -> PyResult<DeviceArrayF32Py> {
1958 if !cuda_available() {
1959 return Err(PyValueError::new_err("CUDA not available"));
1960 }
1961 let slice = data_tm_f32.as_slice()?;
1962 let params = CoraWaveParams {
1963 period: Some(period),
1964 r_multi: Some(r_multi),
1965 smooth: Some(smooth),
1966 };
1967 let (inner, ctx_arc, dev_id) = py.allow_threads(|| {
1968 let cuda =
1969 CudaCoraWave::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1970 let ctx = cuda.context_arc();
1971 let dev = cuda.device_id();
1972 let out = cuda
1973 .cora_wave_multi_series_one_param_time_major_dev(slice, cols, rows, ¶ms)
1974 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1975 Ok::<(DeviceArrayF32, std::sync::Arc<cust::context::Context>, u32), PyErr>((out, ctx, dev))
1976 })?;
1977 Ok(DeviceArrayF32Py {
1978 inner,
1979 _ctx: Some(ctx_arc),
1980 device_id: Some(dev_id),
1981 })
1982}
1983
1984#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1985#[wasm_bindgen]
1986pub fn cora_wave_js(
1987 data: &[f64],
1988 period: usize,
1989 r_multi: f64,
1990 smooth: bool,
1991) -> Result<Vec<f64>, JsValue> {
1992 let params = CoraWaveParams {
1993 period: Some(period),
1994 r_multi: Some(r_multi),
1995 smooth: Some(smooth),
1996 };
1997 let input = CoraWaveInput::from_slice(data, params);
1998
1999 let mut output = vec![0.0; data.len()];
2000 cora_wave_into_slice(&mut output, &input, detect_best_kernel())
2001 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2002
2003 Ok(output)
2004}
2005
2006#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2007#[wasm_bindgen]
2008pub fn cora_wave_alloc(len: usize) -> *mut f64 {
2009 let mut vec = Vec::<f64>::with_capacity(len);
2010 let ptr = vec.as_mut_ptr();
2011 std::mem::forget(vec);
2012 ptr
2013}
2014
2015#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2016#[wasm_bindgen]
2017pub fn cora_wave_free(ptr: *mut f64, len: usize) {
2018 unsafe {
2019 let _ = Vec::from_raw_parts(ptr, len, len);
2020 }
2021}
2022
2023#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2024#[wasm_bindgen]
2025pub fn cora_wave_into(
2026 in_ptr: *const f64,
2027 out_ptr: *mut f64,
2028 len: usize,
2029 period: usize,
2030 r_multi: f64,
2031 smooth: bool,
2032) -> Result<(), JsValue> {
2033 if in_ptr.is_null() || out_ptr.is_null() {
2034 return Err(JsValue::from_str("null pointer passed to cora_wave_into"));
2035 }
2036 if period == 0 || period > len {
2037 return Err(JsValue::from_str("Invalid period"));
2038 }
2039
2040 unsafe {
2041 let data = std::slice::from_raw_parts(in_ptr, len);
2042
2043 let params = CoraWaveParams {
2044 period: Some(period),
2045 r_multi: Some(r_multi),
2046 smooth: Some(smooth),
2047 };
2048 let input = CoraWaveInput::from_slice(data, params);
2049
2050 if in_ptr == out_ptr {
2051 let mut temp = vec![0.0; len];
2052 cora_wave_into_slice(&mut temp, &input, detect_best_kernel())
2053 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2054 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2055 out.copy_from_slice(&temp);
2056 } else {
2057 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2058 cora_wave_into_slice(out, &input, detect_best_kernel())
2059 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2060 }
2061
2062 Ok(())
2063 }
2064}
2065
2066#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2067#[derive(Serialize, Deserialize)]
2068pub struct CoraWaveBatchConfig {
2069 pub period_range: (usize, usize, usize),
2070 pub r_multi_range: (f64, f64, f64),
2071 pub smooth: bool,
2072}
2073
2074#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2075#[derive(Serialize, Deserialize)]
2076pub struct CoraWaveBatchJsOutput {
2077 pub values: Vec<f64>,
2078 pub combos: Vec<CoraWaveParams>,
2079 pub rows: usize,
2080 pub cols: usize,
2081}
2082
2083#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2084#[wasm_bindgen(js_name = cora_wave_batch)]
2085pub fn cora_wave_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2086 let cfg: CoraWaveBatchConfig = serde_wasm_bindgen::from_value(config)
2087 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2088 let sweep = CoraWaveBatchRange {
2089 period: cfg.period_range,
2090 r_multi: cfg.r_multi_range,
2091 smooth: cfg.smooth,
2092 };
2093 let out = cora_wave_batch_with_kernel(data, &sweep, detect_best_batch_kernel())
2094 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2095 let js = CoraWaveBatchJsOutput {
2096 values: out.values,
2097 combos: out.combos,
2098 rows: out.rows,
2099 cols: out.cols,
2100 };
2101 serde_wasm_bindgen::to_value(&js)
2102 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2103}
2104
2105#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2106#[wasm_bindgen]
2107pub fn cora_wave_batch_into(
2108 in_ptr: *const f64,
2109 out_ptr: *mut f64,
2110 len: usize,
2111 period_start: usize,
2112 period_end: usize,
2113 period_step: usize,
2114 rmulti_start: f64,
2115 rmulti_end: f64,
2116 rmulti_step: f64,
2117 smooth: bool,
2118) -> Result<usize, JsValue> {
2119 if in_ptr.is_null() || out_ptr.is_null() {
2120 return Err(JsValue::from_str(
2121 "null pointer passed to cora_wave_batch_into",
2122 ));
2123 }
2124 unsafe {
2125 let data = std::slice::from_raw_parts(in_ptr, len);
2126 let sweep = CoraWaveBatchRange {
2127 period: (period_start, period_end, period_step),
2128 r_multi: (rmulti_start, rmulti_end, rmulti_step),
2129 smooth,
2130 };
2131 let combos = expand_grid_cw(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2132 let rows = combos.len();
2133 let cols = len;
2134 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2135
2136 cora_wave_batch_inner_into(data, &sweep, detect_best_batch_kernel(), false, out)
2137 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2138 Ok(rows)
2139 }
2140}
2141
2142#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2143#[wasm_bindgen]
2144#[deprecated(
2145 since = "1.0.0",
2146 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2147)]
2148pub struct CoraWaveContext {
2149 weights: Vec<f64>,
2150 inv_norm: f64,
2151 period: usize,
2152 smooth_period: usize,
2153 kernel: Kernel,
2154}
2155
2156#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2157#[wasm_bindgen]
2158#[allow(deprecated)]
2159impl CoraWaveContext {
2160 #[wasm_bindgen(constructor)]
2161 #[deprecated(
2162 since = "1.0.0",
2163 note = "For weight reuse patterns, use the fast/unsafe API with persistent buffers"
2164 )]
2165 pub fn new(period: usize, r_multi: f64, smooth: bool) -> Result<CoraWaveContext, JsValue> {
2166 if period == 0 {
2167 return Err(JsValue::from_str("Invalid period: 0"));
2168 }
2169 if !r_multi.is_finite() || r_multi < 0.0 {
2170 return Err(JsValue::from_str(&format!("Invalid r_multi: {}", r_multi)));
2171 }
2172 let smooth_period = if smooth {
2173 ((period as f64).sqrt().round() as usize).max(1)
2174 } else {
2175 1
2176 };
2177
2178 if period == 1 {
2179 return Ok(CoraWaveContext {
2180 weights: vec![1.0],
2181 inv_norm: 1.0,
2182 period,
2183 smooth_period,
2184 kernel: detect_best_kernel(),
2185 });
2186 }
2187
2188 let start_wt = 0.01;
2189 let end_wt = period as f64;
2190 let r = (end_wt / start_wt).powf(1.0 / (period as f64 - 1.0)) - 1.0;
2191 let base = 1.0 + r * r_multi;
2192
2193 let mut weights = Vec::with_capacity(period);
2194 let mut norm = 0.0;
2195 for j in 0..period {
2196 let w = start_wt * base.powi((j + 1) as i32);
2197 weights.push(w);
2198 norm += w;
2199 }
2200
2201 Ok(CoraWaveContext {
2202 weights,
2203 inv_norm: 1.0 / norm,
2204 period,
2205 smooth_period,
2206 kernel: detect_best_kernel(),
2207 })
2208 }
2209
2210 pub fn update_into(
2211 &self,
2212 in_ptr: *const f64,
2213 out_ptr: *mut f64,
2214 len: usize,
2215 ) -> Result<(), JsValue> {
2216 if len < self.period {
2217 return Err(JsValue::from_str("Data length less than period"));
2218 }
2219 if in_ptr.is_null() || out_ptr.is_null() {
2220 return Err(JsValue::from_str("null pointer passed to update_into"));
2221 }
2222 unsafe {
2223 let data = std::slice::from_raw_parts(in_ptr, len);
2224 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2225 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2226
2227 cora_wave_scalar_with_weights(
2228 data,
2229 &self.weights,
2230 self.inv_norm,
2231 self.smooth_period,
2232 first,
2233 out,
2234 );
2235
2236 let warm = first + self.period - 1 + self.smooth_period.saturating_sub(1);
2237 for i in 0..warm.min(len) {
2238 out[i] = f64::NAN;
2239 }
2240 }
2241 Ok(())
2242 }
2243
2244 pub fn get_warmup_period(&self) -> usize {
2245 self.period - 1 + self.smooth_period.saturating_sub(1)
2246 }
2247}
2248
2249#[cfg(test)]
2250mod tests {
2251 use super::*;
2252 use crate::skip_if_unsupported;
2253 use crate::utilities::data_loader::read_candles_from_csv;
2254 #[cfg(feature = "proptest")]
2255 use proptest::prelude::*;
2256 use std::error::Error;
2257
2258 #[test]
2259 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2260 fn test_cora_wave_into_matches_api() {
2261 let mut data = vec![f64::NAN; 5];
2262 for i in 0..256 {
2263 let x = (i as f64 * 0.03141592653589793).sin() * 10.0 + 50.0;
2264 data.push(x);
2265 }
2266
2267 let input = CoraWaveInput::from_slice(&data, CoraWaveParams::default());
2268
2269 let baseline = cora_wave(&input).expect("baseline cora_wave() failed");
2270 let mut out = vec![0.0; data.len()];
2271 cora_wave_into(&input, &mut out).expect("cora_wave_into() failed");
2272
2273 assert_eq!(baseline.values.len(), out.len());
2274 for (i, (&a, &b)) in baseline.values.iter().zip(&out).enumerate() {
2275 let ok = if a.is_nan() && b.is_nan() {
2276 true
2277 } else if a == b {
2278 true
2279 } else {
2280 (a - b).abs() <= 1e-12
2281 };
2282 assert!(ok, "mismatch at index {}: {} vs {}", i, a, b);
2283 }
2284 }
2285
2286 fn check_cora_wave_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2287 skip_if_unsupported!(kernel, test_name);
2288 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2289 let candles = read_candles_from_csv(file_path)?;
2290
2291 let input = CoraWaveInput::from_candles(&candles, "close", CoraWaveParams::default());
2292 let result = cora_wave_with_kernel(&input, kernel)?;
2293
2294 let expected_last_five = [
2295 59248.63632114,
2296 59251.74238978,
2297 59203.36944998,
2298 59171.14999178,
2299 59053.74201623,
2300 ];
2301
2302 let start = result.values.len().saturating_sub(5);
2303 for (i, &val) in result.values[start..].iter().enumerate() {
2304 let diff = (val - expected_last_five[i]).abs();
2305 assert!(
2306 diff < 0.01,
2307 "[{}] CoRa Wave {:?} mismatch at idx {}: got {}, expected {}",
2308 test_name,
2309 kernel,
2310 i,
2311 val,
2312 expected_last_five[i]
2313 );
2314 }
2315 Ok(())
2316 }
2317
2318 fn check_cora_wave_partial_params(
2319 test_name: &str,
2320 kernel: Kernel,
2321 ) -> Result<(), Box<dyn Error>> {
2322 skip_if_unsupported!(kernel, test_name);
2323 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2324 let candles = read_candles_from_csv(file_path)?;
2325
2326 let default_params = CoraWaveParams {
2327 period: None,
2328 r_multi: None,
2329 smooth: None,
2330 };
2331 let input = CoraWaveInput::from_candles(&candles, "close", default_params);
2332 let output = cora_wave_with_kernel(&input, kernel)?;
2333 assert_eq!(output.values.len(), candles.close.len());
2334
2335 Ok(())
2336 }
2337
2338 fn check_cora_wave_default_candles(
2339 test_name: &str,
2340 kernel: Kernel,
2341 ) -> Result<(), Box<dyn Error>> {
2342 skip_if_unsupported!(kernel, test_name);
2343 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2344 let candles = read_candles_from_csv(file_path)?;
2345
2346 let input = CoraWaveInput::with_default_candles(&candles);
2347 match input.data {
2348 CoraWaveData::Candles { source, .. } => assert_eq!(source, "close"),
2349 _ => panic!("Expected CoraWaveData::Candles"),
2350 }
2351 let output = cora_wave_with_kernel(&input, kernel)?;
2352 assert_eq!(output.values.len(), candles.close.len());
2353
2354 Ok(())
2355 }
2356
2357 fn check_cora_wave_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2358 skip_if_unsupported!(kernel, test_name);
2359 let input_data = [10.0, 20.0, 30.0];
2360 let params = CoraWaveParams {
2361 period: Some(0),
2362 r_multi: None,
2363 smooth: None,
2364 };
2365 let input = CoraWaveInput::from_slice(&input_data, params);
2366 let res = cora_wave_with_kernel(&input, kernel);
2367 assert!(
2368 res.is_err(),
2369 "[{}] CoRa Wave should fail with zero period",
2370 test_name
2371 );
2372 Ok(())
2373 }
2374
2375 fn check_cora_wave_period_exceeds_length(
2376 test_name: &str,
2377 kernel: Kernel,
2378 ) -> Result<(), Box<dyn Error>> {
2379 skip_if_unsupported!(kernel, test_name);
2380 let data_small = [10.0, 20.0, 30.0];
2381 let params = CoraWaveParams {
2382 period: Some(10),
2383 r_multi: None,
2384 smooth: None,
2385 };
2386 let input = CoraWaveInput::from_slice(&data_small, params);
2387 let res = cora_wave_with_kernel(&input, kernel);
2388 assert!(
2389 res.is_err(),
2390 "[{}] CoRa Wave should fail with period exceeding length",
2391 test_name
2392 );
2393 Ok(())
2394 }
2395
2396 fn check_cora_wave_very_small_dataset(
2397 test_name: &str,
2398 kernel: Kernel,
2399 ) -> Result<(), Box<dyn Error>> {
2400 skip_if_unsupported!(kernel, test_name);
2401 let single_point = [42.0];
2402 let params = CoraWaveParams::default();
2403 let input = CoraWaveInput::from_slice(&single_point, params);
2404 let res = cora_wave_with_kernel(&input, kernel);
2405 assert!(
2406 res.is_err(),
2407 "[{}] CoRa Wave should fail with insufficient data",
2408 test_name
2409 );
2410 Ok(())
2411 }
2412
2413 fn check_cora_wave_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2414 skip_if_unsupported!(kernel, test_name);
2415 let empty: [f64; 0] = [];
2416 let params = CoraWaveParams::default();
2417 let input = CoraWaveInput::from_slice(&empty, params);
2418 let res = cora_wave_with_kernel(&input, kernel);
2419 assert!(
2420 res.is_err(),
2421 "[{}] CoRa Wave should fail with empty input",
2422 test_name
2423 );
2424 Ok(())
2425 }
2426
2427 fn check_cora_wave_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2428 skip_if_unsupported!(kernel, test_name);
2429 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2430 let params = CoraWaveParams::default();
2431 let input = CoraWaveInput::from_slice(&nan_data, params);
2432 let res = cora_wave_with_kernel(&input, kernel);
2433 assert!(
2434 res.is_err(),
2435 "[{}] CoRa Wave should fail with all NaN values",
2436 test_name
2437 );
2438 Ok(())
2439 }
2440
2441 fn check_cora_wave_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2442 skip_if_unsupported!(kernel, test_name);
2443
2444 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2445 let c = read_candles_from_csv(file)?;
2446
2447 let params = CoraWaveParams::default();
2448 let input = CoraWaveInput::from_candles(&c, "close", params.clone());
2449 let batch = cora_wave_with_kernel(&input, kernel)?.values;
2450
2451 let mut stream = CoraWaveStream::try_new(params)?;
2452 let mut streamed = Vec::with_capacity(c.close.len());
2453 for &px in &c.close {
2454 match stream.update(px) {
2455 Some(v) => streamed.push(v),
2456 None => streamed.push(f64::NAN),
2457 }
2458 }
2459
2460 assert_eq!(batch.len(), streamed.len());
2461 for (i, (&b, &s)) in batch.iter().zip(&streamed).enumerate() {
2462 if b.is_nan() && s.is_nan() {
2463 continue;
2464 }
2465 let d = (b - s).abs();
2466 assert!(
2467 d < 1e-9,
2468 "[{}] streaming mismatch at {}: {} vs {}",
2469 test_name,
2470 i,
2471 b,
2472 s
2473 );
2474 }
2475 Ok(())
2476 }
2477
2478 fn check_cora_wave_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2479 skip_if_unsupported!(kernel, test_name);
2480
2481 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2482 let c = read_candles_from_csv(file)?;
2483
2484 let out = cora_wave_with_kernel(&CoraWaveInput::with_default_candles(&c), kernel)?.values;
2485 if out.len() > 240 {
2486 for (i, &v) in out[240..].iter().enumerate() {
2487 assert!(!v.is_nan(), "[{}] unexpected NaN at {}", test_name, 240 + i);
2488 }
2489 }
2490 Ok(())
2491 }
2492
2493 fn check_cora_wave_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2494 skip_if_unsupported!(kernel, test_name);
2495
2496 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2497 let c = read_candles_from_csv(file)?;
2498
2499 let first = cora_wave_with_kernel(
2500 &CoraWaveInput::from_candles(&c, "close", CoraWaveParams::default()),
2501 kernel,
2502 )?;
2503
2504 let second_in = CoraWaveInput::from_slice(&first.values, CoraWaveParams::default());
2505 let second = cora_wave_with_kernel(&second_in, kernel)?;
2506 let second_ref = cora_wave_with_kernel(&second_in, Kernel::Scalar)?;
2507
2508 assert_eq!(second.values.len(), first.values.len());
2509 for (i, (&a, &b)) in second.values.iter().zip(&second_ref.values).enumerate() {
2510 if a.is_nan() && b.is_nan() {
2511 continue;
2512 }
2513 assert!(
2514 (a - b).abs() < 1e-9,
2515 "[{}] reinput mismatch at {}: {} vs {}",
2516 test_name,
2517 i,
2518 a,
2519 b
2520 );
2521 }
2522 Ok(())
2523 }
2524
2525 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2526 skip_if_unsupported!(kernel, test);
2527
2528 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2529 let c = read_candles_from_csv(file)?;
2530
2531 let out = CoraWaveBatchBuilder::new()
2532 .kernel(kernel)
2533 .apply_candles(&c, "close")?;
2534 let def = CoraWaveParams::default();
2535 let row = out.values_for(&def).expect("default row missing");
2536 assert_eq!(row.len(), c.close.len());
2537
2538 let expected = [
2539 59248.63632114,
2540 59251.74238978,
2541 59203.36944998,
2542 59171.14999178,
2543 59053.74201623,
2544 ];
2545 let start = row.len() - 5;
2546 for (i, &v) in row[start..].iter().enumerate() {
2547 assert!(
2548 (v - expected[i]).abs() < 0.01,
2549 "[{test}] default-row mismatch at idx {i}: {v}"
2550 );
2551 }
2552 Ok(())
2553 }
2554
2555 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2556 skip_if_unsupported!(kernel, test);
2557
2558 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2559 let c = read_candles_from_csv(file)?;
2560
2561 let out = CoraWaveBatchBuilder::new()
2562 .kernel(kernel)
2563 .period_range(20, 60, 1)
2564 .r_multi_range(1.0, 3.0, 0.25)
2565 .apply_candles(&c, "close")?;
2566
2567 let expected = 41 * 9;
2568 assert_eq!(out.combos.len(), expected);
2569 assert_eq!(out.rows, expected);
2570 assert_eq!(out.cols, c.close.len());
2571 Ok(())
2572 }
2573
2574 #[cfg(debug_assertions)]
2575 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2576 skip_if_unsupported!(kernel, test);
2577
2578 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2579 let c = read_candles_from_csv(file)?;
2580
2581 let out = CoraWaveBatchBuilder::new()
2582 .kernel(kernel)
2583 .period_range(20, 24, 2)
2584 .r_multi_range(1.5, 2.5, 0.5)
2585 .apply_candles(&c, "close")?;
2586
2587 for (idx, &v) in out.values.iter().enumerate() {
2588 if v.is_nan() {
2589 continue;
2590 }
2591 let bits = v.to_bits();
2592 if bits == 0x11111111_11111111
2593 || bits == 0x22222222_22222222
2594 || bits == 0x33333333_33333333
2595 {
2596 let row = idx / out.cols;
2597 let col = idx % out.cols;
2598 let combo = &out.combos[row];
2599 panic!(
2600 "[{}] poison value at row {} col {} (idx {}) params: period={}, r_multi={}, smooth={}",
2601 test, row, col, idx,
2602 combo.period.unwrap_or(20),
2603 combo.r_multi.unwrap_or(2.0),
2604 combo.smooth.unwrap_or(true),
2605 );
2606 }
2607 }
2608 Ok(())
2609 }
2610
2611 #[cfg(not(debug_assertions))]
2612 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2613 Ok(())
2614 }
2615
2616 #[cfg(debug_assertions)]
2617 fn check_cora_wave_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2618 skip_if_unsupported!(kernel, test);
2619 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2620 let c = read_candles_from_csv(file)?;
2621 let out = cora_wave_with_kernel(&CoraWaveInput::with_default_candles(&c), kernel)?.values;
2622 for (i, &v) in out.iter().enumerate() {
2623 if v.is_nan() {
2624 continue;
2625 }
2626 let b = v.to_bits();
2627 assert!(
2628 !(b == 0x11111111_11111111 || b == 0x22222222_22222222 || b == 0x33333333_33333333),
2629 "[{test}] poison at idx {i}"
2630 );
2631 }
2632 Ok(())
2633 }
2634
2635 #[cfg(not(debug_assertions))]
2636 fn check_cora_wave_no_poison(_: &str, _: Kernel) -> Result<(), Box<dyn Error>> {
2637 Ok(())
2638 }
2639
2640 macro_rules! generate_all_cora_wave_tests {
2641 ($($test_fn:ident),*) => {
2642 paste::paste! {
2643 $(
2644 #[test]
2645 fn [<$test_fn _scalar_f64>]() {
2646 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2647 }
2648 )*
2649 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2650 $(
2651 #[test]
2652 fn [<$test_fn _avx2_f64>]() {
2653 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2654 }
2655 )*
2656 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2657 $(
2658 #[test]
2659 fn [<$test_fn _avx512_f64>]() {
2660 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2661 }
2662 )*
2663 }
2664 };
2665 }
2666
2667 generate_all_cora_wave_tests!(
2668 check_cora_wave_accuracy,
2669 check_cora_wave_partial_params,
2670 check_cora_wave_default_candles,
2671 check_cora_wave_zero_period,
2672 check_cora_wave_period_exceeds_length,
2673 check_cora_wave_very_small_dataset,
2674 check_cora_wave_empty_input,
2675 check_cora_wave_all_nan,
2676 check_cora_wave_nan_handling,
2677 check_cora_wave_streaming,
2678 check_cora_wave_reinput,
2679 check_cora_wave_no_poison
2680 );
2681
2682 macro_rules! gen_batch_tests {
2683 ($fn_name:ident) => {
2684 paste::paste! {
2685 #[test] fn [<$fn_name _scalar>]() { let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch); }
2686 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2687 #[test] fn [<$fn_name _avx2>]() { let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch); }
2688 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2689 #[test] fn [<$fn_name _avx512>]() { let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch); }
2690 #[test] fn [<$fn_name _auto_detect>]() { let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto); }
2691 }
2692 };
2693 }
2694
2695 gen_batch_tests!(check_batch_default_row);
2696 gen_batch_tests!(check_batch_sweep);
2697 gen_batch_tests!(check_batch_no_poison);
2698
2699 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2700 #[test]
2701 fn test_cora_wave_simd128_correctness() {
2702 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2703 let params = CoraWaveParams::default();
2704 let input = CoraWaveInput::from_slice(&data, params);
2705 let scalar = cora_wave_with_kernel(&input, Kernel::Scalar).unwrap();
2706 let simd = cora_wave_with_kernel(&input, Kernel::Scalar).unwrap();
2707 assert_eq!(scalar.values.len(), simd.values.len());
2708 for (i, (a, b)) in scalar.values.iter().zip(simd.values.iter()).enumerate() {
2709 assert!((a - b).abs() < 1e-10, "SIMD128 mismatch at {i}: {a} vs {b}");
2710 }
2711 }
2712
2713 #[cfg(feature = "proptest")]
2714 proptest! {
2715 #[test]
2716 fn test_cora_wave_no_panic(data: Vec<f64>, period in 1usize..100) {
2717 let params = CoraWaveParams {
2718 period: Some(period),
2719 r_multi: Some(2.0),
2720 smooth: Some(true),
2721 };
2722 let input = CoraWaveInput::from_slice(&data, params);
2723 let _ = cora_wave(&input);
2724 }
2725
2726 #[test]
2727 fn test_cora_wave_length_preservation(size in 10usize..100) {
2728 let data: Vec<f64> = (0..size).map(|i| i as f64).collect();
2729 let params = CoraWaveParams::default();
2730 let input = CoraWaveInput::from_slice(&data, params);
2731
2732 if let Ok(output) = cora_wave(&input) {
2733 prop_assert_eq!(output.values.len(), size);
2734 }
2735 }
2736 }
2737}