1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
8use core::arch::x86_64::*;
9#[cfg(not(target_arch = "wasm32"))]
10use rayon::prelude::*;
11use std::convert::AsRef;
12use std::error::Error;
13use std::mem::MaybeUninit;
14use thiserror::Error;
15
16#[cfg(feature = "python")]
17use crate::utilities::kernel_validation::validate_kernel;
18#[cfg(all(feature = "python", feature = "cuda"))]
19use numpy::PyUntypedArrayMethods;
20#[cfg(feature = "python")]
21use numpy::{IntoPyArray, PyArray1};
22#[cfg(feature = "python")]
23use pyo3::exceptions::PyValueError;
24#[cfg(feature = "python")]
25use pyo3::prelude::*;
26#[cfg(feature = "python")]
27use pyo3::types::{PyDict, PyList};
28
29#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
30use serde::{Deserialize, Serialize};
31#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
32use wasm_bindgen::prelude::*;
33
34#[cfg(all(feature = "python", feature = "cuda"))]
35use crate::cuda::{cuda_available, moving_averages::CudaDecycler};
36#[cfg(all(feature = "python", feature = "cuda"))]
37use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
38#[cfg(all(feature = "python", feature = "cuda"))]
39use cust::context::Context;
40#[cfg(all(feature = "python", feature = "cuda"))]
41use cust::memory::DeviceBuffer;
42#[cfg(all(feature = "python", feature = "cuda"))]
43use numpy::{PyReadonlyArray1, PyReadonlyArray2};
44#[cfg(all(feature = "python", feature = "cuda"))]
45use std::sync::Arc;
46
47#[cfg(all(feature = "python", feature = "cuda"))]
48use crate::cuda::moving_averages::decycler_wrapper::DeviceArrayF32Decycler as DeviceArrayF32Inner;
49
50#[cfg(all(feature = "python", feature = "cuda"))]
51#[pyclass(
52 module = "ta_indicators.cuda",
53 name = "DecyclerDeviceArrayF32",
54 unsendable
55)]
56pub struct DeviceArrayF32Py {
57 pub(crate) inner: DeviceArrayF32Inner,
58}
59
60#[cfg(all(feature = "python", feature = "cuda"))]
61#[pymethods]
62impl DeviceArrayF32Py {
63 #[getter]
64 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
65 let d = PyDict::new(py);
66 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
67 d.set_item("typestr", "<f4")?;
68 d.set_item(
69 "strides",
70 (
71 self.inner.cols * std::mem::size_of::<f32>(),
72 std::mem::size_of::<f32>(),
73 ),
74 )?;
75 d.set_item("data", (self.inner.device_ptr() as usize, false))?;
76
77 d.set_item("version", 3)?;
78 Ok(d)
79 }
80
81 fn __dlpack_device__(&self) -> (i32, i32) {
82 (2, self.inner.device_id as i32)
83 }
84
85 #[pyo3(signature=(stream=None, max_version=None, dl_device=None, copy=None))]
86 fn __dlpack__<'py>(
87 &mut self,
88 py: Python<'py>,
89 stream: Option<pyo3::PyObject>,
90 max_version: Option<pyo3::PyObject>,
91 dl_device: Option<pyo3::PyObject>,
92 copy: Option<pyo3::PyObject>,
93 ) -> PyResult<pyo3::PyObject> {
94 let (kdl, alloc_dev) = self.__dlpack_device__();
95 if let Some(dev_obj) = dl_device.as_ref() {
96 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
97 if dev_ty != kdl || dev_id != alloc_dev {
98 let wants_copy = copy
99 .as_ref()
100 .and_then(|c| c.extract::<bool>(py).ok())
101 .unwrap_or(false);
102 if wants_copy {
103 return Err(PyValueError::new_err(
104 "device copy not implemented for __dlpack__",
105 ));
106 } else {
107 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
108 }
109 }
110 }
111 }
112 let _ = stream;
113
114 let dummy =
115 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
116 let ctx = self.inner.ctx.clone();
117 let device_id = self.inner.device_id;
118 let inner = std::mem::replace(
119 &mut self.inner,
120 DeviceArrayF32Inner {
121 buf: dummy,
122 rows: 0,
123 cols: 0,
124 ctx,
125 device_id,
126 },
127 );
128
129 let rows = inner.rows;
130 let cols = inner.cols;
131 let buf = inner.buf;
132
133 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
134
135 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
136 }
137}
138
139impl<'a> AsRef<[f64]> for DecyclerInput<'a> {
140 #[inline(always)]
141 fn as_ref(&self) -> &[f64] {
142 match &self.data {
143 DecyclerData::Slice(slice) => slice,
144 DecyclerData::Candles { candles, source } => source_type(candles, source),
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
150pub enum DecyclerData<'a> {
151 Candles {
152 candles: &'a Candles,
153 source: &'a str,
154 },
155 Slice(&'a [f64]),
156}
157
158#[derive(Debug, Clone)]
159pub struct DecyclerOutput {
160 pub values: Vec<f64>,
161}
162
163#[derive(Debug, Clone)]
164#[cfg_attr(
165 all(target_arch = "wasm32", feature = "wasm"),
166 derive(Serialize, Deserialize)
167)]
168pub struct DecyclerParams {
169 pub hp_period: Option<usize>,
170 pub k: Option<f64>,
171}
172
173impl Default for DecyclerParams {
174 fn default() -> Self {
175 Self {
176 hp_period: Some(125),
177 k: Some(0.707),
178 }
179 }
180}
181
182#[derive(Debug, Clone)]
183pub struct DecyclerInput<'a> {
184 pub data: DecyclerData<'a>,
185 pub params: DecyclerParams,
186}
187
188impl<'a> DecyclerInput<'a> {
189 #[inline]
190 pub fn from_candles(c: &'a Candles, s: &'a str, p: DecyclerParams) -> Self {
191 Self {
192 data: DecyclerData::Candles {
193 candles: c,
194 source: s,
195 },
196 params: p,
197 }
198 }
199 #[inline]
200 pub fn from_slice(sl: &'a [f64], p: DecyclerParams) -> Self {
201 Self {
202 data: DecyclerData::Slice(sl),
203 params: p,
204 }
205 }
206 #[inline]
207 pub fn with_default_candles(c: &'a Candles) -> Self {
208 Self::from_candles(c, "close", DecyclerParams::default())
209 }
210 #[inline]
211 pub fn get_hp_period(&self) -> usize {
212 self.params.hp_period.unwrap_or(125)
213 }
214 #[inline]
215 pub fn get_k(&self) -> f64 {
216 self.params.k.unwrap_or(0.707)
217 }
218}
219
220#[derive(Copy, Clone, Debug)]
221pub struct DecyclerBuilder {
222 hp_period: Option<usize>,
223 k: Option<f64>,
224 kernel: Kernel,
225}
226
227impl Default for DecyclerBuilder {
228 fn default() -> Self {
229 Self {
230 hp_period: None,
231 k: None,
232 kernel: Kernel::Auto,
233 }
234 }
235}
236
237impl DecyclerBuilder {
238 #[inline(always)]
239 pub fn new() -> Self {
240 Self::default()
241 }
242 #[inline(always)]
243 pub fn hp_period(mut self, n: usize) -> Self {
244 self.hp_period = Some(n);
245 self
246 }
247 #[inline(always)]
248 pub fn k(mut self, x: f64) -> Self {
249 self.k = Some(x);
250 self
251 }
252 #[inline(always)]
253 pub fn kernel(mut self, k: Kernel) -> Self {
254 self.kernel = k;
255 self
256 }
257 #[inline(always)]
258 pub fn apply(self, c: &Candles) -> Result<DecyclerOutput, DecyclerError> {
259 let p = DecyclerParams {
260 hp_period: self.hp_period,
261 k: self.k,
262 };
263 let i = DecyclerInput::from_candles(c, "close", p);
264 decycler_with_kernel(&i, self.kernel)
265 }
266 #[inline(always)]
267 pub fn apply_slice(self, d: &[f64]) -> Result<DecyclerOutput, DecyclerError> {
268 let p = DecyclerParams {
269 hp_period: self.hp_period,
270 k: self.k,
271 };
272 let i = DecyclerInput::from_slice(d, p);
273 decycler_with_kernel(&i, self.kernel)
274 }
275 #[inline(always)]
276 pub fn into_stream(self) -> Result<DecyclerStream, DecyclerError> {
277 let p = DecyclerParams {
278 hp_period: self.hp_period,
279 k: self.k,
280 };
281 DecyclerStream::try_new(p)
282 }
283}
284
285#[derive(Debug, Error)]
286pub enum DecyclerError {
287 #[error("decycler: Empty data provided.")]
288 EmptyInputData,
289 #[error("decycler: Invalid period: hp_period = {period}, data length = {data_len}")]
290 InvalidPeriod { period: usize, data_len: usize },
291 #[error("decycler: not enough valid data: needed = {needed}, valid = {valid}")]
292 NotEnoughValidData { needed: usize, valid: usize },
293 #[error("decycler: All values are NaN")]
294 AllValuesNaN,
295 #[error("decycler: Invalid k: k = {k}")]
296 InvalidK { k: f64 },
297 #[error("decycler: output length mismatch: expected={expected}, got={got}")]
298 OutputLengthMismatch { expected: usize, got: usize },
299 #[error("decycler: invalid range: start={start}, end={end}, step={step}")]
300 InvalidRange {
301 start: isize,
302 end: isize,
303 step: isize,
304 },
305 #[error("decycler: invalid kernel for batch: {0:?}")]
306 InvalidKernelForBatch(Kernel),
307 #[error("decycler: invalid input: {0}")]
308 InvalidInput(String),
309}
310
311#[inline]
312pub fn decycler(input: &DecyclerInput) -> Result<DecyclerOutput, DecyclerError> {
313 decycler_with_kernel(input, Kernel::Auto)
314}
315
316#[inline]
317pub fn decycler_into_slice(
318 out: &mut [f64],
319 input: &DecyclerInput,
320 kernel: Kernel,
321) -> Result<(), DecyclerError> {
322 let data: &[f64] = input.as_ref();
323
324 if data.is_empty() {
325 return Err(DecyclerError::EmptyInputData);
326 }
327 if out.len() != data.len() {
328 return Err(DecyclerError::OutputLengthMismatch {
329 expected: data.len(),
330 got: out.len(),
331 });
332 }
333
334 let hp_period = input.get_hp_period();
335 let k = input.get_k();
336
337 if hp_period < 2 || hp_period > data.len() {
338 return Err(DecyclerError::InvalidPeriod {
339 period: hp_period,
340 data_len: data.len(),
341 });
342 }
343 if !(k.is_finite()) || k <= 0.0 {
344 return Err(DecyclerError::InvalidK { k });
345 }
346
347 let first = data
348 .iter()
349 .position(|x| !x.is_nan())
350 .ok_or(DecyclerError::AllValuesNaN)?;
351 if data.len() - first < hp_period {
352 return Err(DecyclerError::NotEnoughValidData {
353 needed: hp_period,
354 valid: data.len() - first,
355 });
356 }
357
358 let chosen = match kernel {
359 Kernel::Auto => Kernel::Scalar,
360 other => other,
361 };
362
363 unsafe {
364 match chosen {
365 Kernel::Scalar | Kernel::ScalarBatch => {
366 decycler_scalar_into(data, hp_period, k, first, out)
367 }
368 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
369 Kernel::Avx2 | Kernel::Avx2Batch => {
370 decycler_scalar_into(data, hp_period, k, first, out)
371 }
372 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
373 Kernel::Avx512 | Kernel::Avx512Batch => {
374 decycler_scalar_into(data, hp_period, k, first, out)
375 }
376 _ => unreachable!(),
377 }
378 }?;
379
380 let warmup_period = first + 2;
381 for v in &mut out[..warmup_period] {
382 *v = f64::from_bits(0x7ff8_0000_0000_0000);
383 }
384
385 Ok(())
386}
387
388#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
389#[inline]
390pub fn decycler_into(input: &DecyclerInput, out: &mut [f64]) -> Result<(), DecyclerError> {
391 decycler_into_slice(out, input, Kernel::Auto)
392}
393
394#[inline]
395unsafe fn decycler_scalar_into(
396 data: &[f64],
397 hp_period: usize,
398 k: f64,
399 first: usize,
400 out: &mut [f64],
401) -> Result<(), DecyclerError> {
402 use std::f64::consts::PI;
403
404 let angle = (2.0 * PI * k) * (hp_period as f64).recip();
405 let (sin_val, cos_val) = angle.sin_cos();
406 const EPSILON: f64 = 1e-10;
407 let cos_safe = if cos_val.abs() < EPSILON {
408 EPSILON.copysign(cos_val)
409 } else {
410 cos_val
411 };
412 let alpha = 1.0 + ((sin_val - 1.0) / cos_safe);
413 let one_minus_alpha_half = 1.0 - alpha / 2.0;
414 let c = one_minus_alpha_half * one_minus_alpha_half;
415 let one_minus_alpha = 1.0 - alpha;
416 let one_minus_alpha_sq = one_minus_alpha * one_minus_alpha;
417
418 let mut hp_prev2 = data[first];
419 let mut hp_prev1 = data[first + 1];
420
421 for i in (first + 2)..data.len() {
422 let current = data[i];
423 let prev1 = data[i - 1];
424 let prev2 = data[i - 2];
425
426 let s0 = current * c;
427 let s1 = prev1.mul_add(-2.0 * c, s0);
428 let s2 = prev2.mul_add(c, s1);
429 let s3 = hp_prev1.mul_add(2.0 * one_minus_alpha, s2);
430 let hp_val = hp_prev2.mul_add(-one_minus_alpha_sq, s3);
431
432 hp_prev2 = hp_prev1;
433 hp_prev1 = hp_val;
434
435 out[i] = current - hp_val;
436 }
437 Ok(())
438}
439
440pub fn decycler_with_kernel(
441 input: &DecyclerInput,
442 kernel: Kernel,
443) -> Result<DecyclerOutput, DecyclerError> {
444 let data: &[f64] = match &input.data {
445 DecyclerData::Candles { candles, source } => source_type(candles, source),
446 DecyclerData::Slice(sl) => sl,
447 };
448
449 if data.is_empty() {
450 return Err(DecyclerError::EmptyInputData);
451 }
452 let hp_period = input.get_hp_period();
453 let k = input.get_k();
454 if hp_period < 2 || hp_period > data.len() {
455 return Err(DecyclerError::InvalidPeriod {
456 period: hp_period,
457 data_len: data.len(),
458 });
459 }
460 if !(k.is_finite()) || k <= 0.0 {
461 return Err(DecyclerError::InvalidK { k });
462 }
463 let first = data
464 .iter()
465 .position(|x| !x.is_nan())
466 .ok_or(DecyclerError::AllValuesNaN)?;
467 if data.len() - first < hp_period {
468 return Err(DecyclerError::NotEnoughValidData {
469 needed: hp_period,
470 valid: data.len() - first,
471 });
472 }
473
474 let chosen = match kernel {
475 Kernel::Auto => Kernel::Scalar,
476 other => other,
477 };
478
479 unsafe {
480 match chosen {
481 Kernel::Scalar | Kernel::ScalarBatch => decycler_scalar(data, hp_period, k, first),
482 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
483 Kernel::Avx2 | Kernel::Avx2Batch => decycler_avx2(data, hp_period, k, first),
484 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
485 Kernel::Avx512 | Kernel::Avx512Batch => decycler_avx512(data, hp_period, k, first),
486 _ => unreachable!(),
487 }
488 }
489}
490
491#[inline]
492pub fn decycler_scalar(
493 data: &[f64],
494 hp_period: usize,
495 k: f64,
496 first: usize,
497) -> Result<DecyclerOutput, DecyclerError> {
498 use std::f64::consts::PI;
499
500 let warmup_period = first + 2;
501 let mut out = alloc_with_nan_prefix(data.len(), warmup_period);
502
503 let mut hp_prev2 = 0.0;
504 let mut hp_prev1 = 0.0;
505
506 let angle = (2.0 * PI * k) * (hp_period as f64).recip();
507 let (sin_val, cos_val) = angle.sin_cos();
508
509 const EPSILON: f64 = 1e-10;
510 let cos_safe = if cos_val.abs() < EPSILON {
511 EPSILON.copysign(cos_val)
512 } else {
513 cos_val
514 };
515 let alpha = 1.0 + ((sin_val - 1.0) / cos_safe);
516 let one_minus_alpha_half = 1.0 - alpha / 2.0;
517 let c = one_minus_alpha_half * one_minus_alpha_half;
518 let one_minus_alpha = 1.0 - alpha;
519 let one_minus_alpha_sq = one_minus_alpha * one_minus_alpha;
520
521 if data.len() > first {
522 hp_prev2 = data[first];
523 }
524 if data.len() > (first + 1) {
525 hp_prev1 = data[first + 1];
526 }
527
528 for i in (first + 2)..data.len() {
529 let current = data[i];
530 let prev1 = data[i - 1];
531 let prev2 = data[i - 2];
532
533 let s0 = current * c;
534 let s1 = prev1.mul_add(-2.0 * c, s0);
535 let s2 = prev2.mul_add(c, s1);
536 let s3 = hp_prev1.mul_add(2.0 * one_minus_alpha, s2);
537 let hp_val = hp_prev2.mul_add(-one_minus_alpha_sq, s3);
538
539 hp_prev2 = hp_prev1;
540 hp_prev1 = hp_val;
541
542 out[i] = current - hp_val;
543 }
544 Ok(DecyclerOutput { values: out })
545}
546
547#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
548#[inline]
549pub fn decycler_avx512(
550 data: &[f64],
551 hp_period: usize,
552 k: f64,
553 first: usize,
554) -> Result<DecyclerOutput, DecyclerError> {
555 decycler_scalar(data, hp_period, k, first)
556}
557
558#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
559#[inline]
560pub fn decycler_avx2(
561 data: &[f64],
562 hp_period: usize,
563 k: f64,
564 first: usize,
565) -> Result<DecyclerOutput, DecyclerError> {
566 decycler_scalar(data, hp_period, k, first)
567}
568
569#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
570#[inline]
571pub fn decycler_avx512_short(
572 data: &[f64],
573 hp_period: usize,
574 k: f64,
575 first: usize,
576) -> Result<DecyclerOutput, DecyclerError> {
577 decycler_scalar(data, hp_period, k, first)
578}
579
580#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
581#[inline]
582pub fn decycler_avx512_long(
583 data: &[f64],
584 hp_period: usize,
585 k: f64,
586 first: usize,
587) -> Result<DecyclerOutput, DecyclerError> {
588 decycler_scalar(data, hp_period, k, first)
589}
590
591#[inline]
592pub fn decycler_batch_with_kernel(
593 data: &[f64],
594 sweep: &DecyclerBatchRange,
595 k: Kernel,
596) -> Result<DecyclerBatchOutput, DecyclerError> {
597 let kernel = match k {
598 Kernel::Auto => detect_best_batch_kernel(),
599 other if other.is_batch() => other,
600 other => return Err(DecyclerError::InvalidKernelForBatch(other)),
601 };
602 let simd = match kernel {
603 Kernel::Avx512Batch => Kernel::Avx512,
604 Kernel::Avx2Batch => Kernel::Avx2,
605 Kernel::ScalarBatch => Kernel::Scalar,
606 _ => unreachable!(),
607 };
608 decycler_batch_par_slice(data, sweep, simd)
609}
610
611#[derive(Debug, Clone)]
612pub struct DecyclerStream {
613 hp_period: usize,
614 k: f64,
615
616 c: f64,
617 oma2: f64,
618 oma_sq_neg: f64,
619 c_neg2: f64,
620
621 x1: f64,
622 x2: f64,
623 hp1: f64,
624 hp2: f64,
625
626 seen: u8,
627}
628
629impl DecyclerStream {
630 pub fn try_new(params: DecyclerParams) -> Result<Self, DecyclerError> {
631 let hp_period = params.hp_period.unwrap_or(125);
632 let k = params.k.unwrap_or(0.707);
633 if hp_period < 2 {
634 return Err(DecyclerError::InvalidPeriod {
635 period: hp_period,
636 data_len: 0,
637 });
638 }
639 if !k.is_finite() || k <= 0.0 {
640 return Err(DecyclerError::InvalidK { k });
641 }
642
643 use std::f64::consts::PI;
644 let angle = (2.0 * PI * k) * (hp_period as f64).recip();
645 let (sin_val, cos_val) = angle.sin_cos();
646
647 const EPS: f64 = 1e-10;
648 let cos_safe = if cos_val.abs() < EPS {
649 EPS.copysign(cos_val)
650 } else {
651 cos_val
652 };
653 let alpha = 1.0 + (sin_val - 1.0) * (1.0 / cos_safe);
654
655 let oma = 1.0 - alpha;
656 let c = {
657 let t = 1.0 - 0.5 * alpha;
658 t * t
659 };
660
661 Ok(Self {
662 hp_period,
663 k,
664 c,
665 oma2: 2.0 * oma,
666 oma_sq_neg: -(oma * oma),
667 c_neg2: -2.0 * c,
668
669 x1: f64::NAN,
670 x2: f64::NAN,
671 hp1: f64::NAN,
672 hp2: f64::NAN,
673 seen: 0,
674 })
675 }
676
677 #[inline(always)]
678 pub fn update(&mut self, x: f64) -> Option<f64> {
679 if self.seen < 2 {
680 if x.is_nan() {
681 return None;
682 }
683 match self.seen {
684 0 => {
685 self.x1 = x;
686 self.hp1 = x;
687 self.seen = 1;
688 return None;
689 }
690 1 => {
691 self.x2 = self.x1;
692 self.x1 = x;
693 self.hp2 = self.hp1;
694 self.hp1 = x;
695 self.seen = 2;
696 return None;
697 }
698 _ => {}
699 }
700 }
701
702 let s0 = self.c * x;
703 let s1 = self.c_neg2.mul_add(self.x1, s0);
704 let s2 = self.c.mul_add(self.x2, s1);
705 let s3 = self.oma2.mul_add(self.hp1, s2);
706 let hp = self.oma_sq_neg.mul_add(self.hp2, s3);
707
708 let out = x - hp;
709
710 self.x2 = self.x1;
711 self.x1 = x;
712 self.hp2 = self.hp1;
713 self.hp1 = hp;
714
715 Some(out)
716 }
717}
718
719#[derive(Clone, Debug)]
720pub struct DecyclerBatchRange {
721 pub hp_period: (usize, usize, usize),
722 pub k: (f64, f64, f64),
723}
724
725impl Default for DecyclerBatchRange {
726 fn default() -> Self {
727 Self {
728 hp_period: (125, 374, 1),
729 k: (0.707, 0.707, 0.0),
730 }
731 }
732}
733
734#[derive(Clone, Debug, Default)]
735pub struct DecyclerBatchBuilder {
736 range: DecyclerBatchRange,
737 kernel: Kernel,
738}
739
740impl DecyclerBatchBuilder {
741 pub fn new() -> Self {
742 Self::default()
743 }
744 pub fn kernel(mut self, k: Kernel) -> Self {
745 self.kernel = k;
746 self
747 }
748 #[inline]
749 pub fn hp_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
750 self.range.hp_period = (start, end, step);
751 self
752 }
753 #[inline]
754 pub fn hp_period_static(mut self, p: usize) -> Self {
755 self.range.hp_period = (p, p, 0);
756 self
757 }
758 #[inline]
759 pub fn k_range(mut self, start: f64, end: f64, step: f64) -> Self {
760 self.range.k = (start, end, step);
761 self
762 }
763 #[inline]
764 pub fn k_static(mut self, v: f64) -> Self {
765 self.range.k = (v, v, 0.0);
766 self
767 }
768 pub fn apply_slice(self, data: &[f64]) -> Result<DecyclerBatchOutput, DecyclerError> {
769 decycler_batch_with_kernel(data, &self.range, self.kernel)
770 }
771 pub fn with_default_slice(
772 data: &[f64],
773 k: Kernel,
774 ) -> Result<DecyclerBatchOutput, DecyclerError> {
775 DecyclerBatchBuilder::new().kernel(k).apply_slice(data)
776 }
777 pub fn apply_candles(
778 self,
779 c: &Candles,
780 src: &str,
781 ) -> Result<DecyclerBatchOutput, DecyclerError> {
782 let slice = source_type(c, src);
783 self.apply_slice(slice)
784 }
785 pub fn with_default_candles(c: &Candles) -> Result<DecyclerBatchOutput, DecyclerError> {
786 DecyclerBatchBuilder::new()
787 .kernel(Kernel::Auto)
788 .apply_candles(c, "close")
789 }
790}
791
792#[derive(Clone, Debug)]
793pub struct DecyclerBatchOutput {
794 pub values: Vec<f64>,
795 pub combos: Vec<DecyclerParams>,
796 pub rows: usize,
797 pub cols: usize,
798}
799impl DecyclerBatchOutput {
800 pub fn row_for_params(&self, p: &DecyclerParams) -> Option<usize> {
801 self.combos.iter().position(|c| {
802 c.hp_period.unwrap_or(125) == p.hp_period.unwrap_or(125)
803 && (c.k.unwrap_or(0.707) - p.k.unwrap_or(0.707)).abs() < 1e-12
804 })
805 }
806 pub fn values_for(&self, p: &DecyclerParams) -> Option<&[f64]> {
807 self.row_for_params(p).map(|row| {
808 let start = row * self.cols;
809 &self.values[start..start + self.cols]
810 })
811 }
812}
813
814#[inline(always)]
815fn expand_grid(r: &DecyclerBatchRange) -> Result<Vec<DecyclerParams>, DecyclerError> {
816 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, DecyclerError> {
817 if step == 0 || start == end {
818 return Ok(vec![start]);
819 }
820 if start < end {
821 return Ok((start..=end).step_by(step).collect());
822 }
823
824 let mut vals = Vec::new();
825 let mut cur: isize = start as isize;
826 let end_i: isize = end as isize;
827 let step_i: isize = step as isize;
828 while cur >= end_i {
829 vals.push(cur as usize);
830 if step_i == 0 {
831 break;
832 }
833 cur = cur.saturating_sub(step_i);
834 if cur == std::isize::MIN {
835 break;
836 }
837 }
838 if vals.is_empty() {
839 return Err(DecyclerError::InvalidRange {
840 start: start as isize,
841 end: end as isize,
842 step: step as isize,
843 });
844 }
845 Ok(vals)
846 }
847 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, DecyclerError> {
848 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
849 return Ok(vec![start]);
850 }
851 let mut v = Vec::new();
852 if start < end && step > 0.0 {
853 let mut x = start;
854 while x <= end + 1e-12 {
855 v.push(x);
856 x += step;
857 }
858 } else if start > end && step < 0.0 {
859 let mut x = start;
860 while x >= end - 1e-12 {
861 v.push(x);
862 x += step;
863 }
864 }
865 if v.is_empty() {
866 return Err(DecyclerError::InvalidRange {
867 start: start as isize,
868 end: end as isize,
869 step: if step >= 0.0 {
870 step as isize
871 } else {
872 (step as i64) as isize
873 },
874 });
875 }
876 Ok(v)
877 }
878 let hp_periods = axis_usize(r.hp_period)?;
879 let ks = axis_f64(r.k)?;
880 let mut out = Vec::with_capacity(hp_periods.len().saturating_mul(ks.len()));
881 for &p in &hp_periods {
882 for &k in &ks {
883 out.push(DecyclerParams {
884 hp_period: Some(p),
885 k: Some(k),
886 });
887 }
888 }
889 Ok(out)
890}
891
892#[inline(always)]
893pub fn decycler_batch_slice(
894 data: &[f64],
895 sweep: &DecyclerBatchRange,
896 kern: Kernel,
897) -> Result<DecyclerBatchOutput, DecyclerError> {
898 decycler_batch_inner(data, sweep, kern, false)
899}
900
901#[inline(always)]
902pub fn decycler_batch_par_slice(
903 data: &[f64],
904 sweep: &DecyclerBatchRange,
905 kern: Kernel,
906) -> Result<DecyclerBatchOutput, DecyclerError> {
907 decycler_batch_inner(data, sweep, kern, true)
908}
909
910#[inline(always)]
911fn decycler_batch_inner(
912 data: &[f64],
913 sweep: &DecyclerBatchRange,
914 kern: Kernel,
915 parallel: bool,
916) -> Result<DecyclerBatchOutput, DecyclerError> {
917 let combos = expand_grid(sweep)?;
918 let cols = data.len();
919 if cols == 0 {
920 return Err(DecyclerError::EmptyInputData);
921 }
922 let rows = combos.len();
923
924 let _total = rows
925 .checked_mul(cols)
926 .ok_or_else(|| DecyclerError::InvalidInput("rows*cols overflow".into()))?;
927
928 let mut buf_mu = make_uninit_matrix(rows, cols);
929
930 let first = data
931 .iter()
932 .position(|x| !x.is_nan())
933 .ok_or(DecyclerError::AllValuesNaN)?;
934 let warm: Vec<usize> = combos.iter().map(|_| first + 2).collect();
935 init_matrix_prefixes(&mut buf_mu, cols, &warm);
936
937 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
938 let out_f64: &mut [f64] =
939 unsafe { std::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
940
941 let combos = decycler_batch_inner_into(data, sweep, kern, parallel, out_f64)?;
942
943 let values = unsafe {
944 Vec::from_raw_parts(
945 guard.as_mut_ptr() as *mut f64,
946 guard.len(),
947 guard.capacity(),
948 )
949 };
950
951 Ok(DecyclerBatchOutput {
952 values,
953 combos,
954 rows,
955 cols,
956 })
957}
958
959#[inline(always)]
960unsafe fn decycler_row_scalar(
961 data: &[f64],
962 first: usize,
963 hp_period: usize,
964 k: f64,
965 out: &mut [f64],
966) {
967 use std::f64::consts::PI;
968
969 let angle = (2.0 * PI * k) * (hp_period as f64).recip();
970 let (sin_val, cos_val) = angle.sin_cos();
971 const EPSILON: f64 = 1e-10;
972 let cos_safe = if cos_val.abs() < EPSILON {
973 EPSILON.copysign(cos_val)
974 } else {
975 cos_val
976 };
977 let alpha = 1.0 + ((sin_val - 1.0) / cos_safe);
978 let one_minus_alpha_half = 1.0 - alpha / 2.0;
979 let c = one_minus_alpha_half * one_minus_alpha_half;
980 let one_minus_alpha = 1.0 - alpha;
981 let one_minus_alpha_sq = one_minus_alpha * one_minus_alpha;
982
983 let mut hp_prev2 = data[first];
984 let mut hp_prev1 = data[first + 1];
985
986 for i in (first + 2)..data.len() {
987 let current = data[i];
988 let prev1 = data[i - 1];
989 let prev2 = data[i - 2];
990
991 let s0 = current * c;
992 let s1 = prev1.mul_add(-2.0 * c, s0);
993 let s2 = prev2.mul_add(c, s1);
994 let s3 = hp_prev1.mul_add(2.0 * one_minus_alpha, s2);
995 let hp_val = hp_prev2.mul_add(-one_minus_alpha_sq, s3);
996
997 hp_prev2 = hp_prev1;
998 hp_prev1 = hp_val;
999
1000 out[i] = current - hp_val;
1001 }
1002}
1003
1004#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1005#[inline]
1006unsafe fn decycler_row_avx2(data: &[f64], first: usize, hp_period: usize, k: f64, out: &mut [f64]) {
1007 decycler_row_scalar(data, first, hp_period, k, out)
1008}
1009
1010#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1011#[inline]
1012pub unsafe fn decycler_row_avx512(
1013 data: &[f64],
1014 first: usize,
1015 hp_period: usize,
1016 k: f64,
1017 out: &mut [f64],
1018) {
1019 if hp_period <= 32 {
1020 decycler_row_avx512_short(data, first, hp_period, k, out)
1021 } else {
1022 decycler_row_avx512_long(data, first, hp_period, k, out)
1023 }
1024 _mm_sfence();
1025}
1026
1027#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1028#[inline]
1029pub unsafe fn decycler_row_avx512_short(
1030 data: &[f64],
1031 first: usize,
1032 hp_period: usize,
1033 k: f64,
1034 out: &mut [f64],
1035) {
1036 decycler_row_scalar(data, first, hp_period, k, out)
1037}
1038
1039#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1040#[inline]
1041pub unsafe fn decycler_row_avx512_long(
1042 data: &[f64],
1043 first: usize,
1044 hp_period: usize,
1045 k: f64,
1046 out: &mut [f64],
1047) {
1048 decycler_row_scalar(data, first, hp_period, k, out)
1049}
1050
1051#[inline(always)]
1052pub fn expand_grid_decycler(r: &DecyclerBatchRange) -> Result<Vec<DecyclerParams>, DecyclerError> {
1053 expand_grid(r)
1054}
1055
1056#[cfg(test)]
1057mod tests {
1058 use super::*;
1059 use crate::skip_if_unsupported;
1060 use crate::utilities::data_loader::read_candles_from_csv;
1061 use crate::utilities::enums::Kernel;
1062
1063 #[test]
1064 fn test_decycler_into_matches_api() -> Result<(), Box<dyn Error>> {
1065 let n = 512usize;
1066 let data: Vec<f64> = (0..n)
1067 .map(|i| ((i as f64) * 0.037).sin() * 5.0 + 100.0)
1068 .collect();
1069 let params = DecyclerParams::default();
1070 let input = DecyclerInput::from_slice(&data, params);
1071
1072 let baseline = decycler(&input)?.values;
1073
1074 let mut out = vec![0.0; n];
1075 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1076 {
1077 decycler_into(&input, &mut out)?;
1078 }
1079 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1080 {
1081 decycler_into_slice(&mut out, &input, Kernel::Auto)?;
1082 }
1083
1084 assert_eq!(baseline.len(), out.len());
1085
1086 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1087 (a.is_nan() && b.is_nan()) || (a == b)
1088 }
1089
1090 for i in 0..n {
1091 assert!(
1092 eq_or_both_nan(baseline[i], out[i]),
1093 "Mismatch at index {}: baseline={} out={} (bits {:016X} vs {:016X})",
1094 i,
1095 baseline[i],
1096 out[i],
1097 baseline[i].to_bits(),
1098 out[i].to_bits()
1099 );
1100 }
1101
1102 Ok(())
1103 }
1104
1105 fn check_decycler_partial_params(
1106 test_name: &str,
1107 kernel: Kernel,
1108 ) -> Result<(), Box<dyn Error>> {
1109 skip_if_unsupported!(kernel, test_name);
1110 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1111 let candles = read_candles_from_csv(file_path)?;
1112 let default_params = DecyclerParams {
1113 hp_period: None,
1114 k: None,
1115 };
1116 let input_default = DecyclerInput::from_candles(&candles, "close", default_params);
1117 let output_default = decycler_with_kernel(&input_default, kernel)?;
1118 assert_eq!(output_default.values.len(), candles.close.len());
1119
1120 let params_hp_50 = DecyclerParams {
1121 hp_period: Some(50),
1122 k: None,
1123 };
1124 let input_hp_50 = DecyclerInput::from_candles(&candles, "hl2", params_hp_50);
1125 let output_hp_50 = decycler_with_kernel(&input_hp_50, kernel)?;
1126 assert_eq!(output_hp_50.values.len(), candles.close.len());
1127
1128 let params_custom = DecyclerParams {
1129 hp_period: Some(30),
1130 k: None,
1131 };
1132 let input_custom = DecyclerInput::from_candles(&candles, "hlc3", params_custom);
1133 let output_custom = decycler_with_kernel(&input_custom, kernel)?;
1134 assert_eq!(output_custom.values.len(), candles.close.len());
1135 Ok(())
1136 }
1137
1138 fn check_decycler_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1139 skip_if_unsupported!(kernel, test_name);
1140 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1141 let candles = read_candles_from_csv(file_path)?;
1142 let close_prices = candles
1143 .select_candle_field("close")
1144 .expect("Failed to extract close prices");
1145 let params = DecyclerParams {
1146 hp_period: Some(125),
1147 k: None,
1148 };
1149 let input = DecyclerInput::from_candles(&candles, "close", params);
1150 let decycler_result = decycler_with_kernel(&input, kernel)?;
1151 assert_eq!(decycler_result.values.len(), close_prices.len());
1152 let test_values = [
1153 60289.96384058519,
1154 60204.010366691065,
1155 60114.255563805666,
1156 60028.535266555904,
1157 59934.26876964316,
1158 ];
1159 assert!(decycler_result.values.len() >= test_values.len());
1160 let start_index = decycler_result.values.len() - test_values.len();
1161 let result_last_values = &decycler_result.values[start_index..];
1162 for (i, &value) in result_last_values.iter().enumerate() {
1163 let expected_value = test_values[i];
1164 assert!(
1165 (value - expected_value).abs() < 1e-6,
1166 "Decycler mismatch at index {}: expected {}, got {}",
1167 i,
1168 expected_value,
1169 value
1170 );
1171 }
1172 Ok(())
1173 }
1174
1175 fn check_decycler_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1176 skip_if_unsupported!(kernel, test_name);
1177 let input_data = [10.0, 20.0, 30.0];
1178 let params = DecyclerParams {
1179 hp_period: Some(0),
1180 k: None,
1181 };
1182 let input = DecyclerInput::from_slice(&input_data, params);
1183 let result = decycler_with_kernel(&input, kernel);
1184 assert!(result.is_err());
1185 Ok(())
1186 }
1187
1188 fn check_decycler_period_exceed_length(
1189 test_name: &str,
1190 kernel: Kernel,
1191 ) -> Result<(), Box<dyn Error>> {
1192 skip_if_unsupported!(kernel, test_name);
1193 let input_data = [10.0, 20.0, 30.0];
1194 let params = DecyclerParams {
1195 hp_period: Some(10),
1196 k: None,
1197 };
1198 let input = DecyclerInput::from_slice(&input_data, params);
1199 let result = decycler_with_kernel(&input, kernel);
1200 assert!(result.is_err());
1201 Ok(())
1202 }
1203
1204 fn check_decycler_very_small_dataset(
1205 test_name: &str,
1206 kernel: Kernel,
1207 ) -> Result<(), Box<dyn Error>> {
1208 skip_if_unsupported!(kernel, test_name);
1209 let input_data = [42.0];
1210 let params = DecyclerParams {
1211 hp_period: Some(2),
1212 k: None,
1213 };
1214 let input = DecyclerInput::from_slice(&input_data, params);
1215 let result = decycler_with_kernel(&input, kernel);
1216 assert!(result.is_err());
1217 Ok(())
1218 }
1219
1220 fn check_decycler_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1221 skip_if_unsupported!(kernel, test_name);
1222 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1223 let candles = read_candles_from_csv(file_path)?;
1224 let first_params = DecyclerParams {
1225 hp_period: Some(30),
1226 k: None,
1227 };
1228 let first_input = DecyclerInput::from_candles(&candles, "close", first_params);
1229 let first_result = decycler_with_kernel(&first_input, kernel)?;
1230 assert_eq!(first_result.values.len(), candles.close.len());
1231 let second_params = DecyclerParams {
1232 hp_period: Some(30),
1233 k: None,
1234 };
1235 let second_input = DecyclerInput::from_slice(&first_result.values, second_params);
1236 let second_result = decycler_with_kernel(&second_input, kernel)?;
1237 assert_eq!(second_result.values.len(), first_result.values.len());
1238 Ok(())
1239 }
1240
1241 fn check_decycler_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1242 skip_if_unsupported!(kernel, test_name);
1243 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1244 let candles = read_candles_from_csv(file_path)?;
1245 let close_prices = &candles.close;
1246 let period = 125;
1247 let params = DecyclerParams {
1248 hp_period: Some(period),
1249 k: None,
1250 };
1251 let input = DecyclerInput::from_candles(&candles, "close", params);
1252 let decycler_result = decycler_with_kernel(&input, kernel)?;
1253 assert_eq!(decycler_result.values.len(), close_prices.len());
1254 if decycler_result.values.len() > 240 {
1255 for i in 240..decycler_result.values.len() {
1256 assert!(
1257 !decycler_result.values[i].is_nan(),
1258 "Expected no NaN after index 240, found NaN at {}",
1259 i
1260 );
1261 }
1262 }
1263 Ok(())
1264 }
1265
1266 #[cfg(debug_assertions)]
1267 fn check_decycler_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1268 skip_if_unsupported!(kernel, test_name);
1269
1270 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1271 let candles = read_candles_from_csv(file_path)?;
1272
1273 let test_params = vec![
1274 DecyclerParams::default(),
1275 DecyclerParams {
1276 hp_period: Some(2),
1277 k: Some(0.1),
1278 },
1279 DecyclerParams {
1280 hp_period: Some(2),
1281 k: Some(0.707),
1282 },
1283 DecyclerParams {
1284 hp_period: Some(2),
1285 k: Some(2.0),
1286 },
1287 DecyclerParams {
1288 hp_period: Some(5),
1289 k: Some(0.5),
1290 },
1291 DecyclerParams {
1292 hp_period: Some(5),
1293 k: Some(1.0),
1294 },
1295 DecyclerParams {
1296 hp_period: Some(10),
1297 k: Some(0.707),
1298 },
1299 DecyclerParams {
1300 hp_period: Some(20),
1301 k: Some(0.3),
1302 },
1303 DecyclerParams {
1304 hp_period: Some(20),
1305 k: Some(0.707),
1306 },
1307 DecyclerParams {
1308 hp_period: Some(20),
1309 k: Some(1.5),
1310 },
1311 DecyclerParams {
1312 hp_period: Some(50),
1313 k: Some(0.707),
1314 },
1315 DecyclerParams {
1316 hp_period: Some(75),
1317 k: Some(0.9),
1318 },
1319 DecyclerParams {
1320 hp_period: Some(100),
1321 k: Some(0.707),
1322 },
1323 DecyclerParams {
1324 hp_period: Some(100),
1325 k: Some(2.5),
1326 },
1327 DecyclerParams {
1328 hp_period: Some(200),
1329 k: Some(0.707),
1330 },
1331 DecyclerParams {
1332 hp_period: Some(200),
1333 k: Some(5.0),
1334 },
1335 DecyclerParams {
1336 hp_period: Some(500),
1337 k: Some(0.707),
1338 },
1339 ];
1340
1341 for (param_idx, params) in test_params.iter().enumerate() {
1342 let input = DecyclerInput::from_candles(&candles, "close", params.clone());
1343 let output = decycler_with_kernel(&input, kernel)?;
1344
1345 for (i, &val) in output.values.iter().enumerate() {
1346 if val.is_nan() {
1347 continue;
1348 }
1349
1350 let bits = val.to_bits();
1351
1352 if bits == 0x11111111_11111111 {
1353 panic!(
1354 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1355 with params: hp_period={}, k={} (param set {})",
1356 test_name,
1357 val,
1358 bits,
1359 i,
1360 params.hp_period.unwrap_or(125),
1361 params.k.unwrap_or(0.707),
1362 param_idx
1363 );
1364 }
1365
1366 if bits == 0x22222222_22222222 {
1367 panic!(
1368 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1369 with params: hp_period={}, k={} (param set {})",
1370 test_name,
1371 val,
1372 bits,
1373 i,
1374 params.hp_period.unwrap_or(125),
1375 params.k.unwrap_or(0.707),
1376 param_idx
1377 );
1378 }
1379
1380 if bits == 0x33333333_33333333 {
1381 panic!(
1382 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1383 with params: hp_period={}, k={} (param set {})",
1384 test_name,
1385 val,
1386 bits,
1387 i,
1388 params.hp_period.unwrap_or(125),
1389 params.k.unwrap_or(0.707),
1390 param_idx
1391 );
1392 }
1393 }
1394 }
1395
1396 Ok(())
1397 }
1398
1399 #[cfg(not(debug_assertions))]
1400 fn check_decycler_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1401 Ok(())
1402 }
1403
1404 macro_rules! generate_all_decycler_tests {
1405 ($($test_fn:ident),*) => {
1406 paste::paste! {
1407 $(
1408 #[test]
1409 fn [<$test_fn _scalar_f64>]() {
1410 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1411 }
1412 )*
1413 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1414 $(
1415 #[test]
1416 fn [<$test_fn _avx2_f64>]() {
1417 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1418 }
1419 #[test]
1420 fn [<$test_fn _avx512_f64>]() {
1421 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1422 }
1423 )*
1424 }
1425 }
1426 }
1427 generate_all_decycler_tests!(
1428 check_decycler_partial_params,
1429 check_decycler_accuracy,
1430 check_decycler_zero_period,
1431 check_decycler_period_exceed_length,
1432 check_decycler_very_small_dataset,
1433 check_decycler_reinput,
1434 check_decycler_nan_handling,
1435 check_decycler_no_poison
1436 );
1437
1438 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1439 skip_if_unsupported!(kernel, test);
1440 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1441 let c = read_candles_from_csv(file)?;
1442 let output = DecyclerBatchBuilder::new()
1443 .kernel(kernel)
1444 .apply_candles(&c, "close")?;
1445 let def = DecyclerParams::default();
1446 let row = output.values_for(&def).expect("default row missing");
1447 assert_eq!(row.len(), c.close.len());
1448 let expected = [
1449 60289.96384058519,
1450 60204.010366691065,
1451 60114.255563805666,
1452 60028.535266555904,
1453 59934.26876964316,
1454 ];
1455 let start = row.len() - 5;
1456 for (i, &v) in row[start..].iter().enumerate() {
1457 assert!(
1458 (v - expected[i]).abs() < 1e-6,
1459 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1460 );
1461 }
1462 Ok(())
1463 }
1464
1465 macro_rules! gen_batch_tests {
1466 ($fn_name:ident) => {
1467 paste::paste! {
1468 #[test] fn [<$fn_name _scalar>]() {
1469 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1470 }
1471 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1472 #[test] fn [<$fn_name _avx2>]() {
1473 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1474 }
1475 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1476 #[test] fn [<$fn_name _avx512>]() {
1477 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1478 }
1479 #[test] fn [<$fn_name _auto_detect>]() {
1480 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]),
1481 Kernel::Auto);
1482 }
1483 }
1484 };
1485 }
1486 #[cfg(debug_assertions)]
1487 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1488 skip_if_unsupported!(kernel, test);
1489
1490 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1491 let c = read_candles_from_csv(file)?;
1492
1493 let test_configs = vec![
1494 (2, 10, 2, 0.1, 1.0, 0.3),
1495 (5, 25, 5, 0.707, 0.707, 0.0),
1496 (10, 10, 0, 0.5, 2.0, 0.5),
1497 (2, 5, 1, 0.3, 0.9, 0.2),
1498 (30, 60, 15, 0.707, 0.707, 0.0),
1499 (50, 100, 25, 0.5, 1.5, 0.5),
1500 (75, 125, 25, 0.707, 2.0, 0.5),
1501 (100, 200, 50, 0.5, 1.0, 0.25),
1502 (200, 500, 100, 0.707, 0.707, 0.0),
1503 ];
1504
1505 for (cfg_idx, &(hp_start, hp_end, hp_step, k_start, k_end, k_step)) in
1506 test_configs.iter().enumerate()
1507 {
1508 let output = DecyclerBatchBuilder::new()
1509 .kernel(kernel)
1510 .hp_period_range(hp_start, hp_end, hp_step)
1511 .k_range(k_start, k_end, k_step)
1512 .apply_candles(&c, "close")?;
1513
1514 for (idx, &val) in output.values.iter().enumerate() {
1515 if val.is_nan() {
1516 continue;
1517 }
1518
1519 let bits = val.to_bits();
1520 let row = idx / output.cols;
1521 let col = idx % output.cols;
1522 let combo = &output.combos[row];
1523
1524 if bits == 0x11111111_11111111 {
1525 panic!(
1526 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1527 at row {} col {} (flat index {}) with params: hp_period={}, k={}",
1528 test,
1529 cfg_idx,
1530 val,
1531 bits,
1532 row,
1533 col,
1534 idx,
1535 combo.hp_period.unwrap_or(125),
1536 combo.k.unwrap_or(0.707)
1537 );
1538 }
1539
1540 if bits == 0x22222222_22222222 {
1541 panic!(
1542 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1543 at row {} col {} (flat index {}) with params: hp_period={}, k={}",
1544 test,
1545 cfg_idx,
1546 val,
1547 bits,
1548 row,
1549 col,
1550 idx,
1551 combo.hp_period.unwrap_or(125),
1552 combo.k.unwrap_or(0.707)
1553 );
1554 }
1555
1556 if bits == 0x33333333_33333333 {
1557 panic!(
1558 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1559 at row {} col {} (flat index {}) with params: hp_period={}, k={}",
1560 test,
1561 cfg_idx,
1562 val,
1563 bits,
1564 row,
1565 col,
1566 idx,
1567 combo.hp_period.unwrap_or(125),
1568 combo.k.unwrap_or(0.707)
1569 );
1570 }
1571 }
1572 }
1573
1574 Ok(())
1575 }
1576
1577 #[cfg(not(debug_assertions))]
1578 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1579 Ok(())
1580 }
1581
1582 gen_batch_tests!(check_batch_default_row);
1583 gen_batch_tests!(check_batch_no_poison);
1584
1585 #[cfg(feature = "proptest")]
1586 #[allow(clippy::float_cmp)]
1587 fn check_decycler_property(
1588 test_name: &str,
1589 kernel: Kernel,
1590 ) -> Result<(), Box<dyn std::error::Error>> {
1591 use proptest::prelude::*;
1592 skip_if_unsupported!(kernel, test_name);
1593
1594 let strat = (2usize..=100).prop_flat_map(|period| {
1595 (
1596 prop::collection::vec(
1597 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1598 period..400,
1599 ),
1600 Just(period),
1601 0.1f64..5.0f64,
1602 )
1603 });
1604
1605 proptest::test_runner::TestRunner::default().run(&strat, |(data, hp_period, k)| {
1606 let params = DecyclerParams {
1607 hp_period: Some(hp_period),
1608 k: Some(k),
1609 };
1610 let input = DecyclerInput::from_slice(&data, params);
1611
1612 let DecyclerOutput { values: out } = decycler_with_kernel(&input, kernel).unwrap();
1613
1614 let DecyclerOutput { values: ref_out } =
1615 decycler_with_kernel(&input, Kernel::Scalar).unwrap();
1616
1617 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1618 let warmup_period = first + 2;
1619
1620 for i in 0..warmup_period.min(out.len()) {
1621 prop_assert!(
1622 out[i].is_nan(),
1623 "Expected NaN during warmup at index {}, got {}",
1624 i,
1625 out[i]
1626 );
1627 }
1628
1629 for i in warmup_period..out.len() {
1630 prop_assert!(
1631 out[i].is_finite(),
1632 "Expected finite value after warmup at index {}, got {}",
1633 i,
1634 out[i]
1635 );
1636 }
1637
1638 for i in 0..out.len() {
1639 let y = out[i];
1640 let r = ref_out[i];
1641
1642 if y.is_nan() && r.is_nan() {
1643 continue;
1644 }
1645
1646 let y_bits = y.to_bits();
1647 let r_bits = r.to_bits();
1648 let ulp_diff = if y_bits > r_bits {
1649 y_bits - r_bits
1650 } else {
1651 r_bits - y_bits
1652 };
1653
1654 prop_assert!(
1655 ulp_diff <= 3,
1656 "Kernel mismatch at index {}: {} ({} bits) vs {} ({} bits), ULP diff: {}",
1657 i,
1658 y,
1659 y_bits,
1660 r,
1661 r_bits,
1662 ulp_diff
1663 );
1664 }
1665
1666 if data.windows(2).all(|w| (w[0] - w[1]).abs() < 1e-9) && !data.is_empty() {
1667 if out.len() > warmup_period + hp_period * 2 {
1668 let check_idx = out.len() - 1;
1669 prop_assert!(
1670 out[check_idx].abs() <= 1e-6 || out[check_idx].abs() <= data[first].abs() * 1e-6,
1671 "Constant input {} should produce output close to 0 (high-pass filtered), got {}",
1672 data[first],
1673 out[check_idx]
1674 );
1675 }
1676 }
1677
1678 if out.len() > warmup_period + hp_period {
1679 let start_check = warmup_period + hp_period;
1680 for i in start_check..out.len() {
1681 let input_mag = data[i].abs();
1682 let output_mag = out[i].abs();
1683
1684 prop_assert!(
1685 output_mag <= input_mag * 2.0 + 1.0,
1686 "Output magnitude {} exceeds reasonable bound for input {} at index {}",
1687 output_mag,
1688 input_mag,
1689 i
1690 );
1691 }
1692 }
1693
1694 if data.len() > warmup_period + 10 {
1695 let trend_start = warmup_period;
1696 let trend_end = data.len().min(warmup_period + 50);
1697 let is_monotonic_increasing = data[trend_start..trend_end]
1698 .windows(2)
1699 .all(|w| w[1] >= w[0] - 1e-9);
1700 let is_monotonic_decreasing = data[trend_start..trend_end]
1701 .windows(2)
1702 .all(|w| w[1] <= w[0] + 1e-9);
1703
1704 if (is_monotonic_increasing || is_monotonic_decreasing)
1705 && trend_end > trend_start + 5
1706 {
1707 let input_range = data[trend_start..trend_end]
1708 .iter()
1709 .fold(f64::INFINITY, |a, &b| a.min(b))
1710 ..=data[trend_start..trend_end]
1711 .iter()
1712 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1713 let output_range = out[trend_start..trend_end]
1714 .iter()
1715 .fold(f64::INFINITY, |a, &b| a.min(b))
1716 ..=out[trend_start..trend_end]
1717 .iter()
1718 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1719
1720 let input_span = input_range.end() - input_range.start();
1721 let output_span = output_range.end() - output_range.start();
1722
1723 if input_span > 1e-9 {
1724 prop_assert!(
1725 output_span <= input_span * 1.5,
1726 "Decycler should reduce trend variation: input span {}, output span {}",
1727 input_span,
1728 output_span
1729 );
1730 }
1731 }
1732 }
1733
1734 #[cfg(debug_assertions)]
1735 for (i, &val) in out.iter().enumerate() {
1736 if val.is_nan() {
1737 continue;
1738 }
1739
1740 let bits = val.to_bits();
1741 prop_assert!(
1742 bits != 0x11111111_11111111
1743 && bits != 0x22222222_22222222
1744 && bits != 0x33333333_33333333,
1745 "Found poison value at index {}: {} (0x{:016X})",
1746 i,
1747 val,
1748 bits
1749 );
1750 }
1751
1752 if hp_period == 2 {
1753 prop_assert!(
1754 out.len() == data.len(),
1755 "Output length mismatch for period=2"
1756 );
1757 }
1758
1759 if k < 0.2 || k > 4.5 {
1760 for i in warmup_period..out.len() {
1761 prop_assert!(
1762 out[i].is_finite(),
1763 "Extreme k={} produced non-finite value at index {}",
1764 k,
1765 i
1766 );
1767 }
1768 }
1769
1770 Ok(())
1771 })?;
1772
1773 Ok(())
1774 }
1775
1776 #[cfg(feature = "proptest")]
1777 generate_all_decycler_tests!(check_decycler_property);
1778}
1779
1780#[cfg(feature = "python")]
1781#[pyfunction(name = "decycler")]
1782#[pyo3(signature = (data, hp_period=None, k=None, kernel=None))]
1783pub fn decycler_py<'py>(
1784 py: Python<'py>,
1785 data: numpy::PyReadonlyArray1<'py, f64>,
1786 hp_period: Option<usize>,
1787 k: Option<f64>,
1788 kernel: Option<&str>,
1789) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1790 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1791
1792 let slice_in = data.as_slice()?;
1793 let kern = validate_kernel(kernel, false)?;
1794 let params = DecyclerParams { hp_period, k };
1795 let decycler_in = DecyclerInput::from_slice(slice_in, params);
1796
1797 let result_vec: Vec<f64> = py
1798 .allow_threads(|| decycler_with_kernel(&decycler_in, kern).map(|o| o.values))
1799 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1800
1801 Ok(result_vec.into_pyarray(py))
1802}
1803
1804#[cfg(feature = "python")]
1805#[pyclass(name = "DecyclerStream")]
1806pub struct DecyclerStreamPy {
1807 stream: DecyclerStream,
1808}
1809
1810#[cfg(feature = "python")]
1811#[pymethods]
1812impl DecyclerStreamPy {
1813 #[new]
1814 fn new(hp_period: Option<usize>, k: Option<f64>) -> PyResult<Self> {
1815 let params = DecyclerParams { hp_period, k };
1816 let stream =
1817 DecyclerStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1818 Ok(DecyclerStreamPy { stream })
1819 }
1820
1821 fn update(&mut self, value: f64) -> Option<f64> {
1822 self.stream.update(value)
1823 }
1824}
1825
1826#[cfg(feature = "python")]
1827#[pyfunction(name = "decycler_batch")]
1828#[pyo3(signature = (data, hp_period_range, k_range, kernel=None))]
1829pub fn decycler_batch_py<'py>(
1830 py: Python<'py>,
1831 data: numpy::PyReadonlyArray1<'py, f64>,
1832 hp_period_range: (usize, usize, usize),
1833 k_range: (f64, f64, f64),
1834 kernel: Option<&str>,
1835) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1836 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1837 use pyo3::types::PyDict;
1838 use std::mem::MaybeUninit;
1839
1840 let slice_in = data.as_slice()?;
1841 let sweep = DecyclerBatchRange {
1842 hp_period: hp_period_range,
1843 k: k_range,
1844 };
1845
1846 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1847 let rows = combos.len();
1848 let cols = slice_in.len();
1849
1850 let total = rows
1851 .checked_mul(cols)
1852 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1853 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1854 let slice_out = unsafe { out_arr.as_slice_mut()? };
1855
1856 let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
1857 let warmup = first + 2;
1858 for row in 0..rows {
1859 let row_start = row * cols;
1860 for i in 0..warmup.min(cols) {
1861 slice_out[row_start + i] = f64::NAN;
1862 }
1863 }
1864
1865 let kern = validate_kernel(kernel, true)?;
1866
1867 let combos = py
1868 .allow_threads(|| {
1869 let kernel = match kern {
1870 Kernel::Auto => detect_best_batch_kernel(),
1871 k => k,
1872 };
1873 let simd = match kernel {
1874 Kernel::Avx512Batch => Kernel::Avx512,
1875 Kernel::Avx2Batch => Kernel::Avx2,
1876 Kernel::ScalarBatch => Kernel::Scalar,
1877 _ => unreachable!(),
1878 };
1879 decycler_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1880 })
1881 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1882
1883 let dict = PyDict::new(py);
1884 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1885 dict.set_item(
1886 "hp_periods",
1887 combos
1888 .iter()
1889 .map(|p| p.hp_period.unwrap() as u64)
1890 .collect::<Vec<_>>()
1891 .into_pyarray(py),
1892 )?;
1893 dict.set_item(
1894 "ks",
1895 combos
1896 .iter()
1897 .map(|p| p.k.unwrap())
1898 .collect::<Vec<_>>()
1899 .into_pyarray(py),
1900 )?;
1901 dict.set_item("rows", rows)?;
1902 dict.set_item("cols", cols)?;
1903 Ok(dict)
1904}
1905
1906#[cfg(all(feature = "python", feature = "cuda"))]
1907#[pyfunction(name = "decycler_cuda_batch_dev")]
1908#[pyo3(signature = (data_f32, hp_period_range=(125, 125, 0), k_range=(0.707, 0.707, 0.0), device_id=0))]
1909pub fn decycler_cuda_batch_dev_py(
1910 py: Python<'_>,
1911 data_f32: PyReadonlyArray1<'_, f32>,
1912 hp_period_range: (usize, usize, usize),
1913 k_range: (f64, f64, f64),
1914 device_id: usize,
1915) -> PyResult<DeviceArrayF32Py> {
1916 if !cuda_available() {
1917 return Err(PyValueError::new_err("CUDA not available"));
1918 }
1919 let slice_in = data_f32.as_slice()?;
1920 let sweep = DecyclerBatchRange {
1921 hp_period: hp_period_range,
1922 k: k_range,
1923 };
1924 let inner = py.allow_threads(|| -> PyResult<_> {
1925 let cuda =
1926 CudaDecycler::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1927 cuda.decycler_batch_dev(slice_in, &sweep)
1928 .map_err(|e| PyValueError::new_err(e.to_string()))
1929 })?;
1930 Ok(DeviceArrayF32Py { inner })
1931}
1932
1933#[cfg(all(feature = "python", feature = "cuda"))]
1934#[pyfunction(name = "decycler_cuda_many_series_one_param_dev")]
1935#[pyo3(signature = (data_tm_f32, hp_period, k, device_id=0))]
1936pub fn decycler_cuda_many_series_one_param_dev_py(
1937 py: Python<'_>,
1938 data_tm_f32: PyReadonlyArray2<'_, f32>,
1939 hp_period: usize,
1940 k: f64,
1941 device_id: usize,
1942) -> PyResult<DeviceArrayF32Py> {
1943 if !cuda_available() {
1944 return Err(PyValueError::new_err("CUDA not available"));
1945 }
1946 if hp_period < 2 {
1947 return Err(PyValueError::new_err("hp_period must be >= 2"));
1948 }
1949 if !(k > 0.0) || !k.is_finite() {
1950 return Err(PyValueError::new_err("k must be positive and finite"));
1951 }
1952
1953 let flat = data_tm_f32.as_slice()?;
1954 let shape = data_tm_f32.shape();
1955 let series_len = shape[0];
1956 let num_series = shape[1];
1957 let params = DecyclerParams {
1958 hp_period: Some(hp_period),
1959 k: Some(k),
1960 };
1961 let inner = py.allow_threads(|| -> PyResult<_> {
1962 let cuda =
1963 CudaDecycler::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1964 cuda.decycler_many_series_one_param_time_major_dev(flat, num_series, series_len, ¶ms)
1965 .map_err(|e| PyValueError::new_err(e.to_string()))
1966 })?;
1967 Ok(DeviceArrayF32Py { inner })
1968}
1969
1970#[inline(always)]
1971fn decycler_batch_inner_into(
1972 data: &[f64],
1973 sweep: &DecyclerBatchRange,
1974 kern: Kernel,
1975 parallel: bool,
1976 out: &mut [f64],
1977) -> Result<Vec<DecyclerParams>, DecyclerError> {
1978 let combos = expand_grid(sweep)?;
1979
1980 let first = data
1981 .iter()
1982 .position(|x| !x.is_nan())
1983 .ok_or(DecyclerError::AllValuesNaN)?;
1984 let max_p = combos.iter().map(|c| c.hp_period.unwrap()).max().unwrap();
1985 let _guard = combos
1986 .len()
1987 .checked_mul(max_p)
1988 .ok_or_else(|| DecyclerError::InvalidInput("n_combos*max_period overflow".into()))?;
1989 if data.len() - first < max_p {
1990 return Err(DecyclerError::NotEnoughValidData {
1991 needed: max_p,
1992 valid: data.len() - first,
1993 });
1994 }
1995
1996 let rows = combos.len();
1997 let cols = data.len();
1998 let expected = rows
1999 .checked_mul(cols)
2000 .ok_or_else(|| DecyclerError::InvalidInput("rows*cols overflow".into()))?;
2001 if out.len() != expected {
2002 return Err(DecyclerError::OutputLengthMismatch {
2003 expected,
2004 got: out.len(),
2005 });
2006 }
2007
2008 let mut diff: Vec<f64> = vec![0.0; cols];
2009 for i in (first + 2)..cols {
2010 diff[i] = data[i] - 2.0 * data[i - 1] + data[i - 2];
2011 }
2012
2013 let out_mu = unsafe {
2014 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
2015 };
2016
2017 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
2018 let hp_period = combos[row].hp_period.unwrap();
2019 let k = combos[row].k.unwrap();
2020 let dst = std::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
2021
2022 let angle = (2.0 * std::f64::consts::PI * k) * (hp_period as f64).recip();
2023 let (sin_val, cos_val) = angle.sin_cos();
2024 const EPSILON: f64 = 1e-10;
2025 let cos_safe = if cos_val.abs() < EPSILON {
2026 EPSILON.copysign(cos_val)
2027 } else {
2028 cos_val
2029 };
2030 let alpha = 1.0 + ((sin_val - 1.0) / cos_safe);
2031 let one_minus_alpha_half = 1.0 - alpha / 2.0;
2032 let c = one_minus_alpha_half * one_minus_alpha_half;
2033 let one_minus_alpha = 1.0 - alpha;
2034 let one_minus_alpha_sq = one_minus_alpha * one_minus_alpha;
2035
2036 let mut hp_prev2 = data[first];
2037 let mut hp_prev1 = data[first + 1];
2038
2039 for i in (first + 2)..cols {
2040 let current = data[i];
2041 let s3 = hp_prev1.mul_add(2.0 * one_minus_alpha, c * diff[i]);
2042 let hp_val = hp_prev2.mul_add(-one_minus_alpha_sq, s3);
2043
2044 hp_prev2 = hp_prev1;
2045 hp_prev1 = hp_val;
2046
2047 dst[i] = current - hp_val;
2048 }
2049 };
2050
2051 if parallel {
2052 #[cfg(not(target_arch = "wasm32"))]
2053 out_mu
2054 .par_chunks_mut(cols)
2055 .enumerate()
2056 .for_each(|(r, s)| do_row(r, s));
2057
2058 #[cfg(target_arch = "wasm32")]
2059 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
2060 do_row(r, s);
2061 }
2062 } else {
2063 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
2064 do_row(r, s);
2065 }
2066 }
2067
2068 Ok(combos)
2069}
2070
2071#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2072#[wasm_bindgen]
2073pub fn decycler_js(data: &[f64], hp_period: usize, k: f64) -> Result<Vec<f64>, JsValue> {
2074 let params = DecyclerParams {
2075 hp_period: Some(hp_period),
2076 k: Some(k),
2077 };
2078 let input = DecyclerInput::from_slice(data, params);
2079 let mut output = vec![0.0; data.len()];
2080 decycler_into_slice(&mut output, &input, detect_best_kernel())
2081 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2082 Ok(output)
2083}
2084
2085#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2086#[wasm_bindgen]
2087pub fn decycler_into(
2088 in_ptr: *const f64,
2089 out_ptr: *mut f64,
2090 len: usize,
2091 hp_period: usize,
2092 k: f64,
2093) -> Result<(), JsValue> {
2094 if in_ptr.is_null() || out_ptr.is_null() {
2095 return Err(JsValue::from_str("null pointer"));
2096 }
2097
2098 unsafe {
2099 let data = std::slice::from_raw_parts(in_ptr, len);
2100 let params = DecyclerParams {
2101 hp_period: Some(hp_period),
2102 k: Some(k),
2103 };
2104 let input = DecyclerInput::from_slice(data, params);
2105
2106 if in_ptr == out_ptr as *const f64 {
2107 let mut tmp = vec![0.0; len];
2108 decycler_into_slice(&mut tmp, &input, detect_best_kernel())
2109 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2110 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
2111 } else {
2112 decycler_into_slice(
2113 std::slice::from_raw_parts_mut(out_ptr, len),
2114 &input,
2115 detect_best_kernel(),
2116 )
2117 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2118 }
2119 }
2120 Ok(())
2121}
2122
2123#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2124#[wasm_bindgen]
2125pub fn decycler_alloc(len: usize) -> *mut f64 {
2126 let mut vec = Vec::<f64>::with_capacity(len);
2127 let ptr = vec.as_mut_ptr();
2128 std::mem::forget(vec);
2129 ptr
2130}
2131
2132#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2133#[wasm_bindgen]
2134pub fn decycler_free(ptr: *mut f64, len: usize) {
2135 if !ptr.is_null() {
2136 unsafe {
2137 let _ = Vec::from_raw_parts(ptr, len, len);
2138 }
2139 }
2140}
2141
2142#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2143#[derive(Serialize, Deserialize)]
2144pub struct DecyclerBatchConfig {
2145 pub hp_period_range: (usize, usize, usize),
2146 pub k_range: (f64, f64, f64),
2147}
2148
2149#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2150#[derive(Serialize, Deserialize)]
2151pub struct DecyclerBatchJsOutput {
2152 pub values: Vec<f64>,
2153 pub combos: Vec<DecyclerParams>,
2154 pub rows: usize,
2155 pub cols: usize,
2156}
2157
2158#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2159#[wasm_bindgen(js_name = decycler_batch)]
2160pub fn decycler_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2161 let cfg: DecyclerBatchConfig = serde_wasm_bindgen::from_value(config)
2162 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2163
2164 let sweep = DecyclerBatchRange {
2165 hp_period: cfg.hp_period_range,
2166 k: cfg.k_range,
2167 };
2168
2169 let out = decycler_batch_inner(data, &sweep, detect_best_kernel(), false)
2170 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2171
2172 let js = DecyclerBatchJsOutput {
2173 values: out.values,
2174 combos: out.combos,
2175 rows: out.rows,
2176 cols: out.cols,
2177 };
2178 serde_wasm_bindgen::to_value(&js)
2179 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2180}
2181
2182#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2183#[wasm_bindgen]
2184pub fn decycler_batch_into(
2185 in_ptr: *const f64,
2186 out_ptr: *mut f64,
2187 len: usize,
2188 hp_start: usize,
2189 hp_end: usize,
2190 hp_step: usize,
2191 k_start: f64,
2192 k_end: f64,
2193 k_step: f64,
2194) -> Result<usize, JsValue> {
2195 if in_ptr.is_null() || out_ptr.is_null() {
2196 return Err(JsValue::from_str("null pointer"));
2197 }
2198
2199 unsafe {
2200 let data = std::slice::from_raw_parts(in_ptr, len);
2201 let sweep = DecyclerBatchRange {
2202 hp_period: (hp_start, hp_end, hp_step),
2203 k: (k_start, k_end, k_step),
2204 };
2205 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2206 let rows = combos.len();
2207 let cols = len;
2208 let total = rows
2209 .checked_mul(cols)
2210 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2211 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2212 decycler_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
2213 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2214 Ok(rows)
2215 }
2216}