1#[cfg(all(feature = "python", feature = "cuda"))]
2mod cvi_python_cuda_handle {
3 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
4 use cust::context::Context;
5 use cust::memory::DeviceBuffer;
6 use pyo3::exceptions::PyValueError;
7 use pyo3::prelude::*;
8 use pyo3::types::PyDict;
9 use std::ffi::c_void;
10 use std::sync::Arc;
11
12 #[pyclass(module = "ta_indicators.cuda", unsendable, name = "DeviceArrayF32Py")]
13 pub struct DeviceArrayF32Py {
14 pub(crate) buf: Option<DeviceBuffer<f32>>,
15 pub(crate) rows: usize,
16 pub(crate) cols: usize,
17 pub(crate) _ctx: Arc<Context>,
18 pub(crate) device_id: u32,
19 }
20
21 #[pymethods]
22 impl DeviceArrayF32Py {
23 #[getter]
24 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
25 let d = PyDict::new(py);
26 d.set_item("shape", (self.rows, self.cols))?;
27 d.set_item("typestr", "<f4")?;
28 d.set_item(
29 "strides",
30 (
31 self.cols * std::mem::size_of::<f32>(),
32 std::mem::size_of::<f32>(),
33 ),
34 )?;
35 let ptr = self
36 .buf
37 .as_ref()
38 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
39 .as_device_ptr()
40 .as_raw() as usize;
41 d.set_item("data", (ptr, false))?;
42 d.set_item("version", 3)?;
43 Ok(d)
44 }
45
46 fn __dlpack_device__(&self) -> (i32, i32) {
47 let mut device_ordinal: i32 = self.device_id as i32;
48 unsafe {
49 let attr = cust::sys::CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL;
50 let mut value = std::mem::MaybeUninit::<i32>::uninit();
51 let ptr = self
52 .buf
53 .as_ref()
54 .map(|b| b.as_device_ptr().as_raw())
55 .unwrap_or(0);
56 if ptr != 0 {
57 let rc = cust::sys::cuPointerGetAttribute(
58 value.as_mut_ptr() as *mut c_void,
59 attr,
60 ptr,
61 );
62 if rc == cust::sys::CUresult::CUDA_SUCCESS {
63 device_ordinal = value.assume_init();
64 }
65 }
66 }
67 (2, device_ordinal)
68 }
69
70 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
71 fn __dlpack__<'py>(
72 &mut self,
73 py: Python<'py>,
74 stream: Option<pyo3::PyObject>,
75 max_version: Option<pyo3::PyObject>,
76 dl_device: Option<pyo3::PyObject>,
77 copy: Option<pyo3::PyObject>,
78 ) -> PyResult<PyObject> {
79 let (kdl, alloc_dev) = self.__dlpack_device__();
80 if let Some(dev_obj) = dl_device.as_ref() {
81 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
82 if dev_ty != kdl || dev_id != alloc_dev {
83 let wants_copy = copy
84 .as_ref()
85 .and_then(|c| c.extract::<bool>(py).ok())
86 .unwrap_or(false);
87 if wants_copy {
88 return Err(PyValueError::new_err(
89 "device copy not implemented for __dlpack__",
90 ));
91 } else {
92 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
93 }
94 }
95 }
96 }
97 let _ = stream;
98
99 let buf = self
100 .buf
101 .take()
102 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
103
104 let rows = self.rows;
105 let cols = self.cols;
106
107 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
108
109 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
110 }
111 }
112
113 pub use DeviceArrayF32Py as CviDeviceArrayF32Py;
114}
115
116#[cfg(all(feature = "python", feature = "cuda"))]
117use self::cvi_python_cuda_handle::CviDeviceArrayF32Py;
118use crate::utilities::data_loader::Candles;
119use crate::utilities::enums::Kernel;
120use crate::utilities::helpers::{
121 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
122 make_uninit_matrix,
123};
124#[cfg(feature = "python")]
125use crate::utilities::kernel_validation::validate_kernel;
126use aligned_vec::{AVec, CACHELINE_ALIGN};
127#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
128use core::arch::x86_64::*;
129#[cfg(feature = "python")]
130use numpy::{IntoPyArray, PyArray1};
131#[cfg(feature = "python")]
132use pyo3::exceptions::PyValueError;
133#[cfg(feature = "python")]
134use pyo3::prelude::*;
135#[cfg(feature = "python")]
136use pyo3::types::{PyDict, PyList};
137#[cfg(not(target_arch = "wasm32"))]
138use rayon::prelude::*;
139#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
140use serde::{Deserialize, Serialize};
141use std::mem::MaybeUninit;
142use thiserror::Error;
143#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
144use wasm_bindgen::prelude::*;
145
146#[derive(Debug, Clone)]
147pub enum CviData<'a> {
148 Candles(&'a Candles),
149 Slices { high: &'a [f64], low: &'a [f64] },
150}
151
152#[derive(Debug, Clone)]
153pub struct CviOutput {
154 pub values: Vec<f64>,
155}
156
157#[derive(Debug, Clone)]
158#[cfg_attr(
159 all(target_arch = "wasm32", feature = "wasm"),
160 derive(Serialize, Deserialize)
161)]
162pub struct CviParams {
163 pub period: Option<usize>,
164}
165
166impl Default for CviParams {
167 fn default() -> Self {
168 Self { period: Some(10) }
169 }
170}
171
172#[derive(Debug, Clone)]
173pub struct CviInput<'a> {
174 pub data: CviData<'a>,
175 pub params: CviParams,
176}
177
178impl<'a> CviInput<'a> {
179 #[inline]
180 pub fn from_candles(c: &'a Candles, p: CviParams) -> Self {
181 Self {
182 data: CviData::Candles(c),
183 params: p,
184 }
185 }
186 #[inline]
187 pub fn with_default_candles(c: &'a Candles) -> Self {
188 Self::from_candles(c, CviParams::default())
189 }
190 #[inline]
191 pub fn from_slices(high: &'a [f64], low: &'a [f64], p: CviParams) -> Self {
192 Self {
193 data: CviData::Slices { high, low },
194 params: p,
195 }
196 }
197
198 #[inline]
199 pub fn get_period(&self) -> usize {
200 self.params.period.unwrap_or(10)
201 }
202}
203
204#[derive(Debug, Copy, Clone)]
205pub struct CviBuilder {
206 period: Option<usize>,
207 kernel: Kernel,
208}
209
210impl Default for CviBuilder {
211 fn default() -> Self {
212 Self {
213 period: None,
214 kernel: Kernel::Auto,
215 }
216 }
217}
218
219impl CviBuilder {
220 #[inline]
221 pub fn new() -> Self {
222 Self::default()
223 }
224 #[inline]
225 pub fn period(mut self, n: usize) -> Self {
226 self.period = Some(n);
227 self
228 }
229 #[inline]
230 pub fn kernel(mut self, k: Kernel) -> Self {
231 self.kernel = k;
232 self
233 }
234
235 #[inline]
236 pub fn apply(self, candles: &Candles) -> Result<CviOutput, CviError> {
237 let params = CviParams {
238 period: self.period,
239 };
240 let input = CviInput::from_candles(candles, params);
241 cvi_with_kernel(&input, self.kernel)
242 }
243
244 #[inline]
245 pub fn apply_slice(self, high: &[f64], low: &[f64]) -> Result<CviOutput, CviError> {
246 let params = CviParams {
247 period: self.period,
248 };
249 let input = CviInput::from_slices(high, low, params);
250 cvi_with_kernel(&input, self.kernel)
251 }
252
253 #[inline]
254 pub fn into_stream(self, initial_high: f64, initial_low: f64) -> Result<CviStream, CviError> {
255 let params = CviParams {
256 period: self.period,
257 };
258 CviStream::try_new(params, initial_high, initial_low)
259 }
260}
261
262#[derive(Debug, Error)]
263pub enum CviError {
264 #[error("cvi: Empty data provided for CVI.")]
265 EmptyData,
266 #[error("cvi: Invalid period: period = {period}, data length = {data_len}")]
267 InvalidPeriod { period: usize, data_len: usize },
268 #[error("cvi: Not enough valid data: needed = {needed}, valid = {valid}")]
269 NotEnoughValidData { needed: usize, valid: usize },
270 #[error("cvi: All values are NaN.")]
271 AllValuesNaN,
272 #[error("cvi: Output length mismatch: expected={expected}, got={got}")]
273 OutputLengthMismatch { expected: usize, got: usize },
274 #[error("cvi: Invalid range: start={start}, end={end}, step={step}")]
275 InvalidRange {
276 start: usize,
277 end: usize,
278 step: usize,
279 },
280 #[error("cvi: Invalid kernel for batch path: {0:?}")]
281 InvalidKernelForBatch(Kernel),
282}
283
284#[inline]
285pub fn cvi(input: &CviInput) -> Result<CviOutput, CviError> {
286 cvi_with_kernel(input, Kernel::Auto)
287}
288
289pub fn cvi_with_kernel(input: &CviInput, kernel: Kernel) -> Result<CviOutput, CviError> {
290 let (high, low) = match &input.data {
291 CviData::Candles(c) => (&c.high[..], &c.low[..]),
292 CviData::Slices { high, low } => (*high, *low),
293 };
294
295 if high.is_empty() || low.is_empty() || high.len() != low.len() {
296 return Err(CviError::EmptyData);
297 }
298
299 let period = input.get_period();
300 if period == 0 || period > high.len() {
301 return Err(CviError::InvalidPeriod {
302 period,
303 data_len: high.len(),
304 });
305 }
306
307 let first_valid_idx = match (0..high.len()).find(|&i| !high[i].is_nan() && !low[i].is_nan()) {
308 Some(idx) => idx,
309 None => return Err(CviError::AllValuesNaN),
310 };
311
312 let needed = 2 * period - 1;
313 if (high.len() - first_valid_idx) < needed {
314 return Err(CviError::NotEnoughValidData {
315 needed,
316 valid: high.len() - first_valid_idx,
317 });
318 }
319
320 let mut cvi_values = alloc_with_nan_prefix(high.len(), first_valid_idx + needed);
321
322 let chosen = match kernel {
323 Kernel::Auto => Kernel::Scalar,
324 other => other,
325 };
326
327 unsafe {
328 match chosen {
329 Kernel::Scalar | Kernel::ScalarBatch => {
330 cvi_scalar(high, low, period, first_valid_idx, &mut cvi_values)
331 }
332 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
333 Kernel::Avx2 | Kernel::Avx2Batch => {
334 cvi_avx2(high, low, period, first_valid_idx, &mut cvi_values)
335 }
336 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
337 Kernel::Avx512 | Kernel::Avx512Batch => {
338 cvi_avx512(high, low, period, first_valid_idx, &mut cvi_values)
339 }
340 _ => unreachable!(),
341 }
342 }
343
344 Ok(CviOutput { values: cvi_values })
345}
346
347#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
348#[inline]
349pub fn cvi_into(input: &CviInput, out: &mut [f64]) -> Result<(), CviError> {
350 cvi_into_slice(out, input, Kernel::Scalar)
351}
352
353#[inline]
354pub fn cvi_scalar(
355 high: &[f64],
356 low: &[f64],
357 period: usize,
358 first_valid_idx: usize,
359 out: &mut [f64],
360) {
361 let alpha = 2.0 / (period as f64 + 1.0);
362
363 let mut val =
364 unsafe { *high.get_unchecked(first_valid_idx) - *low.get_unchecked(first_valid_idx) };
365
366 let mut lag = AVec::<f64>::with_capacity(CACHELINE_ALIGN, period);
367 unsafe { lag.set_len(period) };
368 unsafe { *lag.get_unchecked_mut(0) = val };
369
370 let mut head = 1usize;
371 let len = high.len();
372 let needed = 2 * period - 1;
373
374 let mut i = first_valid_idx + 1;
375 let end_warm = first_valid_idx + needed;
376 while i < end_warm {
377 let range = unsafe { *high.get_unchecked(i) - *low.get_unchecked(i) };
378 val += (range - val) * alpha;
379 unsafe { *lag.get_unchecked_mut(head) = val };
380 head += 1;
381 if head == period {
382 head = 0;
383 }
384 i += 1;
385 }
386
387 let mut j = end_warm;
388 while j < len {
389 let range = unsafe { *high.get_unchecked(j) - *low.get_unchecked(j) };
390 val += (range - val) * alpha;
391 let old = unsafe { *lag.get_unchecked(head) };
392 unsafe {
393 *out.get_unchecked_mut(j) = 100.0 * (val - old) / old;
394 *lag.get_unchecked_mut(head) = val;
395 }
396 head += 1;
397 if head == period {
398 head = 0;
399 }
400 j += 1;
401 }
402}
403
404#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
405#[inline]
406pub fn cvi_avx2(high: &[f64], low: &[f64], period: usize, first_valid_idx: usize, out: &mut [f64]) {
407 #[target_feature(enable = "avx2,fma")]
408 unsafe fn inner(
409 high: &[f64],
410 low: &[f64],
411 period: usize,
412 first_valid_idx: usize,
413 out: &mut [f64],
414 ) {
415 const DO_PREFETCH: bool = false;
416 let alpha = 2.0 / (period as f64 + 1.0);
417 let mut val = *high.get_unchecked(first_valid_idx) - *low.get_unchecked(first_valid_idx);
418
419 let mut lag = AVec::<f64>::with_capacity(CACHELINE_ALIGN, period);
420 lag.set_len(period);
421 *lag.get_unchecked_mut(0) = val;
422
423 let mut head = 1usize;
424 let len = high.len();
425 let needed = 2 * period - 1;
426
427 let mut i = first_valid_idx + 1;
428 let end_warm = first_valid_idx + needed;
429 while i < end_warm {
430 if DO_PREFETCH && i + 64 < len {
431 _mm_prefetch(high.as_ptr().add(i + 64) as *const i8, _MM_HINT_T0);
432 _mm_prefetch(low.as_ptr().add(i + 64) as *const i8, _MM_HINT_T0);
433 }
434 let range = *high.get_unchecked(i) - *low.get_unchecked(i);
435 val += (range - val) * alpha;
436 *lag.get_unchecked_mut(head) = val;
437 head += 1;
438 if head == period {
439 head = 0;
440 }
441 i += 1;
442 }
443
444 let mut j = end_warm;
445 while j < len {
446 if DO_PREFETCH && j + 64 < len {
447 _mm_prefetch(high.as_ptr().add(j + 64) as *const i8, _MM_HINT_T0);
448 _mm_prefetch(low.as_ptr().add(j + 64) as *const i8, _MM_HINT_T0);
449 }
450 let range = *high.get_unchecked(j) - *low.get_unchecked(j);
451 val += (range - val) * alpha;
452 let old = *lag.get_unchecked(head);
453 *out.get_unchecked_mut(j) = 100.0 * (val - old) / old;
454 *lag.get_unchecked_mut(head) = val;
455 head += 1;
456 if head == period {
457 head = 0;
458 }
459 j += 1;
460 }
461 }
462
463 unsafe { inner(high, low, period, first_valid_idx, out) }
464}
465
466#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
467#[inline]
468pub fn cvi_avx512(
469 high: &[f64],
470 low: &[f64],
471 period: usize,
472 first_valid_idx: usize,
473 out: &mut [f64],
474) {
475 if period <= 32 {
476 unsafe { cvi_avx512_short(high, low, period, first_valid_idx, out) }
477 } else {
478 unsafe { cvi_avx512_long(high, low, period, first_valid_idx, out) }
479 }
480}
481
482#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
483#[inline]
484pub unsafe fn cvi_avx512_short(
485 high: &[f64],
486 low: &[f64],
487 period: usize,
488 first_valid_idx: usize,
489 out: &mut [f64],
490) {
491 cvi_avx512_core(high, low, period, first_valid_idx, out)
492}
493
494#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
495#[inline]
496pub unsafe fn cvi_avx512_long(
497 high: &[f64],
498 low: &[f64],
499 period: usize,
500 first_valid_idx: usize,
501 out: &mut [f64],
502) {
503 cvi_avx512_core(high, low, period, first_valid_idx, out)
504}
505
506#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
507#[target_feature(enable = "avx512f,fma")]
508unsafe fn cvi_avx512_core(
509 high: &[f64],
510 low: &[f64],
511 period: usize,
512 first_valid_idx: usize,
513 out: &mut [f64],
514) {
515 const DO_PREFETCH: bool = false;
516 let alpha = 2.0 / (period as f64 + 1.0);
517 let mut val = *high.get_unchecked(first_valid_idx) - *low.get_unchecked(first_valid_idx);
518
519 let mut lag = AVec::<f64>::with_capacity(CACHELINE_ALIGN, period);
520 lag.set_len(period);
521 *lag.get_unchecked_mut(0) = val;
522
523 let mut head = 1usize;
524 let len = high.len();
525 let needed = 2 * period - 1;
526
527 let mut i = first_valid_idx + 1;
528 let end_warm = first_valid_idx + needed;
529 while i < end_warm {
530 if DO_PREFETCH && i + 96 < len {
531 _mm_prefetch(high.as_ptr().add(i + 96) as *const i8, _MM_HINT_T0);
532 _mm_prefetch(low.as_ptr().add(i + 96) as *const i8, _MM_HINT_T0);
533 }
534 let range = *high.get_unchecked(i) - *low.get_unchecked(i);
535 val += (range - val) * alpha;
536 *lag.get_unchecked_mut(head) = val;
537 head += 1;
538 if head == period {
539 head = 0;
540 }
541 i += 1;
542 }
543
544 let mut j = end_warm;
545 while j < len {
546 if DO_PREFETCH && j + 96 < len {
547 _mm_prefetch(high.as_ptr().add(j + 96) as *const i8, _MM_HINT_T0);
548 _mm_prefetch(low.as_ptr().add(j + 96) as *const i8, _MM_HINT_T0);
549 }
550 let range = *high.get_unchecked(j) - *low.get_unchecked(j);
551 val += (range - val) * alpha;
552 let old = *lag.get_unchecked(head);
553 *out.get_unchecked_mut(j) = 100.0 * (val - old) / old;
554 *lag.get_unchecked_mut(head) = val;
555 head += 1;
556 if head == period {
557 head = 0;
558 }
559 j += 1;
560 }
561}
562
563#[inline(always)]
564pub fn cvi_row_scalar(
565 high: &[f64],
566 low: &[f64],
567 period: usize,
568 first_valid_idx: usize,
569 out: &mut [f64],
570) {
571 cvi_scalar(high, low, period, first_valid_idx, out)
572}
573
574#[inline(always)]
575fn cvi_scalar_from_range(range: &[f64], period: usize, first_valid_idx: usize, out: &mut [f64]) {
576 let alpha = 2.0 / (period as f64 + 1.0);
577
578 let mut val = unsafe { *range.get_unchecked(first_valid_idx) };
579
580 let mut lag = AVec::<f64>::with_capacity(CACHELINE_ALIGN, period);
581 unsafe { lag.set_len(period) };
582 unsafe { *lag.get_unchecked_mut(0) = val };
583
584 let mut head = 1usize;
585 let len = range.len();
586 let needed = 2 * period - 1;
587
588 let mut i = first_valid_idx + 1;
589 let end_warm = first_valid_idx + needed;
590 while i < end_warm {
591 let r = unsafe { *range.get_unchecked(i) };
592 val += (r - val) * alpha;
593 unsafe { *lag.get_unchecked_mut(head) = val };
594 head += 1;
595 if head == period {
596 head = 0;
597 }
598 i += 1;
599 }
600
601 let mut j = end_warm;
602 while j < len {
603 let r = unsafe { *range.get_unchecked(j) };
604 val += (r - val) * alpha;
605 let old = unsafe { *lag.get_unchecked(head) };
606 unsafe {
607 *out.get_unchecked_mut(j) = 100.0 * (val - old) / old;
608 *lag.get_unchecked_mut(head) = val;
609 }
610 head += 1;
611 if head == period {
612 head = 0;
613 }
614 j += 1;
615 }
616}
617
618#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
619#[inline(always)]
620pub fn cvi_row_avx2(
621 high: &[f64],
622 low: &[f64],
623 period: usize,
624 first_valid_idx: usize,
625 out: &mut [f64],
626) {
627 cvi_avx2(high, low, period, first_valid_idx, out)
628}
629
630#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
631#[inline(always)]
632pub fn cvi_row_avx512(
633 high: &[f64],
634 low: &[f64],
635 period: usize,
636 first_valid_idx: usize,
637 out: &mut [f64],
638) {
639 cvi_avx512(high, low, period, first_valid_idx, out)
640}
641
642#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
643#[inline(always)]
644pub fn cvi_row_avx512_short(
645 high: &[f64],
646 low: &[f64],
647 period: usize,
648 first_valid_idx: usize,
649 out: &mut [f64],
650) {
651 unsafe { cvi_avx512_short(high, low, period, first_valid_idx, out) }
652}
653
654#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
655#[inline(always)]
656pub fn cvi_row_avx512_long(
657 high: &[f64],
658 low: &[f64],
659 period: usize,
660 first_valid_idx: usize,
661 out: &mut [f64],
662) {
663 unsafe { cvi_avx512_long(high, low, period, first_valid_idx, out) }
664}
665
666#[derive(Clone, Debug)]
667pub struct CviBatchRange {
668 pub period: (usize, usize, usize),
669}
670
671impl Default for CviBatchRange {
672 fn default() -> Self {
673 Self {
674 period: (10, 259, 1),
675 }
676 }
677}
678
679#[derive(Clone, Debug, Default)]
680pub struct CviBatchBuilder {
681 range: CviBatchRange,
682 kernel: Kernel,
683}
684
685impl CviBatchBuilder {
686 pub fn new() -> Self {
687 Self::default()
688 }
689 pub fn kernel(mut self, k: Kernel) -> Self {
690 self.kernel = k;
691 self
692 }
693 #[inline]
694 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
695 self.range.period = (start, end, step);
696 self
697 }
698 #[inline]
699 pub fn period_static(mut self, p: usize) -> Self {
700 self.range.period = (p, p, 0);
701 self
702 }
703 pub fn apply_slices(self, high: &[f64], low: &[f64]) -> Result<CviBatchOutput, CviError> {
704 cvi_batch_with_kernel(high, low, &self.range, self.kernel)
705 }
706 pub fn apply_candles(self, candles: &Candles) -> Result<CviBatchOutput, CviError> {
707 self.apply_slices(&candles.high, &candles.low)
708 }
709}
710
711pub fn cvi_batch_with_kernel(
712 high: &[f64],
713 low: &[f64],
714 sweep: &CviBatchRange,
715 k: Kernel,
716) -> Result<CviBatchOutput, CviError> {
717 let kernel = match k {
718 Kernel::Auto => {
719 let k = detect_best_batch_kernel();
720 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
721 {
722 if k == Kernel::Avx512Batch
723 && std::arch::is_x86_feature_detected!("avx2")
724 && std::arch::is_x86_feature_detected!("fma")
725 {
726 Kernel::Avx2Batch
727 } else {
728 k
729 }
730 }
731 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
732 {
733 k
734 }
735 }
736 other if other.is_batch() => other,
737 other => return Err(CviError::InvalidKernelForBatch(other)),
738 };
739
740 let simd = match kernel {
741 Kernel::Avx512Batch => Kernel::Avx512,
742 Kernel::Avx2Batch => Kernel::Avx2,
743 Kernel::ScalarBatch => Kernel::Scalar,
744 _ => unreachable!(),
745 };
746 cvi_batch_par_slice(high, low, sweep, simd)
747}
748
749#[derive(Clone, Debug)]
750pub struct CviBatchOutput {
751 pub values: Vec<f64>,
752 pub combos: Vec<CviParams>,
753 pub rows: usize,
754 pub cols: usize,
755}
756impl CviBatchOutput {
757 pub fn row_for_params(&self, p: &CviParams) -> Option<usize> {
758 self.combos
759 .iter()
760 .position(|c| c.period.unwrap_or(10) == p.period.unwrap_or(10))
761 }
762
763 pub fn values_for(&self, p: &CviParams) -> Option<&[f64]> {
764 self.row_for_params(p).map(|row| {
765 let start = row * self.cols;
766 &self.values[start..start + self.cols]
767 })
768 }
769}
770
771#[inline(always)]
772fn expand_grid(r: &CviBatchRange) -> Vec<CviParams> {
773 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
774 if step == 0 || start == end {
775 return vec![start];
776 }
777 if start <= end {
778 return (start..=end).step_by(step.max(1)).collect();
779 }
780
781 let mut v = Vec::new();
782 let s = step.max(1);
783 let mut cur = start;
784 loop {
785 v.push(cur);
786 if cur <= end {
787 break;
788 }
789 let next = cur.saturating_sub(s);
790 if next == cur {
791 break;
792 }
793 cur = next;
794 }
795 v.retain(|&x| x >= end);
796 v
797 }
798 let periods = axis_usize(r.period);
799
800 let mut out = Vec::with_capacity(periods.len());
801 for &p in &periods {
802 out.push(CviParams { period: Some(p) });
803 }
804 out
805}
806
807#[inline(always)]
808pub fn cvi_batch_slice(
809 high: &[f64],
810 low: &[f64],
811 sweep: &CviBatchRange,
812 kern: Kernel,
813) -> Result<CviBatchOutput, CviError> {
814 cvi_batch_inner(high, low, sweep, kern, false)
815}
816
817#[inline(always)]
818pub fn cvi_batch_par_slice(
819 high: &[f64],
820 low: &[f64],
821 sweep: &CviBatchRange,
822 kern: Kernel,
823) -> Result<CviBatchOutput, CviError> {
824 cvi_batch_inner(high, low, sweep, kern, true)
825}
826
827#[inline(always)]
828fn cvi_batch_inner(
829 high: &[f64],
830 low: &[f64],
831 sweep: &CviBatchRange,
832 kern: Kernel,
833 parallel: bool,
834) -> Result<CviBatchOutput, CviError> {
835 if high.is_empty() || low.is_empty() || high.len() != low.len() {
836 return Err(CviError::EmptyData);
837 }
838
839 let combos = expand_grid(sweep);
840 if combos.is_empty() {
841 let (s, e, st) = sweep.period;
842 return Err(CviError::InvalidRange {
843 start: s,
844 end: e,
845 step: st,
846 });
847 }
848
849 let first_valid_idx = (0..high.len())
850 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
851 .ok_or(CviError::AllValuesNaN)?;
852
853 if combos.iter().any(|c| c.period.unwrap_or(0) == 0) {
854 return Err(CviError::InvalidPeriod {
855 period: 0,
856 data_len: high.len(),
857 });
858 }
859
860 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
861 let needed = 2 * max_p - 1;
862 if high.len() - first_valid_idx < needed {
863 return Err(CviError::NotEnoughValidData {
864 needed,
865 valid: high.len() - first_valid_idx,
866 });
867 }
868
869 let rows = combos.len();
870 let cols = high.len();
871
872 rows.checked_mul(cols)
873 .ok_or_else(|| CviError::InvalidRange {
874 start: rows,
875 end: cols,
876 step: 0,
877 })?;
878
879 let mut buf_mu = make_uninit_matrix(rows, cols);
880 let warmup_periods: Vec<usize> = combos
881 .iter()
882 .map(|c| {
883 let period = c.period.unwrap();
884 first_valid_idx + (2 * period - 1)
885 })
886 .collect();
887 init_matrix_prefixes(&mut buf_mu, cols, &warmup_periods);
888
889 let mut buf_guard = core::mem::ManuallyDrop::new(buf_mu);
890 let values_slice: &mut [f64] = unsafe {
891 core::slice::from_raw_parts_mut(buf_guard.as_mut_ptr() as *mut f64, buf_guard.len())
892 };
893
894 let shared_range: Option<Vec<f64>> = match kern {
895 Kernel::Scalar | Kernel::Auto => {
896 let mut r = Vec::with_capacity(cols);
897 unsafe {
898 r.set_len(cols);
899 let mut k = 0usize;
900 while k < cols {
901 let v = *high.get_unchecked(k) - *low.get_unchecked(k);
902 *r.get_unchecked_mut(k) = v;
903 k += 1;
904 }
905 }
906 Some(r)
907 }
908 _ => None,
909 };
910
911 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
912 let period = combos[row].period.unwrap();
913 match kern {
914 Kernel::Scalar | Kernel::Auto => {
915 if let Some(range) = &shared_range {
916 cvi_scalar_from_range(range, period, first_valid_idx, out_row)
917 } else {
918 cvi_row_scalar(high, low, period, first_valid_idx, out_row)
919 }
920 }
921 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
922 Kernel::Avx2 => cvi_row_avx2(high, low, period, first_valid_idx, out_row),
923 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
924 Kernel::Avx512 => cvi_row_avx512(high, low, period, first_valid_idx, out_row),
925 _ => cvi_row_scalar(high, low, period, first_valid_idx, out_row),
926 }
927 };
928
929 if parallel {
930 #[cfg(not(target_arch = "wasm32"))]
931 {
932 values_slice
933 .par_chunks_mut(cols)
934 .enumerate()
935 .for_each(|(row, slice)| do_row(row, slice));
936 }
937
938 #[cfg(target_arch = "wasm32")]
939 {
940 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
941 do_row(row, slice);
942 }
943 }
944 } else {
945 for (row, slice) in values_slice.chunks_mut(cols).enumerate() {
946 do_row(row, slice);
947 }
948 }
949
950 let values = unsafe {
951 Vec::from_raw_parts(
952 buf_guard.as_mut_ptr() as *mut f64,
953 buf_guard.len(),
954 buf_guard.capacity(),
955 )
956 };
957 core::mem::forget(buf_guard);
958
959 Ok(CviBatchOutput {
960 values,
961 combos,
962 rows,
963 cols,
964 })
965}
966
967#[inline(always)]
968fn cvi_batch_inner_into(
969 high: &[f64],
970 low: &[f64],
971 sweep: &CviBatchRange,
972 kern: Kernel,
973 parallel: bool,
974 out: &mut [f64],
975) -> Result<Vec<CviParams>, CviError> {
976 if high.is_empty() || low.is_empty() || high.len() != low.len() {
977 return Err(CviError::EmptyData);
978 }
979
980 let combos = expand_grid(sweep);
981 if combos.is_empty() {
982 let (s, e, st) = sweep.period;
983 return Err(CviError::InvalidRange {
984 start: s,
985 end: e,
986 step: st,
987 });
988 }
989
990 let first = (0..high.len())
991 .find(|&i| !high[i].is_nan() && !low[i].is_nan())
992 .ok_or(CviError::AllValuesNaN)?;
993
994 if combos.iter().any(|c| c.period.unwrap_or(0) == 0) {
995 return Err(CviError::InvalidPeriod {
996 period: 0,
997 data_len: high.len(),
998 });
999 }
1000
1001 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1002 let needed = 2 * max_p - 1;
1003 if high.len() - first < needed {
1004 return Err(CviError::NotEnoughValidData {
1005 needed,
1006 valid: high.len() - first,
1007 });
1008 }
1009
1010 let rows = combos.len();
1011 let cols = high.len();
1012 let expected = rows
1013 .checked_mul(cols)
1014 .ok_or_else(|| CviError::InvalidRange {
1015 start: rows,
1016 end: cols,
1017 step: 0,
1018 })?;
1019 if out.len() != expected {
1020 return Err(CviError::OutputLengthMismatch {
1021 expected,
1022 got: out.len(),
1023 });
1024 }
1025
1026 let out_mu = unsafe {
1027 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1028 };
1029 let warms: Vec<usize> = combos
1030 .iter()
1031 .map(|p| first + (2 * p.period.unwrap() - 1))
1032 .collect();
1033 init_matrix_prefixes(out_mu, cols, &warms);
1034
1035 let shared_range: Option<Vec<f64>> = match kern {
1036 Kernel::Scalar | Kernel::ScalarBatch | Kernel::Auto => {
1037 let mut r = Vec::with_capacity(cols);
1038 unsafe {
1039 r.set_len(cols);
1040 let mut i = 0usize;
1041 while i < cols {
1042 let v = *high.get_unchecked(i) - *low.get_unchecked(i);
1043 *r.get_unchecked_mut(i) = v;
1044 i += 1;
1045 }
1046 }
1047 Some(r)
1048 }
1049 _ => None,
1050 };
1051
1052 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| {
1053 let p = combos[row].period.unwrap();
1054 let dst = unsafe {
1055 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1056 };
1057 match kern {
1058 Kernel::Scalar | Kernel::ScalarBatch | Kernel::Auto => {
1059 if let Some(range) = &shared_range {
1060 cvi_scalar_from_range(range, p, first, dst)
1061 } else {
1062 cvi_row_scalar(high, low, p, first, dst)
1063 }
1064 }
1065 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1066 Kernel::Avx2 | Kernel::Avx2Batch => cvi_row_avx2(high, low, p, first, dst),
1067 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1068 Kernel::Avx512 | Kernel::Avx512Batch => cvi_row_avx512(high, low, p, first, dst),
1069 _ => cvi_row_scalar(high, low, p, first, dst),
1070 }
1071 };
1072
1073 #[cfg(not(target_arch = "wasm32"))]
1074 {
1075 if parallel {
1076 out_mu
1077 .par_chunks_mut(cols)
1078 .enumerate()
1079 .for_each(|(r, s)| do_row(r, s));
1080 } else {
1081 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1082 do_row(r, s);
1083 }
1084 }
1085 }
1086 #[cfg(target_arch = "wasm32")]
1087 {
1088 for (r, s) in out_mu.chunks_mut(cols).enumerate() {
1089 do_row(r, s);
1090 }
1091 }
1092
1093 Ok(combos)
1094}
1095
1096#[derive(Debug, Clone)]
1097pub struct CviStream {
1098 period: usize,
1099 alpha: f64,
1100 lag_buffer: Vec<f64>,
1101 head: usize,
1102 warmup_remaining: usize,
1103 state_val: f64,
1104}
1105
1106impl CviStream {
1107 #[inline]
1108 pub fn try_new(
1109 params: CviParams,
1110 initial_high: f64,
1111 initial_low: f64,
1112 ) -> Result<Self, CviError> {
1113 let period = params.period.unwrap_or(10);
1114 if period == 0 {
1115 return Err(CviError::InvalidPeriod {
1116 period,
1117 data_len: 0,
1118 });
1119 }
1120
1121 let alpha = 2.0 / (period as f64 + 1.0);
1122
1123 let ema0 = initial_high - initial_low;
1124
1125 let mut lag_buffer = vec![0.0; period.max(1)];
1126 lag_buffer[0] = ema0;
1127
1128 let head = if period > 1 { 1 } else { 0 };
1129
1130 let warmup_remaining = period.saturating_mul(2).saturating_sub(2);
1131
1132 Ok(Self {
1133 period,
1134 alpha,
1135 lag_buffer,
1136 head,
1137 warmup_remaining,
1138 state_val: ema0,
1139 })
1140 }
1141
1142 #[inline(always)]
1143 pub fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1144 let range = high - low;
1145 self.update_range(range)
1146 }
1147
1148 #[inline(always)]
1149 pub fn update_range(&mut self, range: f64) -> Option<f64> {
1150 self.state_val = (range - self.state_val).mul_add(self.alpha, self.state_val);
1151
1152 if self.warmup_remaining != 0 {
1153 self.lag_buffer[self.head] = self.state_val;
1154 let next = self.head + 1;
1155 self.head = if next == self.period { 0 } else { next };
1156 self.warmup_remaining -= 1;
1157 return None;
1158 }
1159
1160 let old = self.lag_buffer[self.head];
1161 let out = 100.0 * (self.state_val - old) / old;
1162
1163 self.lag_buffer[self.head] = self.state_val;
1164 let next = self.head + 1;
1165 self.head = if next == self.period { 0 } else { next };
1166
1167 Some(out)
1168 }
1169}
1170
1171#[cfg(feature = "python")]
1172#[pyfunction(name = "cvi")]
1173#[pyo3(signature = (high, low, period, kernel=None))]
1174pub fn cvi_py<'py>(
1175 py: Python<'py>,
1176 high: numpy::PyReadonlyArray1<'py, f64>,
1177 low: numpy::PyReadonlyArray1<'py, f64>,
1178 period: usize,
1179 kernel: Option<&str>,
1180) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1181 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1182
1183 let high_slice = high.as_slice()?;
1184 let low_slice = low.as_slice()?;
1185 let kern = validate_kernel(kernel, false)?;
1186 let params = CviParams {
1187 period: Some(period),
1188 };
1189 let cvi_in = CviInput::from_slices(high_slice, low_slice, params);
1190
1191 let result_vec: Vec<f64> = py
1192 .allow_threads(|| cvi_with_kernel(&cvi_in, kern).map(|o| o.values))
1193 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1194
1195 Ok(result_vec.into_pyarray(py))
1196}
1197
1198#[cfg(feature = "python")]
1199#[pyclass(name = "CviStream")]
1200pub struct CviStreamPy {
1201 period: usize,
1202 alpha: f64,
1203 lag_buffer: Vec<f64>,
1204 head: usize,
1205 warmup_remaining: usize,
1206 state_val: f64,
1207}
1208
1209#[cfg(feature = "python")]
1210#[pymethods]
1211impl CviStreamPy {
1212 #[new]
1213 fn new(period: usize, initial_high: f64, initial_low: f64) -> PyResult<Self> {
1214 if period == 0 {
1215 return Err(PyValueError::new_err(
1216 "cvi: Invalid period: period = 0, data length = 0",
1217 ));
1218 }
1219 let alpha = 2.0 / (period as f64 + 1.0);
1220 let ema0 = initial_high - initial_low;
1221 let mut lag = vec![0.0; period];
1222 lag[0] = ema0;
1223 let head = if period > 1 { 1 } else { 0 };
1224
1225 let warmup_remaining = period.saturating_sub(1);
1226 Ok(CviStreamPy {
1227 period,
1228 alpha,
1229 lag_buffer: lag,
1230 head,
1231 warmup_remaining,
1232 state_val: ema0,
1233 })
1234 }
1235
1236 fn update(&mut self, high: f64, low: f64) -> Option<f64> {
1237 let range = high - low;
1238 self.state_val = (range - self.state_val).mul_add(self.alpha, self.state_val);
1239
1240 if self.warmup_remaining != 0 {
1241 self.lag_buffer[self.head] = self.state_val;
1242 let next = self.head + 1;
1243 self.head = if next == self.period { 0 } else { next };
1244 self.warmup_remaining -= 1;
1245 return None;
1246 }
1247
1248 let old = self.lag_buffer[self.head];
1249 let out = 100.0 * (self.state_val - old) / old;
1250 self.lag_buffer[self.head] = self.state_val;
1251 let next = self.head + 1;
1252 self.head = if next == self.period { 0 } else { next };
1253 Some(out)
1254 }
1255}
1256
1257#[cfg(feature = "python")]
1258#[pyfunction(name = "cvi_batch")]
1259#[pyo3(signature = (high, low, period_range, kernel=None))]
1260pub fn cvi_batch_py<'py>(
1261 py: Python<'py>,
1262 high: numpy::PyReadonlyArray1<'py, f64>,
1263 low: numpy::PyReadonlyArray1<'py, f64>,
1264 period_range: (usize, usize, usize),
1265 kernel: Option<&str>,
1266) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1267 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1268 use pyo3::types::PyDict;
1269
1270 let high_slice = high.as_slice()?;
1271 let low_slice = low.as_slice()?;
1272
1273 let sweep = CviBatchRange {
1274 period: period_range,
1275 };
1276 let combos = expand_grid(&sweep);
1277 let rows = combos.len();
1278 let cols = high_slice.len();
1279
1280 let total = rows
1281 .checked_mul(cols)
1282 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1283 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1284 let slice_out = unsafe { out_arr.as_slice_mut()? };
1285
1286 let kern = validate_kernel(kernel, true)?;
1287
1288 let combos = py
1289 .allow_threads(|| {
1290 let kernel = match kern {
1291 Kernel::Auto => detect_best_batch_kernel(),
1292 k => k,
1293 };
1294 let simd = match kernel {
1295 Kernel::Avx512Batch => Kernel::Avx512,
1296 Kernel::Avx2Batch => Kernel::Avx2,
1297 Kernel::ScalarBatch => Kernel::Scalar,
1298 _ => unreachable!(),
1299 };
1300 cvi_batch_inner_into(high_slice, low_slice, &sweep, simd, true, slice_out)
1301 })
1302 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1303
1304 let dict = PyDict::new(py);
1305 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1306 dict.set_item(
1307 "periods",
1308 combos
1309 .iter()
1310 .map(|p| p.period.unwrap_or(10) as u64)
1311 .collect::<Vec<_>>()
1312 .into_pyarray(py),
1313 )?;
1314
1315 Ok(dict.into())
1316}
1317
1318#[cfg(all(feature = "python", feature = "cuda"))]
1319#[pyfunction(name = "cvi_cuda_batch_dev")]
1320#[pyo3(signature = (high, low, period_range, device_id=0))]
1321pub fn cvi_cuda_batch_dev_py(
1322 py: Python<'_>,
1323 high: numpy::PyReadonlyArray1<'_, f32>,
1324 low: numpy::PyReadonlyArray1<'_, f32>,
1325 period_range: (usize, usize, usize),
1326 device_id: usize,
1327) -> PyResult<CviDeviceArrayF32Py> {
1328 use crate::cuda::cuda_available;
1329 use crate::cuda::CudaCvi;
1330
1331 if !cuda_available() {
1332 return Err(PyValueError::new_err("CUDA not available"));
1333 }
1334
1335 let high_slice = high.as_slice()?;
1336 let low_slice = low.as_slice()?;
1337 if high_slice.len() != low_slice.len() {
1338 return Err(PyValueError::new_err("mismatched input lengths"));
1339 }
1340
1341 let sweep = CviBatchRange {
1342 period: period_range,
1343 };
1344 let (inner, ctx, dev_id) = py.allow_threads(|| {
1345 let cuda = CudaCvi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1346 let ctx = cuda.ctx();
1347 let dev_id = cuda.device_id();
1348 let arr = cuda
1349 .cvi_batch_dev(high_slice, low_slice, &sweep)
1350 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1351 Ok::<_, PyErr>((arr, ctx, dev_id))
1352 })?;
1353
1354 Ok(CviDeviceArrayF32Py {
1355 buf: Some(inner.buf),
1356 rows: inner.rows,
1357 cols: inner.cols,
1358 _ctx: ctx,
1359 device_id: dev_id,
1360 })
1361}
1362
1363#[cfg(all(feature = "python", feature = "cuda"))]
1364#[pyfunction(name = "cvi_cuda_many_series_one_param_dev")]
1365#[pyo3(signature = (high_tm, low_tm, cols, rows, period, device_id=0))]
1366pub fn cvi_cuda_many_series_one_param_dev_py(
1367 py: Python<'_>,
1368 high_tm: numpy::PyReadonlyArray1<'_, f32>,
1369 low_tm: numpy::PyReadonlyArray1<'_, f32>,
1370 cols: usize,
1371 rows: usize,
1372 period: usize,
1373 device_id: usize,
1374) -> PyResult<CviDeviceArrayF32Py> {
1375 use crate::cuda::cuda_available;
1376 use crate::cuda::CudaCvi;
1377
1378 if !cuda_available() {
1379 return Err(PyValueError::new_err("CUDA not available"));
1380 }
1381
1382 let high_slice = high_tm.as_slice()?;
1383 let low_slice = low_tm.as_slice()?;
1384 let expected = cols
1385 .checked_mul(rows)
1386 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1387 if high_slice.len() != expected || low_slice.len() != expected {
1388 return Err(PyValueError::new_err("time-major input length mismatch"));
1389 }
1390
1391 let (inner, ctx, dev_id) = py.allow_threads(|| {
1392 let cuda = CudaCvi::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1393 let ctx = cuda.ctx();
1394 let dev_id = cuda.device_id();
1395 let arr = cuda
1396 .cvi_many_series_one_param_time_major_dev(high_slice, low_slice, cols, rows, period)
1397 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1398 Ok::<_, PyErr>((arr, ctx, dev_id))
1399 })?;
1400
1401 Ok(CviDeviceArrayF32Py {
1402 buf: Some(inner.buf),
1403 rows: inner.rows,
1404 cols: inner.cols,
1405 _ctx: ctx,
1406 device_id: dev_id,
1407 })
1408}
1409
1410#[cfg(test)]
1411mod tests {
1412 use super::*;
1413 use crate::skip_if_unsupported;
1414 use crate::utilities::data_loader::read_candles_from_csv;
1415
1416 fn check_cvi_partial_params(
1417 test_name: &str,
1418 kernel: Kernel,
1419 ) -> Result<(), Box<dyn std::error::Error>> {
1420 skip_if_unsupported!(kernel, test_name);
1421 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1422 let candles = read_candles_from_csv(file_path)?;
1423 let default_params = CviParams { period: None };
1424 let input_default = CviInput::from_candles(&candles, default_params);
1425 let output_default = cvi_with_kernel(&input_default, kernel)?;
1426 assert_eq!(output_default.values.len(), candles.close.len());
1427 Ok(())
1428 }
1429
1430 fn check_cvi_accuracy(
1431 test_name: &str,
1432 kernel: Kernel,
1433 ) -> Result<(), Box<dyn std::error::Error>> {
1434 skip_if_unsupported!(kernel, test_name);
1435 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1436 let candles = read_candles_from_csv(file_path)?;
1437 let params = CviParams { period: Some(5) };
1438 let input = CviInput::from_candles(&candles, params);
1439 let cvi_result = cvi_with_kernel(&input, kernel)?;
1440
1441 let expected_last_five_cvi = [
1442 -52.96320026271643,
1443 -64.39616778235792,
1444 -59.4830094380472,
1445 -52.4690724045071,
1446 -11.858704179539174,
1447 ];
1448 assert!(cvi_result.values.len() >= 5);
1449 let start_index = cvi_result.values.len() - 5;
1450 let result_last_five = &cvi_result.values[start_index..];
1451 for (i, &val) in result_last_five.iter().enumerate() {
1452 let expected = expected_last_five_cvi[i];
1453 assert!(
1454 (val - expected).abs() < 1e-6,
1455 "[{}] CVI mismatch at index {}: expected {}, got {}",
1456 test_name,
1457 i,
1458 expected,
1459 val
1460 );
1461 }
1462 Ok(())
1463 }
1464
1465 fn check_cvi_input_with_default_candles(
1466 test_name: &str,
1467 kernel: Kernel,
1468 ) -> Result<(), Box<dyn std::error::Error>> {
1469 skip_if_unsupported!(kernel, test_name);
1470 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1471 let candles = read_candles_from_csv(file_path)?;
1472 let input = CviInput::with_default_candles(&candles);
1473 match input.data {
1474 CviData::Candles(_) => {}
1475 _ => panic!("Expected CviData::Candles variant"),
1476 }
1477 Ok(())
1478 }
1479
1480 fn check_cvi_with_zero_period(
1481 test_name: &str,
1482 kernel: Kernel,
1483 ) -> Result<(), Box<dyn std::error::Error>> {
1484 skip_if_unsupported!(kernel, test_name);
1485 let high = [10.0, 20.0, 30.0];
1486 let low = [5.0, 15.0, 25.0];
1487 let params = CviParams { period: Some(0) };
1488 let input = CviInput::from_slices(&high, &low, params);
1489
1490 let result = cvi_with_kernel(&input, kernel);
1491 assert!(
1492 result.is_err(),
1493 "[{}] Expected an error for zero period",
1494 test_name
1495 );
1496 Ok(())
1497 }
1498
1499 fn check_cvi_with_period_exceeding_data_length(
1500 test_name: &str,
1501 kernel: Kernel,
1502 ) -> Result<(), Box<dyn std::error::Error>> {
1503 skip_if_unsupported!(kernel, test_name);
1504 let high = [10.0, 20.0, 30.0];
1505 let low = [5.0, 15.0, 25.0];
1506 let params = CviParams { period: Some(10) };
1507 let input = CviInput::from_slices(&high, &low, params);
1508
1509 let result = cvi_with_kernel(&input, kernel);
1510 assert!(
1511 result.is_err(),
1512 "[{}] Expected an error for period > data.len()",
1513 test_name
1514 );
1515 Ok(())
1516 }
1517
1518 fn check_cvi_very_small_data_set(
1519 test_name: &str,
1520 kernel: Kernel,
1521 ) -> Result<(), Box<dyn std::error::Error>> {
1522 skip_if_unsupported!(kernel, test_name);
1523 let high = [5.0];
1524 let low = [2.0];
1525 let params = CviParams { period: Some(10) };
1526 let input = CviInput::from_slices(&high, &low, params);
1527
1528 let result = cvi_with_kernel(&input, kernel);
1529 assert!(
1530 result.is_err(),
1531 "[{}] Expected error for data smaller than period",
1532 test_name
1533 );
1534 Ok(())
1535 }
1536
1537 fn check_cvi_with_nan_data(
1538 test_name: &str,
1539 kernel: Kernel,
1540 ) -> Result<(), Box<dyn std::error::Error>> {
1541 skip_if_unsupported!(kernel, test_name);
1542 let high = [f64::NAN, 20.0, 30.0];
1543 let low = [5.0, 15.0, f64::NAN];
1544 let input = CviInput::from_slices(&high, &low, CviParams { period: Some(2) });
1545
1546 let result = cvi_with_kernel(&input, kernel);
1547 assert!(
1548 result.is_err(),
1549 "[{}] Expected an error due to trailing NaN",
1550 test_name
1551 );
1552 Ok(())
1553 }
1554
1555 fn check_cvi_slice_reinput(
1556 test_name: &str,
1557 kernel: Kernel,
1558 ) -> Result<(), Box<dyn std::error::Error>> {
1559 skip_if_unsupported!(kernel, test_name);
1560 let high = [
1561 10.0, 12.0, 12.5, 12.2, 13.0, 14.0, 15.0, 16.0, 16.5, 17.0, 17.5, 18.0,
1562 ];
1563 let low = [
1564 9.0, 10.0, 11.5, 11.0, 12.0, 13.5, 14.0, 14.5, 15.5, 16.0, 16.5, 17.0,
1565 ];
1566 let first_input = CviInput::from_slices(&high, &low, CviParams { period: Some(3) });
1567 let first_result = cvi_with_kernel(&first_input, kernel)?;
1568 let second_input =
1569 CviInput::from_slices(&first_result.values, &low, CviParams { period: Some(3) });
1570 let second_result = cvi_with_kernel(&second_input, kernel)?;
1571 assert_eq!(second_result.values.len(), low.len());
1572 Ok(())
1573 }
1574
1575 #[cfg(debug_assertions)]
1576 fn check_cvi_no_poison(
1577 test_name: &str,
1578 kernel: Kernel,
1579 ) -> Result<(), Box<dyn std::error::Error>> {
1580 skip_if_unsupported!(kernel, test_name);
1581
1582 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1583 let candles = read_candles_from_csv(file_path)?;
1584
1585 let test_params = vec![
1586 CviParams::default(),
1587 CviParams { period: Some(2) },
1588 CviParams { period: Some(5) },
1589 CviParams { period: Some(10) },
1590 CviParams { period: Some(14) },
1591 CviParams { period: Some(20) },
1592 CviParams { period: Some(50) },
1593 CviParams { period: Some(100) },
1594 CviParams { period: Some(200) },
1595 ];
1596
1597 for (param_idx, params) in test_params.iter().enumerate() {
1598 let input = CviInput::from_candles(&candles, params.clone());
1599 let output = cvi_with_kernel(&input, kernel)?;
1600
1601 for (i, &val) in output.values.iter().enumerate() {
1602 if val.is_nan() {
1603 continue;
1604 }
1605
1606 let bits = val.to_bits();
1607
1608 if bits == 0x11111111_11111111 {
1609 panic!(
1610 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1611 with params: period={} (param set {})",
1612 test_name,
1613 val,
1614 bits,
1615 i,
1616 params.period.unwrap_or(10),
1617 param_idx
1618 );
1619 }
1620
1621 if bits == 0x22222222_22222222 {
1622 panic!(
1623 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1624 with params: period={} (param set {})",
1625 test_name,
1626 val,
1627 bits,
1628 i,
1629 params.period.unwrap_or(10),
1630 param_idx
1631 );
1632 }
1633
1634 if bits == 0x33333333_33333333 {
1635 panic!(
1636 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1637 with params: period={} (param set {})",
1638 test_name,
1639 val,
1640 bits,
1641 i,
1642 params.period.unwrap_or(10),
1643 param_idx
1644 );
1645 }
1646 }
1647 }
1648
1649 Ok(())
1650 }
1651
1652 #[cfg(not(debug_assertions))]
1653 fn check_cvi_no_poison(
1654 _test_name: &str,
1655 _kernel: Kernel,
1656 ) -> Result<(), Box<dyn std::error::Error>> {
1657 Ok(())
1658 }
1659
1660 macro_rules! generate_all_cvi_tests {
1661 ($($test_fn:ident),*) => {
1662 paste::paste! {
1663 $(
1664 #[test]
1665 fn [<$test_fn _scalar_f64>]() {
1666 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1667 }
1668 )*
1669 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1670 $(
1671 #[test]
1672 fn [<$test_fn _avx2_f64>]() {
1673 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1674 }
1675 #[test]
1676 fn [<$test_fn _avx512_f64>]() {
1677 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1678 }
1679 )*
1680 }
1681 }
1682 }
1683
1684 generate_all_cvi_tests!(
1685 check_cvi_partial_params,
1686 check_cvi_accuracy,
1687 check_cvi_input_with_default_candles,
1688 check_cvi_with_zero_period,
1689 check_cvi_with_period_exceeding_data_length,
1690 check_cvi_very_small_data_set,
1691 check_cvi_with_nan_data,
1692 check_cvi_slice_reinput,
1693 check_cvi_no_poison
1694 );
1695
1696 #[cfg(test)]
1697 generate_all_cvi_tests!(check_cvi_property);
1698
1699 fn check_batch_default_row(
1700 test: &str,
1701 kernel: Kernel,
1702 ) -> Result<(), Box<dyn std::error::Error>> {
1703 skip_if_unsupported!(kernel, test);
1704 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1705 let c = read_candles_from_csv(file)?;
1706 let output = CviBatchBuilder::new().kernel(kernel).apply_candles(&c)?;
1707 let def = CviParams::default();
1708 let row = output.values_for(&def).expect("default row missing");
1709 assert_eq!(row.len(), c.close.len());
1710 Ok(())
1711 }
1712
1713 macro_rules! gen_batch_tests {
1714 ($fn_name:ident) => {
1715 paste::paste! {
1716 #[test] fn [<$fn_name _scalar>]() {
1717 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1718 }
1719 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1720 #[test] fn [<$fn_name _avx2>]() {
1721 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1722 }
1723 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1724 #[test] fn [<$fn_name _avx512>]() {
1725 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1726 }
1727 #[test] fn [<$fn_name _auto_detect>]() {
1728 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1729 }
1730 }
1731 };
1732 }
1733 gen_batch_tests!(check_batch_default_row);
1734 gen_batch_tests!(check_batch_no_poison);
1735
1736 #[test]
1737 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1738 fn test_cvi_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1739 let len = 256usize;
1740 let mut high = vec![0.0f64; len];
1741 let mut low = vec![0.0f64; len];
1742
1743 high[0] = f64::NAN;
1744 low[0] = f64::NAN;
1745 high[1] = f64::NAN;
1746 low[1] = f64::NAN;
1747 high[2] = f64::NAN;
1748 low[2] = f64::NAN;
1749
1750 for i in 3..len {
1751 let i_f = i as f64;
1752 let base = 100.0 + 0.1 * i_f;
1753
1754 high[i] = base + (i % 7) as f64 * 0.03;
1755 let spread = 1.0 + (i % 5) as f64 * 0.2;
1756 low[i] = high[i] - spread.max(0.001);
1757 }
1758
1759 let input = CviInput::from_slices(&high, &low, CviParams::default());
1760
1761 let baseline = cvi(&input)?.values;
1762
1763 let mut into_out = vec![0.0f64; len];
1764 cvi_into(&input, &mut into_out)?;
1765
1766 assert_eq!(baseline.len(), into_out.len());
1767
1768 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1769 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1770 }
1771
1772 for (i, (&a, &b)) in baseline.iter().zip(into_out.iter()).enumerate() {
1773 assert!(
1774 eq_or_both_nan(a, b),
1775 "mismatch at {}: baseline={} into={}",
1776 i,
1777 a,
1778 b
1779 );
1780 }
1781
1782 Ok(())
1783 }
1784
1785 #[cfg(debug_assertions)]
1786 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
1787 skip_if_unsupported!(kernel, test);
1788
1789 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1790 let c = read_candles_from_csv(file)?;
1791
1792 let test_configs = vec![
1793 (2, 10, 2),
1794 (5, 25, 5),
1795 (10, 50, 10),
1796 (2, 5, 1),
1797 (30, 60, 15),
1798 (14, 21, 7),
1799 (50, 100, 25),
1800 ];
1801
1802 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1803 let output = CviBatchBuilder::new()
1804 .kernel(kernel)
1805 .period_range(p_start, p_end, p_step)
1806 .apply_candles(&c)?;
1807
1808 for (idx, &val) in output.values.iter().enumerate() {
1809 if val.is_nan() {
1810 continue;
1811 }
1812
1813 let bits = val.to_bits();
1814 let row = idx / output.cols;
1815 let col = idx % output.cols;
1816 let combo = &output.combos[row];
1817
1818 if bits == 0x11111111_11111111 {
1819 panic!(
1820 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1821 at row {} col {} (flat index {}) with params: period={}",
1822 test,
1823 cfg_idx,
1824 val,
1825 bits,
1826 row,
1827 col,
1828 idx,
1829 combo.period.unwrap_or(10)
1830 );
1831 }
1832
1833 if bits == 0x22222222_22222222 {
1834 panic!(
1835 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1836 at row {} col {} (flat index {}) with params: period={}",
1837 test,
1838 cfg_idx,
1839 val,
1840 bits,
1841 row,
1842 col,
1843 idx,
1844 combo.period.unwrap_or(10)
1845 );
1846 }
1847
1848 if bits == 0x33333333_33333333 {
1849 panic!(
1850 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1851 at row {} col {} (flat index {}) with params: period={}",
1852 test,
1853 cfg_idx,
1854 val,
1855 bits,
1856 row,
1857 col,
1858 idx,
1859 combo.period.unwrap_or(10)
1860 );
1861 }
1862 }
1863 }
1864
1865 Ok(())
1866 }
1867
1868 #[cfg(not(debug_assertions))]
1869 fn check_batch_no_poison(
1870 _test: &str,
1871 _kernel: Kernel,
1872 ) -> Result<(), Box<dyn std::error::Error>> {
1873 Ok(())
1874 }
1875
1876 #[cfg(test)]
1877 #[allow(clippy::float_cmp)]
1878 fn check_cvi_property(
1879 test_name: &str,
1880 kernel: Kernel,
1881 ) -> Result<(), Box<dyn std::error::Error>> {
1882 use proptest::prelude::*;
1883 skip_if_unsupported!(kernel, test_name);
1884
1885 let strat = (2usize..=50)
1886 .prop_flat_map(|period| {
1887 (
1888 100.0f64..5000.0f64,
1889 (2 * period + 20)..400,
1890 0.001f64..0.1f64,
1891 -0.01f64..0.01f64,
1892 Just(period),
1893 )
1894 })
1895 .prop_map(|(base_price, data_len, volatility, trend, period)| {
1896 let mut high_data = Vec::with_capacity(data_len);
1897 let mut low_data = Vec::with_capacity(data_len);
1898 let mut current_price = base_price;
1899
1900 for i in 0..data_len {
1901 current_price *= 1.0 + trend;
1902 current_price = current_price.max(10.0);
1903
1904 let volatility_factor = volatility * (1.0 + (i as f64 * 0.1).sin() * 0.5);
1905 let range = current_price * volatility_factor;
1906
1907 let high = current_price + range * 0.5;
1908 let low = (current_price - range * 0.5).max(1.0);
1909
1910 high_data.push(high);
1911 low_data.push(low);
1912 }
1913
1914 (high_data, low_data, period)
1915 });
1916
1917 proptest::test_runner::TestRunner::default()
1918 .run(&strat, |(high_data, low_data, period)| {
1919 let params = CviParams {
1920 period: Some(period),
1921 };
1922 let input = CviInput::from_slices(&high_data, &low_data, params);
1923
1924 let out = cvi_with_kernel(&input, kernel)?;
1925 let ref_out = cvi_with_kernel(&input, Kernel::Scalar)?;
1926
1927 let expected_first_valid = 2 * period - 1;
1928 let first_valid_idx = out.values.iter().position(|&v| !v.is_nan());
1929
1930 if let Some(idx) = first_valid_idx {
1931 prop_assert_eq!(
1932 idx,
1933 expected_first_valid,
1934 "[{}] First valid index mismatch: expected {}, got {}",
1935 test_name,
1936 expected_first_valid,
1937 idx
1938 );
1939 }
1940
1941 for (i, &val) in out.values.iter().enumerate() {
1942 if !val.is_nan() {
1943 prop_assert!(
1944 val.is_finite(),
1945 "[{}] CVI value {} at index {} is not finite",
1946 test_name,
1947 val,
1948 i
1949 );
1950
1951 prop_assert!(
1952 val > -100.0 && val < 1000.0,
1953 "[{}] CVI value {} at index {} outside reasonable bounds [-100%, 1000%]",
1954 test_name, val, i
1955 );
1956 }
1957 }
1958
1959 prop_assert_eq!(
1960 out.values.len(),
1961 high_data.len(),
1962 "[{}] Output length mismatch",
1963 test_name
1964 );
1965
1966 prop_assert_eq!(
1967 out.values.len(),
1968 ref_out.values.len(),
1969 "[{}] Output length mismatch between kernels",
1970 test_name
1971 );
1972
1973 for (i, (&y, &r)) in out.values.iter().zip(ref_out.values.iter()).enumerate() {
1974 if !y.is_finite() || !r.is_finite() {
1975 prop_assert_eq!(
1976 y.to_bits(),
1977 r.to_bits(),
1978 "[{}] NaN/Inf mismatch at index {}: {} vs {}",
1979 test_name,
1980 i,
1981 y,
1982 r
1983 );
1984 continue;
1985 }
1986
1987 let ulp_diff = y.to_bits().abs_diff(r.to_bits());
1988 prop_assert!(
1989 (y - r).abs() <= 1e-9 || ulp_diff <= 4,
1990 "[{}] Value mismatch at index {}: {} vs {} (ULP={})",
1991 test_name,
1992 i,
1993 y,
1994 r,
1995 ulp_diff
1996 );
1997 }
1998
1999 if period == 2 {
2000 let warmup_count = out.values.iter().take_while(|&&v| v.is_nan()).count();
2001 prop_assert_eq!(
2002 warmup_count,
2003 3,
2004 "[{}] Period=2 should have warmup count of 3, got {}",
2005 test_name,
2006 warmup_count
2007 );
2008 }
2009
2010 let ranges: Vec<f64> = high_data
2011 .iter()
2012 .zip(low_data.iter())
2013 .map(|(&h, &l)| h - l)
2014 .collect();
2015
2016 let valid_cvi: Vec<f64> = out
2017 .values
2018 .iter()
2019 .filter(|&&v| !v.is_nan())
2020 .cloned()
2021 .collect();
2022
2023 if valid_cvi.len() > 10 {
2024 let min_cvi = valid_cvi.iter().cloned().fold(f64::INFINITY, f64::min);
2025 let max_cvi = valid_cvi.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
2026
2027 let range_variance = ranges
2028 .windows(period)
2029 .map(|w| {
2030 let mean = w.iter().sum::<f64>() / w.len() as f64;
2031 w.iter().map(|&r| (r - mean).powi(2)).sum::<f64>() / w.len() as f64
2032 })
2033 .fold(0.0f64, f64::max);
2034
2035 if range_variance > 1e-10 {
2036 prop_assert!(
2037 (max_cvi - min_cvi).abs() > 1.0,
2038 "[{}] CVI should show meaningful variation (>1%) when volatility changes, got range: {}",
2039 test_name, (max_cvi - min_cvi).abs()
2040 );
2041 }
2042 }
2043
2044 let min_range = ranges.iter().cloned().fold(f64::INFINITY, f64::min);
2045 if min_range < 1e-10 {
2046 for &val in &out.values {
2047 if !val.is_nan() {
2048 prop_assert!(
2049 val.is_finite(),
2050 "[{}] CVI should remain finite even with tiny ranges",
2051 test_name
2052 );
2053 }
2054 }
2055 }
2056
2057 if ranges.len() > period * 3 {
2058 let mean_range = ranges.iter().sum::<f64>() / ranges.len() as f64;
2059 let all_similar = ranges
2060 .iter()
2061 .all(|&r| (r - mean_range).abs() < mean_range * 0.01);
2062
2063 if all_similar && mean_range > 1e-10 {
2064 let last_quarter_start = valid_cvi.len() * 3 / 4;
2065 if last_quarter_start < valid_cvi.len() {
2066 let last_quarter: Vec<f64> = valid_cvi[last_quarter_start..].to_vec();
2067 if !last_quarter.is_empty() {
2068 let avg_abs_cvi = last_quarter.iter().map(|v| v.abs()).sum::<f64>()
2069 / last_quarter.len() as f64;
2070
2071 prop_assert!(
2072 avg_abs_cvi < 5.0,
2073 "[{}] CVI should converge near 0 for constant volatility, but average |CVI| = {}",
2074 test_name, avg_abs_cvi
2075 );
2076 }
2077 }
2078 }
2079 }
2080
2081 if ranges.len() > period * 2 {
2082 let mut max_volatility_change = 1.0;
2083 let mut max_change_idx = 0;
2084
2085 for i in period..ranges.len() - period {
2086 let prev_avg = ranges[i - period..i].iter().sum::<f64>() / period as f64;
2087 let next_avg = ranges[i..i + period].iter().sum::<f64>() / period as f64;
2088
2089 if prev_avg > 1e-10 && next_avg > 1e-10 {
2090 let ratio = (next_avg / prev_avg).max(prev_avg / next_avg);
2091 if ratio > max_volatility_change {
2092 max_volatility_change = ratio;
2093 max_change_idx = i;
2094 }
2095 }
2096 }
2097
2098 if max_volatility_change > 5.0 {
2099 let spike_start =
2100 (max_change_idx + expected_first_valid).saturating_sub(period);
2101 let spike_end = (max_change_idx + expected_first_valid + period * 2)
2102 .min(out.values.len());
2103
2104 if spike_end > spike_start {
2105 let spike_region: Vec<f64> = out.values[spike_start..spike_end]
2106 .iter()
2107 .filter(|&&v| !v.is_nan())
2108 .cloned()
2109 .collect();
2110
2111 if !spike_region.is_empty() {
2112 let max_abs_cvi =
2113 spike_region.iter().map(|v| v.abs()).fold(0.0f64, f64::max);
2114
2115 prop_assert!(
2116 max_abs_cvi > 10.0,
2117 "[{}] CVI should spike (>10%) for {}x volatility change, but max |CVI| = {}",
2118 test_name, max_volatility_change, max_abs_cvi
2119 );
2120 }
2121 }
2122 }
2123 }
2124
2125 Ok(())
2126 })
2127 .unwrap();
2128
2129 Ok(())
2130 }
2131}
2132
2133#[inline]
2134fn find_first_valid_idx(high: &[f64], low: &[f64]) -> Option<usize> {
2135 (0..high.len()).find(|&i| !high[i].is_nan() && !low[i].is_nan())
2136}
2137
2138#[inline(always)]
2139pub fn cvi_into_slice(
2140 output: &mut [f64],
2141 input: &CviInput,
2142 kernel: Kernel,
2143) -> Result<(), CviError> {
2144 let (high, low) = match &input.data {
2145 CviData::Candles(c) => (&c.high[..], &c.low[..]),
2146 CviData::Slices { high, low } => (*high, *low),
2147 };
2148
2149 let period = input.params.period.unwrap_or(10);
2150
2151 if high.is_empty() || low.is_empty() {
2152 return Err(CviError::EmptyData);
2153 }
2154 if period == 0 || period > high.len() {
2155 return Err(CviError::InvalidPeriod {
2156 period,
2157 data_len: high.len(),
2158 });
2159 }
2160 if high.len() != low.len() {
2161 return Err(CviError::EmptyData);
2162 }
2163 if output.len() != high.len() {
2164 return Err(CviError::OutputLengthMismatch {
2165 expected: high.len(),
2166 got: output.len(),
2167 });
2168 }
2169
2170 let first_valid_idx = match find_first_valid_idx(high, low) {
2171 Some(idx) => idx,
2172 None => return Err(CviError::AllValuesNaN),
2173 };
2174
2175 let warmup = period - 1;
2176 let min_data_points = warmup + period;
2177
2178 if high.len() - first_valid_idx < min_data_points {
2179 return Err(CviError::NotEnoughValidData {
2180 needed: min_data_points,
2181 valid: high.len() - first_valid_idx,
2182 });
2183 }
2184
2185 let out_start = first_valid_idx + 2 * period - 1;
2186 for i in 0..out_start {
2187 output[i] = f64::NAN;
2188 }
2189
2190 match kernel {
2191 Kernel::Scalar => cvi_scalar(high, low, period, first_valid_idx, output),
2192 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2193 Kernel::Avx2 => unsafe { cvi_avx2(high, low, period, first_valid_idx, output) },
2194 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2195 Kernel::Avx512 => unsafe { cvi_avx512(high, low, period, first_valid_idx, output) },
2196 Kernel::Auto => cvi_scalar(high, low, period, first_valid_idx, output),
2197 _ => return Err(CviError::InvalidKernelForBatch(kernel)),
2198 }
2199
2200 Ok(())
2201}
2202
2203#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2204#[wasm_bindgen]
2205pub fn cvi_js(high: &[f64], low: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2206 let params = CviParams {
2207 period: Some(period),
2208 };
2209 let input = CviInput::from_slices(high, low, params);
2210
2211 let output = cvi(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
2212
2213 Ok(output.values)
2214}
2215
2216#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2217#[wasm_bindgen]
2218pub fn cvi_alloc(len: usize) -> *mut f64 {
2219 let mut vec = Vec::<f64>::with_capacity(len);
2220 let ptr = vec.as_mut_ptr();
2221 std::mem::forget(vec);
2222 ptr
2223}
2224
2225#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2226#[wasm_bindgen]
2227pub fn cvi_free(ptr: *mut f64, len: usize) {
2228 unsafe {
2229 let _ = Vec::from_raw_parts(ptr, len, len);
2230 }
2231}
2232
2233#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2234#[wasm_bindgen]
2235pub fn cvi_into(
2236 high_ptr: *const f64,
2237 low_ptr: *const f64,
2238 out_ptr: *mut f64,
2239 len: usize,
2240 period: usize,
2241) -> Result<(), JsValue> {
2242 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
2243 return Err(JsValue::from_str("null pointer passed to cvi_into"));
2244 }
2245
2246 unsafe {
2247 let high = std::slice::from_raw_parts(high_ptr, len);
2248 let low = std::slice::from_raw_parts(low_ptr, len);
2249
2250 if period == 0 || period > len {
2251 return Err(JsValue::from_str("Invalid period"));
2252 }
2253
2254 let params = CviParams {
2255 period: Some(period),
2256 };
2257 let input = CviInput::from_slices(high, low, params);
2258
2259 let aliased = high_ptr == out_ptr || low_ptr == out_ptr;
2260
2261 if aliased {
2262 let result = cvi(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
2263 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2264 out.copy_from_slice(&result.values);
2265 } else {
2266 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2267 cvi_into_slice(out, &input, Kernel::Auto)
2268 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2269 }
2270
2271 Ok(())
2272 }
2273}
2274
2275#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2276#[derive(Serialize, Deserialize)]
2277pub struct CviBatchConfig {
2278 pub period_range: (usize, usize, usize),
2279}
2280
2281#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2282#[derive(Serialize, Deserialize)]
2283pub struct CviBatchJsOutput {
2284 pub values: Vec<f64>,
2285 pub combos: Vec<CviParams>,
2286 pub rows: usize,
2287 pub cols: usize,
2288}
2289
2290#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2291#[wasm_bindgen(js_name = cvi_batch)]
2292pub fn cvi_batch_unified_js(
2293 high: &[f64],
2294 low: &[f64],
2295 config: JsValue,
2296) -> Result<JsValue, JsValue> {
2297 let config: CviBatchConfig = serde_wasm_bindgen::from_value(config)
2298 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2299
2300 let sweep = CviBatchRange {
2301 period: config.period_range,
2302 };
2303
2304 let kernel = detect_best_kernel();
2305 let output = cvi_batch_inner(high, low, &sweep, kernel, false)
2306 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2307
2308 let js_output = CviBatchJsOutput {
2309 values: output.values,
2310 combos: output.combos,
2311 rows: output.rows,
2312 cols: output.cols,
2313 };
2314
2315 serde_wasm_bindgen::to_value(&js_output)
2316 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2317}
2318
2319#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2320#[wasm_bindgen]
2321pub fn cvi_batch_into(
2322 high_ptr: *const f64,
2323 low_ptr: *const f64,
2324 out_ptr: *mut f64,
2325 len: usize,
2326 period_start: usize,
2327 period_end: usize,
2328 period_step: usize,
2329) -> Result<usize, JsValue> {
2330 if high_ptr.is_null() || low_ptr.is_null() || out_ptr.is_null() {
2331 return Err(JsValue::from_str("null pointer passed to cvi_batch_into"));
2332 }
2333
2334 unsafe {
2335 let high = std::slice::from_raw_parts(high_ptr, len);
2336 let low = std::slice::from_raw_parts(low_ptr, len);
2337
2338 let sweep = CviBatchRange {
2339 period: (period_start, period_end, period_step),
2340 };
2341
2342 let combos = expand_grid(&sweep);
2343 let rows = combos.len();
2344 let cols = len;
2345
2346 let total = rows
2347 .checked_mul(cols)
2348 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
2349 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2350
2351 cvi_batch_inner_into(high, low, &sweep, Kernel::Auto, false, out)
2352 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2353
2354 Ok(rows)
2355 }
2356}