1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::cuda_available;
3#[cfg(all(feature = "python", feature = "cuda"))]
4use crate::cuda::moving_averages::CudaOtt;
5#[cfg(all(feature = "python", feature = "cuda"))]
6use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
7#[cfg(feature = "python")]
8use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
9#[cfg(feature = "python")]
10use pyo3::exceptions::PyValueError;
11#[cfg(feature = "python")]
12use pyo3::prelude::*;
13#[cfg(feature = "python")]
14use pyo3::types::{PyDict, PyList};
15
16#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
17use serde::{Deserialize, Serialize};
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use wasm_bindgen::prelude::*;
20
21use crate::utilities::data_loader::{source_type, Candles};
22use crate::utilities::enums::Kernel;
23use crate::utilities::helpers::{
24 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
25 make_uninit_matrix,
26};
27#[cfg(feature = "python")]
28use crate::utilities::kernel_validation::validate_kernel;
29use aligned_vec::{AVec, CACHELINE_ALIGN};
30
31use crate::indicators::moving_averages::{
32 ema::{ema_with_kernel, EmaInput, EmaParams},
33 linreg::{linreg_with_kernel, LinRegInput, LinRegParams},
34 sma::{sma_with_kernel, SmaInput, SmaParams},
35 wma::{wma_with_kernel, WmaInput, WmaParams},
36 zlema::{zlema_with_kernel, ZlemaInput, ZlemaParams},
37};
38use crate::indicators::tsf::{tsf_with_kernel, TsfInput, TsfParams};
39
40#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
41use core::arch::x86_64::*;
42
43#[cfg(not(target_arch = "wasm32"))]
44use rayon::prelude::*;
45
46use std::collections::HashMap;
47use std::convert::AsRef;
48use std::error::Error;
49use std::mem::MaybeUninit;
50use thiserror::Error;
51
52impl<'a> AsRef<[f64]> for OttInput<'a> {
53 #[inline(always)]
54 fn as_ref(&self) -> &[f64] {
55 match &self.data {
56 OttData::Slice(slice) => slice,
57 OttData::Candles { candles, source } => source_type(candles, source),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub enum OttData<'a> {
64 Candles {
65 candles: &'a Candles,
66 source: &'a str,
67 },
68 Slice(&'a [f64]),
69}
70
71#[derive(Debug, Clone)]
72pub struct OttOutput {
73 pub values: Vec<f64>,
74}
75
76#[derive(Debug, Clone)]
77#[cfg_attr(
78 all(target_arch = "wasm32", feature = "wasm"),
79 derive(Serialize, Deserialize)
80)]
81pub struct OttParams {
82 pub period: Option<usize>,
83 pub percent: Option<f64>,
84 pub ma_type: Option<String>,
85}
86
87impl Default for OttParams {
88 fn default() -> Self {
89 Self {
90 period: Some(2),
91 percent: Some(1.4),
92 ma_type: Some("VAR".to_string()),
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
98pub struct OttInput<'a> {
99 pub data: OttData<'a>,
100 pub params: OttParams,
101}
102
103impl<'a> OttInput<'a> {
104 #[inline]
105 pub fn from_candles(c: &'a Candles, s: &'a str, p: OttParams) -> Self {
106 Self {
107 data: OttData::Candles {
108 candles: c,
109 source: s,
110 },
111 params: p,
112 }
113 }
114
115 #[inline]
116 pub fn from_slice(sl: &'a [f64], p: OttParams) -> Self {
117 Self {
118 data: OttData::Slice(sl),
119 params: p,
120 }
121 }
122
123 #[inline]
124 pub fn with_default_candles(c: &'a Candles) -> Self {
125 Self::from_candles(c, "close", OttParams::default())
126 }
127
128 #[inline]
129 pub fn get_period(&self) -> usize {
130 self.params.period.unwrap_or(2)
131 }
132
133 #[inline]
134 pub fn get_percent(&self) -> f64 {
135 self.params.percent.unwrap_or(1.4)
136 }
137
138 #[inline]
139 pub fn get_ma_type(&self) -> &str {
140 match &self.params.ma_type {
141 Some(s) => s.as_str(),
142 None => "VAR",
143 }
144 }
145}
146
147#[derive(Clone, Debug)]
148pub struct OttBuilder {
149 period: Option<usize>,
150 percent: Option<f64>,
151 ma_type: Option<String>,
152 kernel: Kernel,
153}
154
155impl Default for OttBuilder {
156 fn default() -> Self {
157 Self {
158 period: None,
159 percent: None,
160 ma_type: None,
161 kernel: Kernel::Auto,
162 }
163 }
164}
165
166impl OttBuilder {
167 #[inline(always)]
168 pub fn new() -> Self {
169 Self::default()
170 }
171
172 #[inline(always)]
173 pub fn period(mut self, val: usize) -> Self {
174 self.period = Some(val);
175 self
176 }
177
178 #[inline(always)]
179 pub fn percent(mut self, val: f64) -> Self {
180 self.percent = Some(val);
181 self
182 }
183
184 #[inline(always)]
185 pub fn ma_type(mut self, val: String) -> Self {
186 self.ma_type = Some(val);
187 self
188 }
189
190 #[inline(always)]
191 pub fn kernel(mut self, k: Kernel) -> Self {
192 self.kernel = k;
193 self
194 }
195
196 #[inline(always)]
197 pub fn apply(self, c: &Candles) -> Result<OttOutput, OttError> {
198 let p = OttParams {
199 period: self.period,
200 percent: self.percent,
201 ma_type: self.ma_type,
202 };
203 let i = OttInput::from_candles(c, "close", p);
204 ott_with_kernel(&i, self.kernel)
205 }
206
207 #[inline(always)]
208 pub fn apply_slice(self, d: &[f64]) -> Result<OttOutput, OttError> {
209 let p = OttParams {
210 period: self.period,
211 percent: self.percent,
212 ma_type: self.ma_type,
213 };
214 let i = OttInput::from_slice(d, p);
215 ott_with_kernel(&i, self.kernel)
216 }
217
218 #[inline(always)]
219 pub fn apply_candles(self, c: &Candles, source: &str) -> Result<OttOutput, OttError> {
220 let p = OttParams {
221 period: self.period,
222 percent: self.percent,
223 ma_type: self.ma_type,
224 };
225 let i = OttInput::from_candles(c, source, p);
226 ott_with_kernel(&i, self.kernel)
227 }
228
229 #[inline(always)]
230 pub fn into_stream(self) -> Result<OttStream, OttError> {
231 let p = OttParams {
232 period: self.period,
233 percent: self.percent,
234 ma_type: self.ma_type,
235 };
236 OttStream::try_new(p)
237 }
238}
239
240#[derive(Debug, Error)]
241pub enum OttError {
242 #[error("ott: Input data slice is empty.")]
243 EmptyInputData,
244 #[error("ott: All values are NaN.")]
245 AllValuesNaN,
246 #[error("ott: Invalid period: period = {period}, data length = {data_len}")]
247 InvalidPeriod { period: usize, data_len: usize },
248 #[error("ott: Not enough valid data: needed = {needed}, valid = {valid}")]
249 NotEnoughValidData { needed: usize, valid: usize },
250 #[error("ott: Invalid percent: {percent}")]
251 InvalidPercent { percent: f64 },
252 #[error("ott: Invalid moving average type: {ma_type}")]
253 InvalidMaType { ma_type: String },
254 #[error("ott: Moving average calculation failed: {reason}")]
255 MaCalculationFailed { reason: String },
256 #[error("ott: Output length mismatch: expected = {expected}, got = {got}")]
257 OutputLengthMismatch { expected: usize, got: usize },
258 #[error("ott: Invalid range: start={start}, end={end}, step={step}")]
259 InvalidRange {
260 start: String,
261 end: String,
262 step: String,
263 },
264 #[error("ott: Invalid kernel for batch operation. Expected batch kernel, got: {0:?}")]
265 InvalidKernelForBatch(Kernel),
266
267 #[error("ott: Invalid kernel for batch operation")]
268 InvalidBatchKernel,
269}
270
271#[inline]
272pub fn ott(input: &OttInput) -> Result<OttOutput, OttError> {
273 ott_with_kernel(input, Kernel::Auto)
274}
275
276pub fn ott_with_kernel(input: &OttInput, kernel: Kernel) -> Result<OttOutput, OttError> {
277 let (data, period, percent, ma_type, first, chosen) = ott_prepare(input, kernel)?;
278
279 if false
280 && chosen == Kernel::Scalar
281 && period == 2
282 && percent == 1.4
283 && ma_type.to_uppercase() == "VAR"
284 {
285 let mut out = vec![f64::NAN; data.len()];
286 unsafe {
287 ott_scalar_classic(data, period, percent, first, &mut out)?;
288 }
289 return Ok(OttOutput { values: out });
290 }
291
292 let ma_values = calculate_moving_average(data, period, ma_type, chosen)?;
293
294 let ma_first = ma_values
295 .iter()
296 .position(|&x| !x.is_nan())
297 .unwrap_or(data.len());
298
299 let mut out = alloc_with_nan_prefix(data.len(), ma_first);
300
301 ott_compute_into(
302 data, &ma_values, percent, ma_first, period, chosen, &mut out,
303 );
304
305 Ok(OttOutput { values: out })
306}
307
308#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
309#[inline]
310pub fn ott_into(input: &OttInput, out: &mut [f64]) -> Result<(), OttError> {
311 ott_into_slice(out, input, Kernel::Auto)
312}
313
314#[inline]
315pub fn ott_into_slice(dst: &mut [f64], input: &OttInput, kern: Kernel) -> Result<(), OttError> {
316 let (data, period, percent, ma_type, first, chosen) = ott_prepare(input, kern)?;
317
318 if dst.len() != data.len() {
319 return Err(OttError::OutputLengthMismatch {
320 expected: data.len(),
321 got: dst.len(),
322 });
323 }
324
325 let ma_values = calculate_moving_average(data, period, ma_type, chosen)?;
326
327 let ma_first = ma_values
328 .iter()
329 .position(|&x| !x.is_nan())
330 .unwrap_or(data.len());
331
332 ott_compute_into(data, &ma_values, percent, ma_first, period, chosen, dst);
333
334 for v in &mut dst[..ma_first] {
335 *v = f64::NAN;
336 }
337
338 Ok(())
339}
340
341#[inline(always)]
342fn ott_prepare<'a>(
343 input: &'a OttInput,
344 kernel: Kernel,
345) -> Result<(&'a [f64], usize, f64, &'a str, usize, Kernel), OttError> {
346 let data: &[f64] = input.as_ref();
347 if data.is_empty() {
348 return Err(OttError::EmptyInputData);
349 }
350
351 let first = data
352 .iter()
353 .position(|x| !x.is_nan())
354 .ok_or(OttError::AllValuesNaN)?;
355
356 let period = input.get_period();
357 let percent = input.get_percent();
358 let ma_type = input.get_ma_type();
359
360 if period == 0 || period > data.len() {
361 return Err(OttError::InvalidPeriod {
362 period,
363 data_len: data.len(),
364 });
365 }
366
367 if data.len() - first < period {
368 return Err(OttError::NotEnoughValidData {
369 needed: period,
370 valid: data.len() - first,
371 });
372 }
373
374 if percent < 0.0 || percent.is_nan() || percent.is_infinite() {
375 return Err(OttError::InvalidPercent { percent });
376 }
377
378 let chosen = match kernel {
379 Kernel::Auto => Kernel::Scalar,
380 k => k,
381 };
382
383 Ok((data, period, percent, ma_type, first, chosen))
384}
385
386fn calculate_moving_average(
387 data: &[f64],
388 period: usize,
389 ma_type: &str,
390 kernel: Kernel,
391) -> Result<Vec<f64>, OttError> {
392 match ma_type.to_uppercase().as_str() {
393 "SMA" => {
394 let params = SmaParams {
395 period: Some(period),
396 };
397 let input = SmaInput::from_slice(data, params);
398 sma_with_kernel(&input, kernel)
399 .map(|o| o.values)
400 .map_err(|e| OttError::MaCalculationFailed {
401 reason: e.to_string(),
402 })
403 }
404 "EMA" => {
405 let params = EmaParams {
406 period: Some(period),
407 };
408 let input = EmaInput::from_slice(data, params);
409 ema_with_kernel(&input, kernel)
410 .map(|o| o.values)
411 .map_err(|e| OttError::MaCalculationFailed {
412 reason: e.to_string(),
413 })
414 }
415 "WMA" => {
416 let params = WmaParams {
417 period: Some(period),
418 };
419 let input = WmaInput::from_slice(data, params);
420 wma_with_kernel(&input, kernel)
421 .map(|o| o.values)
422 .map_err(|e| OttError::MaCalculationFailed {
423 reason: e.to_string(),
424 })
425 }
426 "TMA" => calculate_tma(data, period, kernel),
427 "VAR" => calculate_var_ma(data, period),
428 "WWMA" => calculate_wwma(data, period),
429 "ZLEMA" => {
430 let params = ZlemaParams {
431 period: Some(period),
432 };
433 let input = ZlemaInput::from_slice(data, params);
434 zlema_with_kernel(&input, kernel)
435 .map(|o| o.values)
436 .map_err(|e| OttError::MaCalculationFailed {
437 reason: e.to_string(),
438 })
439 }
440 "TSF" => {
441 let params = TsfParams {
442 period: Some(period),
443 };
444 let input = TsfInput::from_slice(data, params);
445 tsf_with_kernel(&input, kernel)
446 .map(|o| o.values)
447 .map_err(|e| OttError::MaCalculationFailed {
448 reason: e.to_string(),
449 })
450 }
451 _ => Err(OttError::InvalidMaType {
452 ma_type: ma_type.to_string(),
453 }),
454 }
455}
456
457fn calculate_tma(data: &[f64], period: usize, kernel: Kernel) -> Result<Vec<f64>, OttError> {
458 let half_period = (period + 1) / 2;
459 let floor_half = period / 2 + 1;
460
461 let params1 = SmaParams {
462 period: Some(half_period),
463 };
464 let input1 = SmaInput::from_slice(data, params1);
465 let sma1 = sma_with_kernel(&input1, kernel).map_err(|e| OttError::MaCalculationFailed {
466 reason: e.to_string(),
467 })?;
468
469 let params2 = SmaParams {
470 period: Some(floor_half),
471 };
472 let input2 = SmaInput::from_slice(&sma1.values, params2);
473 let sma2 = sma_with_kernel(&input2, kernel).map_err(|e| OttError::MaCalculationFailed {
474 reason: e.to_string(),
475 })?;
476
477 Ok(sma2.values)
478}
479
480fn calculate_wwma(data: &[f64], period: usize) -> Result<Vec<f64>, OttError> {
481 let first = data
482 .iter()
483 .position(|x| !x.is_nan())
484 .ok_or(OttError::AllValuesNaN)?;
485
486 if data.len() - first < period {
487 return Err(OttError::NotEnoughValidData {
488 needed: period,
489 valid: data.len() - first,
490 });
491 }
492
493 let mut out = alloc_with_nan_prefix(data.len(), first);
494 let alpha = 1.0 / period as f64;
495
496 let mut wwma = alpha * data[first];
497 out[first] = wwma;
498
499 for i in (first + 1)..data.len() {
500 let xi = data[i];
501 if xi.is_nan() {
502 continue;
503 }
504 wwma = alpha * xi + (1.0 - alpha) * wwma;
505 out[i] = wwma;
506 }
507 Ok(out)
508}
509
510fn calculate_var_ma(data: &[f64], period: usize) -> Result<Vec<f64>, OttError> {
511 let first = data
512 .iter()
513 .position(|x| !x.is_nan())
514 .ok_or(OttError::AllValuesNaN)?;
515
516 let mut out = alloc_with_nan_prefix(data.len(), first);
517 let valpha = 2.0 / (period as f64 + 1.0);
518
519 let mut ring_u = [0.0f64; 9];
520 let mut ring_d = [0.0f64; 9];
521 let mut u_sum = 0.0;
522 let mut d_sum = 0.0;
523 let mut idx = 0usize;
524
525 let mut var = 0.0;
526 out[first] = var;
527
528 let start = first + 1;
529 let pre_end = (first + 8).min(data.len().saturating_sub(1));
530 for i in start..=pre_end {
531 let a = data[i - 1];
532 let b = data[i];
533 if a.is_nan() || b.is_nan() {
534 continue;
535 }
536 let up = (b - a).max(0.0);
537 let down = (a - b).max(0.0);
538 ring_u[idx] = up;
539 u_sum += up;
540 ring_d[idx] = down;
541 d_sum += down;
542 idx = (idx + 1) % 9;
543 out[i] = var;
544 }
545
546 if data.len() - first <= 9 {
547 return Ok(out);
548 }
549
550 for i in (first + 9)..data.len() {
551 let a = data[i - 1];
552 let b = data[i];
553 if a.is_nan() || b.is_nan() {
554 continue;
555 }
556
557 let old_u = ring_u[idx];
558 let old_d = ring_d[idx];
559 let up = (b - a).max(0.0);
560 let down = (a - b).max(0.0);
561
562 u_sum += up - old_u;
563 d_sum += down - old_d;
564
565 ring_u[idx] = up;
566 ring_d[idx] = down;
567 idx = (idx + 1) % 9;
568
569 let denom = u_sum + d_sum;
570 let vcmo = if denom != 0.0 {
571 (u_sum - d_sum) / denom
572 } else {
573 0.0
574 };
575
576 var = valpha * vcmo.abs() * b + (1.0 - valpha * vcmo.abs()) * var;
577 out[i] = var;
578 }
579
580 Ok(out)
581}
582
583#[inline(always)]
584fn ott_compute_into(
585 data: &[f64],
586 ma_values: &[f64],
587 percent: f64,
588 first: usize,
589 period: usize,
590 kernel: Kernel,
591 out: &mut [f64],
592) {
593 unsafe {
594 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
595 {
596 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
597 ott_simd128(data, ma_values, percent, first, period, out);
598 return;
599 }
600 }
601
602 match kernel {
603 Kernel::Scalar | Kernel::ScalarBatch => {
604 ott_scalar(data, ma_values, percent, first, period, out)
605 }
606 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
607 Kernel::Avx2 | Kernel::Avx2Batch => {
608 ott_avx2(data, ma_values, percent, first, period, out)
609 }
610 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
611 Kernel::Avx512 | Kernel::Avx512Batch => {
612 ott_avx512(data, ma_values, percent, first, period, out)
613 }
614 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
615 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
616 ott_scalar(data, ma_values, percent, first, period, out)
617 }
618 _ => unreachable!(),
619 }
620 }
621}
622
623#[inline(always)]
624pub unsafe fn ott_scalar_classic(
625 data: &[f64],
626 period: usize,
627 percent: f64,
628 first: usize,
629 out: &mut [f64],
630) -> Result<(), OttError> {
631 let len = data.len();
632
633 let valpha = 2.0 / (period as f64 + 1.0);
634
635 let mut ring_u = [0.0f64; 9];
636 let mut ring_d = [0.0f64; 9];
637 let mut u_sum = 0.0;
638 let mut d_sum = 0.0;
639 let mut idx = 0usize;
640
641 let mut var = 0.0;
642 let mut var_ma = vec![f64::NAN; len];
643 var_ma[first] = var;
644
645 let start = first + 1;
646 let pre_end = (first + 8).min(len.saturating_sub(1));
647 for i in start..=pre_end {
648 let a = data[i - 1];
649 let b = data[i];
650 if !a.is_nan() && !b.is_nan() {
651 let up = (b - a).max(0.0);
652 let down = (a - b).max(0.0);
653 ring_u[idx] = up;
654 u_sum += up;
655 ring_d[idx] = down;
656 d_sum += down;
657 idx = (idx + 1) % 9;
658 }
659 var_ma[i] = var;
660 }
661
662 for i in (first + 9)..len {
663 let a = data[i - 1];
664 let b = data[i];
665 if !a.is_nan() && !b.is_nan() {
666 let old_u = ring_u[idx];
667 let old_d = ring_d[idx];
668 let up = (b - a).max(0.0);
669 let down = (a - b).max(0.0);
670
671 u_sum += up - old_u;
672 d_sum += down - old_d;
673
674 ring_u[idx] = up;
675 ring_d[idx] = down;
676 idx = (idx + 1) % 9;
677
678 let denom = u_sum + d_sum;
679 let vcmo = if denom != 0.0 {
680 (u_sum - d_sum) / denom
681 } else {
682 0.0
683 };
684
685 var = valpha * vcmo.abs() * b + (1.0 - valpha * vcmo.abs()) * var;
686 var_ma[i] = var;
687 } else if i > first {
688 var_ma[i] = var_ma[i - 1];
689 }
690 }
691
692 let fark = percent * 0.01;
693 let ma_first = first;
694
695 for i in 0..ma_first {
696 out[i] = f64::NAN;
697 }
698
699 let mut dir = 1i32;
700 let mut long_stop = f64::NAN;
701 let mut short_stop = f64::NAN;
702
703 for i in ma_first..len {
704 let mavg = var_ma[i];
705
706 if mavg.is_nan() {
707 continue;
708 }
709
710 let offset = mavg * fark;
711
712 let long_stop_prev = if long_stop.is_nan() {
713 mavg - offset
714 } else {
715 long_stop
716 };
717 long_stop = if mavg > long_stop_prev {
718 (mavg - offset).max(long_stop_prev)
719 } else {
720 mavg - offset
721 };
722
723 let short_stop_prev = if short_stop.is_nan() {
724 mavg + offset
725 } else {
726 short_stop
727 };
728 short_stop = if mavg < short_stop_prev {
729 (mavg + offset).min(short_stop_prev)
730 } else {
731 mavg + offset
732 };
733
734 let prev_dir = dir;
735 if mavg > short_stop_prev {
736 dir = 1;
737 } else if mavg <= long_stop_prev {
738 dir = -1;
739 }
740
741 out[i] = if dir == -1 { short_stop } else { long_stop };
742 }
743
744 Ok(())
745}
746
747#[inline]
748pub fn ott_scalar(
749 _data: &[f64],
750 ma_values: &[f64],
751 percent: f64,
752 first_val: usize,
753 _period: usize,
754 out: &mut [f64],
755) {
756 let len = ma_values.len();
757 if first_val >= len {
758 return;
759 }
760
761 let fark = percent * 0.01;
762 let scale_minus = 1.0 - (percent * 0.005);
763
764 let mut i = first_val;
765 let mut m = ma_values[i];
766 if m.is_nan() {
767 if let Some(next) = ma_values[first_val..].iter().position(|x| !x.is_nan()) {
768 i = first_val + next;
769 m = ma_values[i];
770 } else {
771 return;
772 }
773 }
774
775 let mut long_stop = m.mul_add(-fark, m);
776 let mut short_stop = m.mul_add(fark, m);
777 let mut dir: i32 = 1;
778
779 let mt0 = long_stop;
780 let scale0 = if m > mt0 {
781 scale_minus + fark
782 } else {
783 scale_minus
784 };
785 out[i] = mt0 * scale0;
786 i += 1;
787
788 while i < len {
789 let mavg = ma_values[i];
790 if !mavg.is_nan() {
791 let cand_long = mavg.mul_add(-fark, mavg);
792 let cand_short = mavg.mul_add(fark, mavg);
793
794 let lprev = long_stop;
795 let sprev = short_stop;
796
797 if mavg > lprev {
798 long_stop = if cand_long > lprev { cand_long } else { lprev };
799 } else {
800 long_stop = cand_long;
801 }
802
803 if mavg < sprev {
804 short_stop = if cand_short < sprev {
805 cand_short
806 } else {
807 sprev
808 };
809 } else {
810 short_stop = cand_short;
811 }
812
813 if dir == -1 && mavg > sprev {
814 dir = 1;
815 } else if dir == 1 && mavg < lprev {
816 dir = -1;
817 }
818
819 let mt = if dir == 1 { long_stop } else { short_stop };
820 let scale = if mavg > mt {
821 scale_minus + fark
822 } else {
823 scale_minus
824 };
825 out[i] = mt * scale;
826 }
827 i += 1;
828 }
829}
830
831#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
832#[inline]
833unsafe fn ott_simd128(
834 data: &[f64],
835 ma_values: &[f64],
836 percent: f64,
837 first_val: usize,
838 period: usize,
839 out: &mut [f64],
840) {
841 ott_scalar(data, ma_values, percent, first_val, period, out);
842}
843
844#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
845#[target_feature(enable = "avx2,fma")]
846unsafe fn ott_avx2(
847 data: &[f64],
848 ma_values: &[f64],
849 percent: f64,
850 first_val: usize,
851 period: usize,
852 out: &mut [f64],
853) {
854 ott_scalar(data, ma_values, percent, first_val, period, out);
855}
856
857#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
858#[target_feature(enable = "avx512f,fma")]
859unsafe fn ott_avx512(
860 data: &[f64],
861 ma_values: &[f64],
862 percent: f64,
863 first_val: usize,
864 period: usize,
865 out: &mut [f64],
866) {
867 ott_scalar(data, ma_values, percent, first_val, period, out);
868}
869
870#[derive(Debug, Clone)]
871pub struct OttStream {
872 period: usize,
873 percent: f64,
874 ma_type: String,
875
876 buf: Vec<f64>,
877 pos: usize,
878 count: usize,
879
880 long_stop: f64,
881 short_stop: f64,
882 dir: i32,
883
884 fark: f64,
885 scale_plus: f64,
886 scale_minus: f64,
887
888 sma_sum: f64,
889
890 ema_alpha: f64,
891 ema_state: Option<f64>,
892
893 ww_alpha: f64,
894 wwma_state: Option<f64>,
895
896 wma_simple_sum: f64,
897 wma_weighted_sum: f64,
898 wma_inv_norm: f64,
899
900 zlema_alpha: f64,
901 zlema_state: Option<f64>,
902 zlema_lag: usize,
903
904 var_alpha_base: f64,
905 var_state: f64,
906 var_u_ring: [f64; 9],
907 var_d_ring: [f64; 9],
908 var_idx: usize,
909 var_u_sum: f64,
910 var_d_sum: f64,
911 var_seen_diffs: usize,
912}
913
914impl OttStream {
915 pub fn try_new(params: OttParams) -> Result<Self, OttError> {
916 let period = params.period.unwrap_or(2);
917 let percent = params.percent.unwrap_or(1.4);
918 let ma_type = params.ma_type.unwrap_or_else(|| "VAR".to_string());
919
920 if period == 0 {
921 return Err(OttError::InvalidPeriod {
922 period,
923 data_len: 0,
924 });
925 }
926 if percent < 0.0 || !percent.is_finite() {
927 return Err(OttError::InvalidPercent { percent });
928 }
929
930 let need = if ma_type.eq_ignore_ascii_case("VAR") {
931 period.max(10)
932 } else {
933 period.max(1)
934 };
935
936 let fark = percent * 0.01;
937 let scale_minus = 1.0 - (percent * 0.005);
938 let scale_plus = 1.0 + (percent * 0.005);
939
940 let ema_alpha = 2.0 / (period as f64 + 1.0);
941 let ww_alpha = 1.0 / period as f64;
942 let zlema_alpha = ema_alpha;
943 let zlema_lag = ((period.saturating_sub(1)) as f64 / 2.0).floor() as usize;
944
945 let n = period as f64;
946 let wma_inv_norm = if period > 1 {
947 2.0 / (n * (n + 1.0))
948 } else {
949 1.0
950 };
951
952 Ok(Self {
953 period,
954 percent,
955 ma_type,
956
957 buf: vec![f64::NAN; need],
958 pos: 0,
959 count: 0,
960
961 long_stop: f64::NAN,
962 short_stop: f64::NAN,
963 dir: 1,
964
965 fark,
966 scale_plus,
967 scale_minus,
968
969 sma_sum: 0.0,
970
971 ema_alpha,
972 ema_state: None,
973
974 ww_alpha,
975 wwma_state: None,
976
977 wma_simple_sum: 0.0,
978 wma_weighted_sum: 0.0,
979 wma_inv_norm,
980
981 zlema_alpha,
982 zlema_state: None,
983 zlema_lag,
984
985 var_alpha_base: ema_alpha,
986 var_state: 0.0,
987 var_u_ring: [0.0; 9],
988 var_d_ring: [0.0; 9],
989 var_idx: 0,
990 var_u_sum: 0.0,
991 var_d_sum: 0.0,
992 var_seen_diffs: 0,
993 })
994 }
995
996 #[inline]
997 pub fn update(&mut self, x: f64) -> Option<f64> {
998 let cap = self.buf.len();
999
1000 let old = self.buf[self.pos];
1001 self.buf[self.pos] = x;
1002 self.pos = (self.pos + 1) % cap;
1003 if self.count < cap {
1004 self.count += 1;
1005 }
1006
1007 let ma = self.calculate_ma(x, old);
1008
1009 if !ma.is_finite() || self.count < cap {
1010 return None;
1011 }
1012
1013 let offset = ma * self.fark;
1014
1015 let lprev = if self.long_stop.is_nan() {
1016 ma - offset
1017 } else {
1018 self.long_stop
1019 };
1020 let sprev = if self.short_stop.is_nan() {
1021 ma + offset
1022 } else {
1023 self.short_stop
1024 };
1025
1026 let cand_long = ma - offset;
1027 self.long_stop = if ma > lprev {
1028 if cand_long > lprev {
1029 cand_long
1030 } else {
1031 lprev
1032 }
1033 } else {
1034 cand_long
1035 };
1036
1037 let cand_short = ma + offset;
1038 self.short_stop = if ma < sprev {
1039 if cand_short < sprev {
1040 cand_short
1041 } else {
1042 sprev
1043 }
1044 } else {
1045 cand_short
1046 };
1047
1048 if self.dir == -1 && ma > sprev {
1049 self.dir = 1;
1050 } else if self.dir == 1 && ma < lprev {
1051 self.dir = -1;
1052 }
1053
1054 let mt = if self.dir == 1 {
1055 self.long_stop
1056 } else {
1057 self.short_stop
1058 };
1059 let scaled = if ma > mt {
1060 mt * self.scale_plus
1061 } else {
1062 mt * self.scale_minus
1063 };
1064
1065 Some(scaled)
1066 }
1067
1068 #[inline]
1069 fn calculate_ma(&mut self, x: f64, old: f64) -> f64 {
1070 match self.ma_type.as_str() {
1071 "VAR" => self.update_var(x),
1072
1073 "WWMA" => self.update_wwma(x),
1074
1075 "EMA" => self.update_ema(x),
1076
1077 "SMA" => self.update_sma(x, old),
1078
1079 "WMA" => self.update_wma(x, old),
1080
1081 "ZLEMA" => self.update_zlema(x),
1082
1083 _ => self.update_sma(x, old),
1084 }
1085 }
1086
1087 #[inline]
1088 fn update_sma(&mut self, x: f64, old: f64) -> f64 {
1089 if old.is_finite() {
1090 self.sma_sum += x - old;
1091 } else {
1092 self.sma_sum += x;
1093 }
1094 if self.count < self.period {
1095 f64::NAN
1096 } else {
1097 self.sma_sum / self.period as f64
1098 }
1099 }
1100
1101 #[inline]
1102 fn update_ema(&mut self, x: f64) -> f64 {
1103 let ema = match self.ema_state {
1104 Some(prev) => self.ema_alpha.mul_add(x - prev, prev),
1105 None => x,
1106 };
1107 self.ema_state = Some(ema);
1108 ema
1109 }
1110
1111 #[inline]
1112 fn update_wwma(&mut self, x: f64) -> f64 {
1113 let ww = match self.wwma_state {
1114 Some(prev) => self.ww_alpha.mul_add(x - prev, prev),
1115 None => self.ww_alpha * x,
1116 };
1117 self.wwma_state = Some(ww);
1118 ww
1119 }
1120
1121 #[inline]
1122 fn update_wma(&mut self, x: f64, old: f64) -> f64 {
1123 if self.count <= self.period {
1124 let w = self.count as f64;
1125 self.wma_simple_sum += x;
1126 self.wma_weighted_sum += w * x;
1127
1128 if self.count < self.period {
1129 return f64::NAN;
1130 }
1131 return self.wma_weighted_sum * self.wma_inv_norm;
1132 }
1133
1134 let s_prev = self.wma_simple_sum;
1135 self.wma_weighted_sum += self.period as f64 * x - s_prev;
1136
1137 let x_out = old;
1138 self.wma_simple_sum += x - x_out;
1139
1140 self.wma_weighted_sum * self.wma_inv_norm
1141 }
1142
1143 #[inline]
1144 fn update_zlema(&mut self, x: f64) -> f64 {
1145 let cap = self.buf.len();
1146 let lag_idx = (self.pos + cap - 1 - self.zlema_lag % cap) % cap;
1147 let lagged = self.buf[lag_idx];
1148
1149 let de_lagged = if lagged.is_finite() {
1150 x + (x - lagged)
1151 } else {
1152 x
1153 };
1154 let z = match self.zlema_state {
1155 Some(prev) => self.zlema_alpha.mul_add(de_lagged - prev, prev),
1156 None => de_lagged,
1157 };
1158 self.zlema_state = Some(z);
1159 z
1160 }
1161
1162 #[inline]
1163 fn update_var(&mut self, x: f64) -> f64 {
1164 let cap = self.buf.len();
1165 if self.count == 0 {
1166 return self.var_state;
1167 }
1168 let prev_idx = (self.pos + cap - 2) % cap;
1169 let prev = self.buf[prev_idx];
1170 if !x.is_finite() || !prev.is_finite() {
1171 return self.var_state;
1172 }
1173
1174 let up = (x - prev).max(0.0);
1175 let dn = (prev - x).max(0.0);
1176
1177 let old_u = self.var_u_ring[self.var_idx];
1178 let old_d = self.var_d_ring[self.var_idx];
1179 self.var_u_ring[self.var_idx] = up;
1180 self.var_d_ring[self.var_idx] = dn;
1181
1182 if self.var_seen_diffs < 9 {
1183 self.var_seen_diffs += 1;
1184 }
1185
1186 self.var_u_sum += up - old_u;
1187 self.var_d_sum += dn - old_d;
1188 self.var_idx = (self.var_idx + 1) % 9;
1189
1190 if self.count < 10 || self.var_seen_diffs < 9 {
1191 return self.var_state;
1192 }
1193
1194 let denom = self.var_u_sum + self.var_d_sum;
1195 let cmo_abs = if denom != 0.0 {
1196 (self.var_u_sum - self.var_d_sum).abs() / denom
1197 } else {
1198 0.0
1199 };
1200
1201 let alpha = cmo_abs * self.var_alpha_base;
1202
1203 self.var_state = alpha.mul_add(x - self.var_state, self.var_state);
1204 self.var_state
1205 }
1206}
1207
1208#[derive(Clone, Debug)]
1209pub struct OttBatchRange {
1210 pub period: (usize, usize, usize),
1211 pub percent: (f64, f64, f64),
1212 pub ma_types: Vec<String>,
1213}
1214
1215impl Default for OttBatchRange {
1216 fn default() -> Self {
1217 Self {
1218 period: (2, 251, 1),
1219 percent: (1.4, 1.4, 0.0),
1220 ma_types: vec!["VAR".to_string()],
1221 }
1222 }
1223}
1224
1225#[derive(Clone, Debug)]
1226pub struct OttBatchOutput {
1227 pub values: Vec<f64>,
1228 pub combos: Vec<OttParams>,
1229 pub rows: usize,
1230 pub cols: usize,
1231}
1232
1233impl OttBatchOutput {
1234 pub fn row_for_params(&self, p: &OttParams) -> Option<usize> {
1235 let tp = p.period.unwrap_or(2);
1236 let tq = p.percent.unwrap_or(1.4);
1237 let tt = p.ma_type.as_deref().unwrap_or("VAR");
1238 self.combos.iter().position(|c| {
1239 c.period.unwrap_or(2) == tp
1240 && (c.percent.unwrap_or(1.4) - tq).abs() < 1e-12
1241 && c.ma_type.as_deref().unwrap_or("VAR") == tt
1242 })
1243 }
1244
1245 pub fn values_for(&self, p: &OttParams) -> Option<&[f64]> {
1246 self.row_for_params(p).map(|row| {
1247 let start = row * self.cols;
1248 &self.values[start..start + self.cols]
1249 })
1250 }
1251}
1252
1253#[derive(Clone, Debug, Default)]
1254pub struct OttBatchBuilder {
1255 range: OttBatchRange,
1256 kernel: Kernel,
1257}
1258
1259impl OttBatchBuilder {
1260 pub fn new() -> Self {
1261 Self::default()
1262 }
1263
1264 pub fn kernel(mut self, k: Kernel) -> Self {
1265 self.kernel = k;
1266 self
1267 }
1268
1269 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<OttBatchOutput, OttError> {
1270 OttBatchBuilder::new().kernel(k).apply_slice(data)
1271 }
1272
1273 pub fn with_default_candles(c: &Candles) -> Result<OttBatchOutput, OttError> {
1274 OttBatchBuilder::new()
1275 .kernel(Kernel::Auto)
1276 .apply_candles(c, "close")
1277 }
1278
1279 #[inline]
1280 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
1281 self.range.period = (start, end, step);
1282 self
1283 }
1284
1285 #[inline]
1286 pub fn period_static(mut self, val: usize) -> Self {
1287 self.range.period = (val, val, 0);
1288 self
1289 }
1290
1291 #[inline]
1292 pub fn percent_range(mut self, start: f64, end: f64, step: f64) -> Self {
1293 self.range.percent = (start, end, step);
1294 self
1295 }
1296
1297 #[inline]
1298 pub fn percent_static(mut self, val: f64) -> Self {
1299 self.range.percent = (val, val, 0.0);
1300 self
1301 }
1302
1303 #[inline]
1304 pub fn ma_types(mut self, types: Vec<String>) -> Self {
1305 self.range.ma_types = types;
1306 self
1307 }
1308
1309 pub fn apply_slice(self, data: &[f64]) -> Result<OttBatchOutput, OttError> {
1310 ott_batch_with_kernel(data, &self.range, self.kernel)
1311 }
1312
1313 pub fn apply_batch(self, data: &[f64]) -> Result<OttBatchOutput, OttError> {
1314 ott_batch_with_kernel(data, &self.range, self.kernel)
1315 }
1316
1317 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<OttBatchOutput, OttError> {
1318 let slice = source_type(c, src);
1319 self.apply_slice(slice)
1320 }
1321}
1322
1323#[inline(always)]
1324fn expand_grid_ott(r: &OttBatchRange) -> Result<Vec<OttParams>, OttError> {
1325 fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, OttError> {
1326 if step == 0 || start == end {
1327 return Ok(vec![start]);
1328 }
1329 if start < end {
1330 return Ok((start..=end).step_by(step.max(1)).collect());
1331 }
1332 let mut v = Vec::new();
1333 let mut x = start as isize;
1334 let end_i = end as isize;
1335 let st = (step as isize).max(1);
1336 while x >= end_i {
1337 v.push(x as usize);
1338 x -= st;
1339 }
1340 if v.is_empty() {
1341 return Err(OttError::InvalidRange {
1342 start: start.to_string(),
1343 end: end.to_string(),
1344 step: step.to_string(),
1345 });
1346 }
1347 Ok(v)
1348 }
1349 fn axis_f64((start, end, step): (f64, f64, f64)) -> Result<Vec<f64>, OttError> {
1350 if step.abs() < 1e-12 || (start - end).abs() < 1e-12 {
1351 return Ok(vec![start]);
1352 }
1353 if start < end {
1354 let mut v = Vec::new();
1355 let mut x = start;
1356 let st = step.abs();
1357 while x <= end + 1e-12 {
1358 v.push(x);
1359 x += st;
1360 }
1361 if v.is_empty() {
1362 return Err(OttError::InvalidRange {
1363 start: start.to_string(),
1364 end: end.to_string(),
1365 step: step.to_string(),
1366 });
1367 }
1368 return Ok(v);
1369 }
1370 let mut v = Vec::new();
1371 let mut x = start;
1372 let st = step.abs();
1373 while x + 1e-12 >= end {
1374 v.push(x);
1375 x -= st;
1376 }
1377 if v.is_empty() {
1378 return Err(OttError::InvalidRange {
1379 start: start.to_string(),
1380 end: end.to_string(),
1381 step: step.to_string(),
1382 });
1383 }
1384 Ok(v)
1385 }
1386
1387 let periods = axis_usize(r.period)?;
1388 let percents = axis_f64(r.percent)?;
1389 let types = if r.ma_types.is_empty() {
1390 vec!["VAR".to_string()]
1391 } else {
1392 r.ma_types.clone()
1393 };
1394 let cap = periods
1395 .len()
1396 .checked_mul(percents.len())
1397 .and_then(|x| x.checked_mul(types.len()))
1398 .ok_or_else(|| OttError::InvalidRange {
1399 start: "cap".into(),
1400 end: "overflow".into(),
1401 step: "mul".into(),
1402 })?;
1403
1404 let mut out = Vec::with_capacity(cap);
1405 for &p in &periods {
1406 for &pct in &percents {
1407 for mt in &types {
1408 out.push(OttParams {
1409 period: Some(p),
1410 percent: Some(pct),
1411 ma_type: Some(mt.clone()),
1412 });
1413 }
1414 }
1415 }
1416 if out.is_empty() {
1417 return Err(OttError::InvalidRange {
1418 start: r.period.0.to_string(),
1419 end: r.period.1.to_string(),
1420 step: r.period.2.to_string(),
1421 });
1422 }
1423 Ok(out)
1424}
1425
1426#[inline(always)]
1427fn ott_batch_inner_into(
1428 data: &[f64],
1429 sweep: &OttBatchRange,
1430 kern: Kernel,
1431 parallel: bool,
1432 out: &mut [f64],
1433) -> Result<Vec<OttParams>, OttError> {
1434 let combos = expand_grid_ott(sweep)?;
1435
1436 let cols = data.len();
1437 if cols == 0 {
1438 return Err(OttError::EmptyInputData);
1439 }
1440
1441 let first = data
1442 .iter()
1443 .position(|x| !x.is_nan())
1444 .ok_or(OttError::AllValuesNaN)?;
1445 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
1446 if max_p == 0 || max_p > cols {
1447 return Err(OttError::InvalidPeriod {
1448 period: max_p,
1449 data_len: cols,
1450 });
1451 }
1452 if cols - first < max_p {
1453 return Err(OttError::NotEnoughValidData {
1454 needed: max_p,
1455 valid: cols - first,
1456 });
1457 }
1458
1459 let row_kern = match kern {
1460 Kernel::Auto => match detect_best_batch_kernel() {
1461 Kernel::Avx512Batch => Kernel::Avx512,
1462 Kernel::Avx2Batch => Kernel::Avx2,
1463 Kernel::ScalarBatch => Kernel::Scalar,
1464 _ => Kernel::Scalar,
1465 },
1466 Kernel::Avx512Batch => Kernel::Avx512,
1467 Kernel::Avx2Batch => Kernel::Avx2,
1468 Kernel::ScalarBatch => Kernel::Scalar,
1469 k => k,
1470 };
1471
1472 let mut ma_cache: HashMap<(usize, String), (Vec<f64>, usize)> = HashMap::new();
1473 for prm in &combos {
1474 let p = prm.period.unwrap();
1475 if p == 0 || p > cols {
1476 return Err(OttError::InvalidPeriod {
1477 period: p,
1478 data_len: cols,
1479 });
1480 }
1481 let pct = prm.percent.unwrap();
1482 if pct < 0.0 || !pct.is_finite() {
1483 return Err(OttError::InvalidPercent { percent: pct });
1484 }
1485 let mt = prm.ma_type.as_deref().unwrap().to_uppercase();
1486 if !ma_cache.contains_key(&(p, mt.clone())) {
1487 let ma = calculate_moving_average(data, p, &mt, row_kern).map_err(|e| {
1488 OttError::MaCalculationFailed {
1489 reason: e.to_string(),
1490 }
1491 })?;
1492 let ma_first = ma.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
1493 ma_cache.insert((p, mt), (ma, ma_first));
1494 }
1495 }
1496
1497 let out_mu: &mut [MaybeUninit<f64>] = unsafe {
1498 core::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
1499 };
1500
1501 let do_row = |r: usize, dst_mu: &mut [MaybeUninit<f64>]| -> Result<(), OttError> {
1502 let prm = &combos[r];
1503 let p = prm.period.unwrap();
1504 let pct = prm.percent.unwrap();
1505 let mt = prm.ma_type.as_deref().unwrap();
1506
1507 let key = (p, mt.to_uppercase());
1508 let (ma, ma_first) = ma_cache.get(&key).expect("missing MA cache entry");
1509
1510 let row: &mut [f64] = unsafe {
1511 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len())
1512 };
1513
1514 for v in &mut row[..(*ma_first).min(cols)] {
1515 *v = f64::NAN;
1516 }
1517
1518 ott_compute_into(data, ma, pct, *ma_first, p, row_kern, row);
1519 Ok(())
1520 };
1521
1522 if parallel {
1523 #[cfg(not(target_arch = "wasm32"))]
1524 {
1525 use rayon::prelude::*;
1526 out_mu
1527 .par_chunks_mut(cols)
1528 .enumerate()
1529 .try_for_each(|(r, sl)| do_row(r, sl))?;
1530 }
1531 #[cfg(target_arch = "wasm32")]
1532 {
1533 for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
1534 do_row(r, sl)?;
1535 }
1536 }
1537 } else {
1538 for (r, sl) in out_mu.chunks_mut(cols).enumerate() {
1539 do_row(r, sl)?;
1540 }
1541 }
1542
1543 Ok(combos)
1544}
1545
1546#[inline(always)]
1547pub fn ott_batch_slice(
1548 data: &[f64],
1549 sweep: &OttBatchRange,
1550 kern: Kernel,
1551) -> Result<OttBatchOutput, OttError> {
1552 ott_batch_inner(data, sweep, kern, false)
1553}
1554
1555#[inline(always)]
1556pub fn ott_batch_par_slice(
1557 data: &[f64],
1558 sweep: &OttBatchRange,
1559 kern: Kernel,
1560) -> Result<OttBatchOutput, OttError> {
1561 ott_batch_inner(data, sweep, kern, true)
1562}
1563
1564#[inline(always)]
1565fn ott_batch_inner(
1566 data: &[f64],
1567 sweep: &OttBatchRange,
1568 kern: Kernel,
1569 parallel: bool,
1570) -> Result<OttBatchOutput, OttError> {
1571 let combos = expand_grid_ott(sweep)?;
1572 let cols = data.len();
1573 if cols == 0 {
1574 return Err(OttError::EmptyInputData);
1575 }
1576
1577 let rows = combos.len();
1578 rows.checked_mul(cols)
1579 .ok_or_else(|| OttError::InvalidRange {
1580 start: sweep.period.0.to_string(),
1581 end: sweep.period.1.to_string(),
1582 step: sweep.period.2.to_string(),
1583 })?;
1584 let mut buf_mu = make_uninit_matrix(rows, cols);
1585
1586 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
1587 let out: &mut [f64] =
1588 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
1589
1590 let combos = ott_batch_inner_into(data, sweep, kern, parallel, out)?;
1591
1592 let values = unsafe {
1593 Vec::from_raw_parts(
1594 guard.as_mut_ptr() as *mut f64,
1595 guard.len(),
1596 guard.capacity(),
1597 )
1598 };
1599
1600 Ok(OttBatchOutput {
1601 values,
1602 combos,
1603 rows,
1604 cols,
1605 })
1606}
1607
1608pub fn ott_batch_with_kernel(
1609 data: &[f64],
1610 sweep: &OttBatchRange,
1611 k: Kernel,
1612) -> Result<OttBatchOutput, OttError> {
1613 let kernel = match k {
1614 Kernel::Auto => detect_best_batch_kernel(),
1615 other if other.is_batch() => other,
1616 _ => return Err(OttError::InvalidKernelForBatch(k)),
1617 };
1618
1619 ott_batch_par_slice(data, sweep, kernel)
1620}
1621
1622#[cfg(feature = "python")]
1623#[pyfunction(name = "ott")]
1624#[pyo3(signature = (data, period=2, percent=1.4, ma_type="VAR", kernel=None))]
1625pub fn ott_py<'py>(
1626 py: Python<'py>,
1627 data: PyReadonlyArray1<'py, f64>,
1628 period: usize,
1629 percent: f64,
1630 ma_type: &str,
1631 kernel: Option<&str>,
1632) -> PyResult<Bound<'py, PyArray1<f64>>> {
1633 let slice_in = data.as_slice()?;
1634 let kern = validate_kernel(kernel, false)?;
1635 let params = OttParams {
1636 period: Some(period),
1637 percent: Some(percent),
1638 ma_type: Some(ma_type.to_string()),
1639 };
1640 let input = OttInput::from_slice(slice_in, params);
1641
1642 let result_vec: Vec<f64> = py
1643 .allow_threads(|| ott_with_kernel(&input, kern).map(|o| o.values))
1644 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1645
1646 Ok(result_vec.into_pyarray(py))
1647}
1648
1649#[cfg(feature = "python")]
1650#[pyclass(name = "OttStream")]
1651pub struct OttStreamPy {
1652 stream: OttStream,
1653}
1654
1655#[cfg(feature = "python")]
1656#[pymethods]
1657impl OttStreamPy {
1658 #[new]
1659 fn new(period: usize, percent: f64, ma_type: &str) -> PyResult<Self> {
1660 let params = OttParams {
1661 period: Some(period),
1662 percent: Some(percent),
1663 ma_type: Some(ma_type.to_string()),
1664 };
1665 let stream =
1666 OttStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1667 Ok(OttStreamPy { stream })
1668 }
1669
1670 fn update(&mut self, value: f64) -> Option<f64> {
1671 self.stream.update(value)
1672 }
1673}
1674
1675#[cfg(feature = "python")]
1676#[pyfunction(name = "ott_batch")]
1677#[pyo3(signature = (data, period_range, percent_range, ma_types, kernel=None))]
1678pub fn ott_batch_py<'py>(
1679 py: Python<'py>,
1680 data: PyReadonlyArray1<'py, f64>,
1681 period_range: (usize, usize, usize),
1682 percent_range: (f64, f64, f64),
1683 ma_types: Vec<String>,
1684 kernel: Option<&str>,
1685) -> PyResult<Bound<'py, PyDict>> {
1686 use numpy::{PyArray1, PyArrayMethods};
1687 let slice_in = data.as_slice()?;
1688 let sweep = OttBatchRange {
1689 period: period_range,
1690 percent: percent_range,
1691 ma_types,
1692 };
1693 let kern = validate_kernel(kernel, true)?;
1694
1695 let combos = expand_grid_ott(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1696 let rows = combos.len();
1697 let cols = slice_in.len();
1698
1699 let total = rows
1700 .checked_mul(cols)
1701 .ok_or_else(|| PyValueError::new_err("rows * cols overflow"))?;
1702
1703 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1704 let slice_out = unsafe { out_arr.as_slice_mut()? };
1705
1706 let combos = py
1707 .allow_threads(|| {
1708 let kernel = match kern {
1709 Kernel::Auto => detect_best_batch_kernel(),
1710 k => k,
1711 };
1712 let simd = match kernel {
1713 Kernel::Avx512Batch => Kernel::Avx512,
1714 Kernel::Avx2Batch => Kernel::Avx2,
1715 Kernel::ScalarBatch => Kernel::Scalar,
1716 _ => unreachable!(),
1717 };
1718 ott_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1719 })
1720 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1721
1722 let dict = PyDict::new(py);
1723 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1724 dict.set_item(
1725 "periods",
1726 combos
1727 .iter()
1728 .map(|p| p.period.unwrap() as u64)
1729 .collect::<Vec<_>>()
1730 .into_pyarray(py),
1731 )?;
1732 dict.set_item(
1733 "percents",
1734 combos
1735 .iter()
1736 .map(|p| p.percent.unwrap())
1737 .collect::<Vec<_>>()
1738 .into_pyarray(py),
1739 )?;
1740
1741 let types = PyList::new(py, combos.iter().map(|p| p.ma_type.as_deref().unwrap()))?;
1742 dict.set_item("ma_types", types)?;
1743 Ok(dict)
1744}
1745
1746#[cfg(all(feature = "python", feature = "cuda"))]
1747#[pyfunction(name = "ott_cuda_batch_dev")]
1748#[pyo3(signature = (data_f32, period_range, percent_range, ma_types, device_id=0))]
1749pub fn ott_cuda_batch_dev_py(
1750 py: Python<'_>,
1751 data_f32: numpy::PyReadonlyArray1<'_, f32>,
1752 period_range: (usize, usize, usize),
1753 percent_range: (f64, f64, f64),
1754 ma_types: Vec<String>,
1755 device_id: usize,
1756) -> PyResult<DeviceArrayF32Py> {
1757 use numpy::PyUntypedArrayMethods;
1758 if !cuda_available() {
1759 return Err(PyValueError::new_err("CUDA not available"));
1760 }
1761 let slice_in = data_f32.as_slice()?;
1762 let sweep = OttBatchRange {
1763 period: period_range,
1764 percent: percent_range,
1765 ma_types,
1766 };
1767
1768 let combos = expand_grid_ott(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1769 let cols = slice_in.len();
1770 for prm in &combos {
1771 let p = prm.period.unwrap();
1772 if p == 0 || p > cols {
1773 return Err(PyValueError::new_err(
1774 OttError::InvalidPeriod {
1775 period: p,
1776 data_len: cols,
1777 }
1778 .to_string(),
1779 ));
1780 }
1781 let pct = prm.percent.unwrap();
1782 if pct < 0.0 || !pct.is_finite() {
1783 return Err(PyValueError::new_err(
1784 OttError::InvalidPercent { percent: pct }.to_string(),
1785 ));
1786 }
1787 }
1788 let inner = py.allow_threads(|| {
1789 let cuda = CudaOtt::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1790 cuda.ott_batch_dev(slice_in, &sweep)
1791 .map_err(|e| PyValueError::new_err(e.to_string()))
1792 })?;
1793 let handle = make_device_array_py(device_id, inner)?;
1794 Ok(handle)
1795}
1796
1797#[cfg(all(feature = "python", feature = "cuda"))]
1798#[pyfunction(name = "ott_cuda_many_series_one_param_dev")]
1799#[pyo3(signature = (data_tm_f32, period, percent, ma_type="VAR", device_id=0))]
1800pub fn ott_cuda_many_series_one_param_dev_py(
1801 py: Python<'_>,
1802 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1803 period: usize,
1804 percent: f64,
1805 ma_type: &str,
1806 device_id: usize,
1807) -> PyResult<DeviceArrayF32Py> {
1808 use numpy::PyUntypedArrayMethods;
1809 if !cuda_available() {
1810 return Err(PyValueError::new_err("CUDA not available"));
1811 }
1812 let flat = data_tm_f32.as_slice()?;
1813 let rows = data_tm_f32.shape()[0];
1814 let cols = data_tm_f32.shape()[1];
1815 let params = OttParams {
1816 period: Some(period),
1817 percent: Some(percent),
1818 ma_type: Some(ma_type.to_string()),
1819 };
1820 let inner = py.allow_threads(|| {
1821 let cuda = CudaOtt::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1822 cuda.ott_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
1823 .map_err(|e| PyValueError::new_err(e.to_string()))
1824 })?;
1825 let handle = make_device_array_py(device_id, inner)?;
1826 Ok(handle)
1827}
1828
1829#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1830#[wasm_bindgen]
1831pub fn ott_js(
1832 data: &[f64],
1833 period: usize,
1834 percent: f64,
1835 ma_type: &str,
1836) -> Result<Vec<f64>, JsValue> {
1837 let params = OttParams {
1838 period: Some(period),
1839 percent: Some(percent),
1840 ma_type: Some(ma_type.to_string()),
1841 };
1842 let input = OttInput::from_slice(data, params);
1843
1844 let mut out = vec![f64::NAN; data.len()];
1845
1846 let kernel = if cfg!(target_arch = "wasm32") {
1847 Kernel::Scalar
1848 } else {
1849 detect_best_kernel()
1850 };
1851 ott_into_slice(&mut out, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
1852 Ok(out)
1853}
1854
1855#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1856#[wasm_bindgen]
1857pub fn ott_alloc(len: usize) -> *mut f64 {
1858 let mut vec = Vec::<f64>::with_capacity(len);
1859 let ptr = vec.as_mut_ptr();
1860 std::mem::forget(vec);
1861 ptr
1862}
1863
1864#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1865#[wasm_bindgen]
1866pub fn ott_free(ptr: *mut f64, len: usize) {
1867 unsafe {
1868 let _ = Vec::from_raw_parts(ptr, len, len);
1869 }
1870}
1871
1872#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1873#[wasm_bindgen]
1874pub fn ott_into(
1875 in_ptr: *const f64,
1876 out_ptr: *mut f64,
1877 len: usize,
1878 period: usize,
1879 percent: f64,
1880 ma_type: &str,
1881) -> Result<(), JsValue> {
1882 if in_ptr.is_null() || out_ptr.is_null() {
1883 return Err(JsValue::from_str("null pointer passed to ott_into"));
1884 }
1885
1886 unsafe {
1887 let data = std::slice::from_raw_parts(in_ptr, len);
1888
1889 let params = OttParams {
1890 period: Some(period),
1891 percent: Some(percent),
1892 ma_type: Some(ma_type.to_string()),
1893 };
1894 let input = OttInput::from_slice(data, params);
1895
1896 let kernel = if cfg!(target_arch = "wasm32") {
1897 Kernel::Scalar
1898 } else {
1899 detect_best_kernel()
1900 };
1901
1902 if in_ptr == out_ptr {
1903 let mut temp = vec![0.0; len];
1904 ott_into_slice(&mut temp, &input, kernel)
1905 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1906 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1907 out.copy_from_slice(&temp);
1908 } else {
1909 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1910 ott_into_slice(out, &input, kernel).map_err(|e| JsValue::from_str(&e.to_string()))?;
1911 }
1912
1913 Ok(())
1914 }
1915}
1916
1917#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1918#[wasm_bindgen]
1919#[deprecated(
1920 since = "1.0.0",
1921 note = "For reuse, prefer fast/unsafe API with persistent buffers"
1922)]
1923pub struct OttContext {
1924 period: usize,
1925 percent: f64,
1926 ma_type: String,
1927 kernel: Kernel,
1928}
1929
1930#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1931#[wasm_bindgen]
1932#[allow(deprecated)]
1933impl OttContext {
1934 #[wasm_bindgen(constructor)]
1935 pub fn new(period: usize, percent: f64, ma_type: &str) -> Result<OttContext, JsValue> {
1936 if period == 0 {
1937 return Err(JsValue::from_str("Invalid period: 0"));
1938 }
1939 if !percent.is_finite() || percent < 0.0 {
1940 return Err(JsValue::from_str("Invalid percent"));
1941 }
1942 Ok(OttContext {
1943 period,
1944 percent,
1945 ma_type: ma_type.to_string(),
1946 kernel: if cfg!(target_arch = "wasm32") {
1947 Kernel::Scalar
1948 } else {
1949 detect_best_kernel()
1950 },
1951 })
1952 }
1953
1954 pub fn update_into(
1955 &self,
1956 in_ptr: *const f64,
1957 out_ptr: *mut f64,
1958 len: usize,
1959 ) -> Result<(), JsValue> {
1960 if in_ptr.is_null() || out_ptr.is_null() {
1961 return Err(JsValue::from_str("null pointer"));
1962 }
1963 unsafe {
1964 let data = std::slice::from_raw_parts(in_ptr, len);
1965 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1966 let params = OttParams {
1967 period: Some(self.period),
1968 percent: Some(self.percent),
1969 ma_type: Some(self.ma_type.clone()),
1970 };
1971 let input = OttInput::from_slice(data, params);
1972 ott_into_slice(out, &input, self.kernel).map_err(|e| JsValue::from_str(&e.to_string()))
1973 }
1974 }
1975
1976 pub fn get_warmup_period(&self) -> usize {
1977 self.period.saturating_sub(1)
1978 }
1979}
1980
1981#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1982#[derive(Serialize, Deserialize)]
1983pub struct OttBatchConfig {
1984 pub period_range: (usize, usize, usize),
1985 pub percent_range: (f64, f64, f64),
1986 pub ma_types: Vec<String>,
1987}
1988
1989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1990#[derive(Serialize, Deserialize)]
1991pub struct OttBatchJsOutput {
1992 pub values: Vec<f64>,
1993 pub combos: Vec<OttParams>,
1994 pub rows: usize,
1995 pub cols: usize,
1996}
1997
1998#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1999#[wasm_bindgen(js_name = ott_batch)]
2000pub fn ott_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
2001 let cfg: OttBatchConfig = serde_wasm_bindgen::from_value(config)
2002 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2003 let sweep = OttBatchRange {
2004 period: cfg.period_range,
2005 percent: cfg.percent_range,
2006 ma_types: cfg.ma_types,
2007 };
2008
2009 let kernel = if cfg!(target_arch = "wasm32") {
2010 Kernel::ScalarBatch
2011 } else {
2012 detect_best_batch_kernel()
2013 };
2014 let out = ott_batch_with_kernel(data, &sweep, kernel)
2015 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2016 let js = OttBatchJsOutput {
2017 values: out.values,
2018 combos: out.combos,
2019 rows: out.rows,
2020 cols: out.cols,
2021 };
2022 serde_wasm_bindgen::to_value(&js)
2023 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2024}
2025
2026#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2027#[wasm_bindgen]
2028pub fn ott_batch_into(
2029 in_ptr: *const f64,
2030 out_ptr: *mut f64,
2031 len: usize,
2032 p_start: usize,
2033 p_end: usize,
2034 p_step: usize,
2035 q_start: f64,
2036 q_end: f64,
2037 q_step: f64,
2038 ma_types: JsValue,
2039) -> Result<usize, JsValue> {
2040 if in_ptr.is_null() || out_ptr.is_null() {
2041 return Err(JsValue::from_str("null pointer passed to ott_batch_into"));
2042 }
2043 let types: Vec<String> = serde_wasm_bindgen::from_value(ma_types)
2044 .map_err(|e| JsValue::from_str(&format!("Invalid ma_types: {}", e)))?;
2045
2046 let sweep = OttBatchRange {
2047 period: (p_start, p_end, p_step),
2048 percent: (q_start, q_end, q_step),
2049 ma_types: types,
2050 };
2051 let combos = expand_grid_ott(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2052 if combos.is_empty() {
2053 return Err(JsValue::from_str("no parameter combinations"));
2054 }
2055
2056 unsafe {
2057 let data = std::slice::from_raw_parts(in_ptr, len);
2058 let rows = combos.len();
2059 let cols = len;
2060
2061 let total = rows
2062 .checked_mul(cols)
2063 .ok_or_else(|| JsValue::from_str("rows * cols overflow"))?;
2064
2065 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2066
2067 let row_kern = match detect_best_batch_kernel() {
2068 Kernel::Avx512Batch => Kernel::Avx512,
2069 Kernel::Avx2Batch => Kernel::Avx2,
2070 _ => Kernel::Scalar,
2071 };
2072
2073 for (r, prm) in combos.iter().enumerate() {
2074 let p = prm.period.unwrap();
2075 let pct = prm.percent.unwrap();
2076 let mt = prm.ma_type.as_deref().unwrap();
2077 if p == 0 || p > cols {
2078 return Err(JsValue::from_str(
2079 &OttError::InvalidPeriod {
2080 period: p,
2081 data_len: cols,
2082 }
2083 .to_string(),
2084 ));
2085 }
2086 if pct < 0.0 || !pct.is_finite() {
2087 return Err(JsValue::from_str(
2088 &OttError::InvalidPercent { percent: pct }.to_string(),
2089 ));
2090 }
2091
2092 let ma = calculate_moving_average(data, p, mt, row_kern)
2093 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2094
2095 let ma_first = ma.iter().position(|&x| !x.is_nan()).unwrap_or(cols);
2096
2097 let row = &mut out[r * cols..(r + 1) * cols];
2098 for v in &mut row[..ma_first.min(cols)] {
2099 *v = f64::NAN;
2100 }
2101
2102 ott_compute_into(data, &ma, pct, ma_first, p, row_kern, row);
2103 }
2104 Ok(rows)
2105 }
2106}
2107
2108#[cfg(test)]
2109mod tests {
2110 use super::*;
2111 use crate::skip_if_unsupported;
2112 use crate::utilities::data_loader::read_candles_from_csv;
2113 #[cfg(feature = "proptest")]
2114 use proptest::prelude::*;
2115 use std::error::Error;
2116
2117 fn check_ott_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2118 skip_if_unsupported!(kernel, test_name);
2119 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2120 let candles = read_candles_from_csv(file_path)?;
2121
2122 let input = OttInput::from_candles(&candles, "close", OttParams::default());
2123 let result = ott_with_kernel(&input, kernel)?;
2124
2125 let expected_last_five = [
2126 59719.89457348,
2127 59719.89457348,
2128 59719.89457348,
2129 59719.89457348,
2130 59649.80599569,
2131 ];
2132
2133 let start = result.values.len().saturating_sub(5);
2134 for (i, &val) in result.values[start..].iter().enumerate() {
2135 let diff = (val - expected_last_five[i]).abs();
2136 assert!(
2137 diff < 1e-6,
2138 "[{}] OTT {:?} mismatch at idx {}: got {}, expected {}",
2139 test_name,
2140 kernel,
2141 i,
2142 val,
2143 expected_last_five[i]
2144 );
2145 }
2146 Ok(())
2147 }
2148
2149 fn check_ott_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2150 skip_if_unsupported!(kernel, test_name);
2151 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2152 let candles = read_candles_from_csv(file_path)?;
2153
2154 let default_params = OttParams {
2155 period: None,
2156 percent: None,
2157 ma_type: None,
2158 };
2159 let input = OttInput::from_candles(&candles, "close", default_params);
2160 let output = ott_with_kernel(&input, kernel)?;
2161 assert_eq!(output.values.len(), candles.close.len());
2162
2163 Ok(())
2164 }
2165
2166 fn check_ott_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2167 skip_if_unsupported!(kernel, test_name);
2168 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2169 let candles = read_candles_from_csv(file_path)?;
2170
2171 let input = OttInput::with_default_candles(&candles);
2172 match input.data {
2173 OttData::Candles { source, .. } => assert_eq!(source, "close"),
2174 _ => panic!("Expected OttData::Candles"),
2175 }
2176 let output = ott_with_kernel(&input, kernel)?;
2177 assert_eq!(output.values.len(), candles.close.len());
2178
2179 Ok(())
2180 }
2181
2182 fn check_ott_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2183 skip_if_unsupported!(kernel, test_name);
2184 let input_data = [10.0, 20.0, 30.0];
2185 let params = OttParams {
2186 period: Some(0),
2187 percent: None,
2188 ma_type: None,
2189 };
2190 let input = OttInput::from_slice(&input_data, params);
2191 let res = ott_with_kernel(&input, kernel);
2192 assert!(
2193 res.is_err(),
2194 "[{}] OTT should fail with zero period",
2195 test_name
2196 );
2197 Ok(())
2198 }
2199
2200 fn check_ott_period_exceeds_length(
2201 test_name: &str,
2202 kernel: Kernel,
2203 ) -> Result<(), Box<dyn Error>> {
2204 skip_if_unsupported!(kernel, test_name);
2205 let data_small = [10.0, 20.0, 30.0];
2206 let params = OttParams {
2207 period: Some(10),
2208 percent: None,
2209 ma_type: None,
2210 };
2211 let input = OttInput::from_slice(&data_small, params);
2212 let res = ott_with_kernel(&input, kernel);
2213 assert!(
2214 res.is_err(),
2215 "[{}] OTT should fail with period exceeding length",
2216 test_name
2217 );
2218 Ok(())
2219 }
2220
2221 fn check_ott_very_small_dataset(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2222 skip_if_unsupported!(kernel, test_name);
2223 let single_point = [42.0];
2224 let params = OttParams::default();
2225 let input = OttInput::from_slice(&single_point, params);
2226 let res = ott_with_kernel(&input, kernel);
2227 assert!(
2228 res.is_err(),
2229 "[{}] OTT should fail with insufficient data",
2230 test_name
2231 );
2232 Ok(())
2233 }
2234
2235 fn check_ott_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2236 skip_if_unsupported!(kernel, test_name);
2237 let empty: [f64; 0] = [];
2238 let params = OttParams::default();
2239 let input = OttInput::from_slice(&empty, params);
2240 let res = ott_with_kernel(&input, kernel);
2241 assert!(
2242 res.is_err(),
2243 "[{}] OTT should fail with empty input",
2244 test_name
2245 );
2246 Ok(())
2247 }
2248
2249 fn check_ott_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2250 skip_if_unsupported!(kernel, test_name);
2251 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
2252 let params = OttParams::default();
2253 let input = OttInput::from_slice(&nan_data, params);
2254 let res = ott_with_kernel(&input, kernel);
2255 assert!(
2256 res.is_err(),
2257 "[{}] OTT should fail with all NaN values",
2258 test_name
2259 );
2260 Ok(())
2261 }
2262
2263 #[cfg(feature = "proptest")]
2264 proptest! {
2265 #[test]
2266 fn test_ott_no_panic(data: Vec<f64>, period in 1usize..100) {
2267 let params = OttParams {
2268 period: Some(period),
2269 percent: Some(1.4),
2270 ma_type: Some("VAR".to_string()),
2271 };
2272 let input = OttInput::from_slice(&data, params);
2273 let _ = ott(&input);
2274 }
2275
2276 #[test]
2277 fn test_ott_length_preservation(size in 10usize..100) {
2278 let data: Vec<f64> = (0..size).map(|i| i as f64).collect();
2279 let params = OttParams::default();
2280 let input = OttInput::from_slice(&data, params);
2281
2282 if let Ok(output) = ott(&input) {
2283 prop_assert_eq!(output.values.len(), size);
2284 }
2285 }
2286 }
2287
2288 fn check_ott_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2289 skip_if_unsupported!(kernel, test_name);
2290 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2291 let candles = read_candles_from_csv(file_path)?;
2292
2293 let input = OttInput::from_candles(&candles, "close", OttParams::default());
2294 let first_result = ott_with_kernel(&input, kernel)?;
2295
2296 let input2 = OttInput::from_slice(&first_result.values, OttParams::default());
2297 let second_result = ott_with_kernel(&input2, kernel)?;
2298
2299 assert_eq!(
2300 second_result.values.len(),
2301 first_result.values.len(),
2302 "[{}] OTT reinput length mismatch",
2303 test_name
2304 );
2305 Ok(())
2306 }
2307
2308 fn check_ott_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2309 skip_if_unsupported!(kernel, test_name);
2310 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2311 let candles = read_candles_from_csv(file_path)?;
2312
2313 let input = OttInput::from_candles(&candles, "close", OttParams::default());
2314 let result = ott_with_kernel(&input, kernel)?;
2315
2316 assert_eq!(
2317 result.values.len(),
2318 candles.close.len(),
2319 "[{}] OTT length mismatch",
2320 test_name
2321 );
2322
2323 let first_valid = result
2324 .values
2325 .iter()
2326 .position(|x| !x.is_nan())
2327 .unwrap_or(result.values.len());
2328
2329 if result.values.len() > first_valid + 100 {
2330 for i in (first_valid + 100)..result.values.len() {
2331 if candles.close[i].is_nan() {
2332 continue;
2333 }
2334 assert!(
2335 !result.values[i].is_nan(),
2336 "[{}] Unexpected NaN at index {} after warmup",
2337 test_name,
2338 i
2339 );
2340 }
2341 }
2342
2343 assert!(
2344 first_valid <= candles.close.len(),
2345 "[{}] First valid index out of range",
2346 test_name
2347 );
2348 Ok(())
2349 }
2350
2351 fn check_ott_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2352 skip_if_unsupported!(kernel, test_name);
2353 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2354 let candles = read_candles_from_csv(file_path)?;
2355 let close = &candles.close;
2356
2357 let input = OttInput::from_candles(&candles, "close", OttParams::default());
2358 let batch_result = ott_with_kernel(&input, kernel)?;
2359
2360 let mut stream = OttStream::try_new(OttParams::default())?;
2361 let mut stream_values = Vec::new();
2362
2363 for &price in close {
2364 let result = stream.update(price);
2365 stream_values.push(result.unwrap_or(f64::NAN));
2366 }
2367
2368 assert_eq!(
2369 batch_result.values.len(),
2370 stream_values.len(),
2371 "[{}] OTT streaming length mismatch",
2372 test_name
2373 );
2374
2375 let warmup = OttParams::default().period.unwrap_or(2);
2376 if batch_result.values.len() > warmup + 10 {
2377 for i in (warmup + 10)..batch_result.values.len() {
2378 if batch_result.values[i].is_nan() || stream_values[i].is_nan() {
2379 continue;
2380 }
2381
2382 let diff = (batch_result.values[i] - stream_values[i]).abs();
2383 let tolerance = batch_result.values[i].abs() * 0.05;
2384 assert!(
2385 diff <= tolerance.max(1.0),
2386 "[{}] OTT streaming mismatch at index {}: batch={}, stream={}, diff={}",
2387 test_name,
2388 i,
2389 batch_result.values[i],
2390 stream_values[i],
2391 diff
2392 );
2393 }
2394 }
2395 Ok(())
2396 }
2397
2398 #[cfg(debug_assertions)]
2399 fn check_ott_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2400 skip_if_unsupported!(kernel, test_name);
2401 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2402 let c = read_candles_from_csv(file)?;
2403 let out = OttBuilder::new().kernel(kernel).apply(&c)?;
2404 for &v in &out.values {
2405 if v.is_nan() {
2406 continue;
2407 }
2408 let b = v.to_bits();
2409 assert_ne!(
2410 b, 0x11111111_11111111,
2411 "alloc_with_nan_prefix poison leaked"
2412 );
2413 assert_ne!(b, 0x22222222_22222222, "init_matrix_prefixes poison leaked");
2414 assert_ne!(b, 0x33333333_33333333, "make_uninit_matrix poison leaked");
2415 }
2416 Ok(())
2417 }
2418
2419 #[cfg(not(debug_assertions))]
2420 fn check_ott_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2421 Ok(())
2422 }
2423
2424 macro_rules! generate_all_ott_tests {
2425 ($($test_fn:ident),*) => {
2426 paste::paste! {
2427 $(
2428 #[test]
2429 fn [<$test_fn _scalar_f64>]() {
2430 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2431 }
2432 )*
2433 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2434 $(
2435 #[test]
2436 fn [<$test_fn _avx2_f64>]() {
2437 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2438 }
2439 #[test]
2440 fn [<$test_fn _avx512_f64>]() {
2441 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2442 }
2443 )*
2444 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2445 $(
2446 #[test]
2447 fn [<$test_fn _simd128_f64>]() {
2448 let _ = $test_fn(stringify!([<$test_fn _simd128_f64>]), Kernel::Scalar);
2449 }
2450 )*
2451 }
2452 }
2453 }
2454
2455 fn check_ott_invalid_percent(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2456 skip_if_unsupported!(kernel, test_name);
2457 let data = vec![10.0, 20.0, 30.0, 40.0, 50.0];
2458 for bad in [-1.0, f64::NAN, f64::INFINITY] {
2459 let params = OttParams {
2460 period: Some(2),
2461 percent: Some(bad),
2462 ma_type: Some("VAR".to_string()),
2463 };
2464 let input = OttInput::from_slice(&data, params);
2465 let res = ott_with_kernel(&input, kernel);
2466 assert!(matches!(res, Err(OttError::InvalidPercent { .. })));
2467 }
2468 Ok(())
2469 }
2470
2471 generate_all_ott_tests!(
2472 check_ott_partial_params,
2473 check_ott_accuracy,
2474 check_ott_default_candles,
2475 check_ott_zero_period,
2476 check_ott_period_exceeds_length,
2477 check_ott_very_small_dataset,
2478 check_ott_empty_input,
2479 check_ott_all_nan,
2480 check_ott_reinput,
2481 check_ott_nan_handling,
2482 check_ott_streaming,
2483 check_ott_no_poison,
2484 check_ott_invalid_percent
2485 );
2486
2487 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2488 skip_if_unsupported!(kernel, test);
2489 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2490 let c = read_candles_from_csv(file)?;
2491
2492 let out = OttBatchBuilder::new()
2493 .kernel(kernel)
2494 .apply_candles(&c, "close")?;
2495 let def = OttParams::default();
2496 let row = out.values_for(&def).expect("default row missing");
2497 assert_eq!(row.len(), c.close.len());
2498
2499 let single_kernel = match kernel {
2500 Kernel::ScalarBatch => Kernel::Scalar,
2501 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2502 Kernel::Avx2Batch => Kernel::Avx2,
2503 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2504 Kernel::Avx512Batch => Kernel::Avx512,
2505 Kernel::Auto => Kernel::Scalar,
2506 _ => Kernel::Scalar,
2507 };
2508 let single =
2509 ott_with_kernel(&OttInput::from_slice(&c.close, def.clone()), single_kernel)?.values;
2510
2511 assert_eq!(single.len(), row.len());
2512 for i in 0..row.len() {
2513 if row[i].is_nan() || single[i].is_nan() {
2514 continue;
2515 }
2516 assert!(
2517 (row[i] - single[i]).abs() <= 1e-9,
2518 "[{test}] mismatch at {i}"
2519 );
2520 }
2521 Ok(())
2522 }
2523
2524 fn check_batch_sweep(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2525 skip_if_unsupported!(kernel, test);
2526 let data = vec![1.0; 100];
2527
2528 let out = OttBatchBuilder::new()
2529 .kernel(kernel)
2530 .period_range(10, 20, 10)
2531 .percent_range(1.0, 2.0, 1.0)
2532 .ma_types(vec!["VAR".to_string(), "WWMA".to_string()])
2533 .apply_slice(&data)?;
2534
2535 assert_eq!(
2536 out.rows, 8,
2537 "[{}] Expected 8 rows for parameter sweep",
2538 test
2539 );
2540 assert_eq!(out.cols, 100, "[{}] Column count mismatch", test);
2541
2542 assert_eq!(out.combos.len(), 8, "[{}] Combos count mismatch", test);
2543 Ok(())
2544 }
2545
2546 #[cfg(debug_assertions)]
2547 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2548 skip_if_unsupported!(kernel, test);
2549 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2550 let c = read_candles_from_csv(file)?;
2551
2552 let out = OttBatchBuilder::new()
2553 .kernel(kernel)
2554 .period_range(40, 60, 20)
2555 .percent_range(1.0, 2.0, 1.0)
2556 .ma_types(vec!["VAR".to_string()])
2557 .apply_slice(&c.close)?;
2558
2559 for &v in &out.values {
2560 if v.is_nan() {
2561 continue;
2562 }
2563 let b = v.to_bits();
2564 assert_ne!(
2565 b, 0x11111111_11111111,
2566 "[{}] alloc_with_nan_prefix poison",
2567 test
2568 );
2569 assert_ne!(
2570 b, 0x22222222_22222222,
2571 "[{}] init_matrix_prefixes poison",
2572 test
2573 );
2574 assert_ne!(
2575 b, 0x33333333_33333333,
2576 "[{}] make_uninit_matrix poison",
2577 test
2578 );
2579 }
2580 Ok(())
2581 }
2582
2583 #[cfg(not(debug_assertions))]
2584 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
2585 Ok(())
2586 }
2587
2588 macro_rules! gen_batch_tests {
2589 ($fn_name:ident) => {
2590 paste::paste! {
2591 #[test] fn [<$fn_name _scalar>]() {
2592 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2593 }
2594 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2595 #[test] fn [<$fn_name _avx2>]() {
2596 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2597 }
2598 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2599 #[test] fn [<$fn_name _avx512>]() {
2600 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2601 }
2602 #[test] fn [<$fn_name _auto_detect>]() {
2603 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2604 }
2605 }
2606 };
2607 }
2608
2609 gen_batch_tests!(check_batch_default_row);
2610 gen_batch_tests!(check_batch_sweep);
2611 gen_batch_tests!(check_batch_no_poison);
2612
2613 fn check_batch_helpers_and_row_lookup(
2614 _test: &str,
2615 kernel: Kernel,
2616 ) -> Result<(), Box<dyn Error>> {
2617 let data = vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0];
2618
2619 let out1 = OttBatchBuilder::with_default_slice(&data, kernel)?;
2620 assert_eq!(out1.rows, 1);
2621 assert_eq!(out1.cols, data.len());
2622
2623 let builder = OttBatchBuilder::new()
2624 .kernel(kernel)
2625 .period_static(3)
2626 .percent_static(1.4)
2627 .ma_types(vec!["VAR".to_string()]);
2628
2629 let out2 = builder.apply_batch(&data)?;
2630 assert_eq!(out2.rows, 1);
2631 assert_eq!(out2.cols, data.len());
2632
2633 let params = OttParams {
2634 period: Some(3),
2635 percent: Some(1.4),
2636 ma_type: Some("VAR".to_string()),
2637 };
2638
2639 let row_idx = out2.row_for_params(¶ms);
2640 assert_eq!(row_idx, Some(0));
2641
2642 let values = out2.values_for(¶ms);
2643 assert!(values.is_some());
2644 assert_eq!(values.unwrap().len(), data.len());
2645
2646 let default_params = OttParams::default();
2647 let default_row_idx = out1.row_for_params(&default_params);
2648 assert_eq!(default_row_idx, Some(0));
2649
2650 let invalid_params = OttParams {
2651 period: Some(999),
2652 percent: Some(999.0),
2653 ma_type: Some("INVALID".to_string()),
2654 };
2655 assert_eq!(out2.row_for_params(&invalid_params), None);
2656 assert_eq!(out2.values_for(&invalid_params), None);
2657
2658 Ok(())
2659 }
2660
2661 gen_batch_tests!(check_batch_helpers_and_row_lookup);
2662}
2663
2664#[cfg(test)]
2665#[test]
2666fn test_ott_into_matches_api() {
2667 let n = 256usize;
2668 let mut data = Vec::with_capacity(n);
2669 for i in 0..n {
2670 let x = i as f64;
2671 data.push((x * 0.1).sin() * 100.0 + x * 0.05);
2672 }
2673
2674 let input = OttInput::from_slice(&data, OttParams::default());
2675
2676 let baseline = ott(&input).expect("baseline ott() failed");
2677
2678 let mut into_out = vec![0.0; data.len()];
2679
2680 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
2681 {
2682 ott_into(&input, &mut into_out).expect("ott_into failed");
2683 }
2684 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2685 {
2686 ott_into_slice(&mut into_out, &input, Kernel::Scalar).expect("ott_into_slice failed");
2687 }
2688
2689 assert_eq!(baseline.values.len(), into_out.len());
2690
2691 fn eq_or_both_nan(a: f64, b: f64) -> bool {
2692 (a.is_nan() && b.is_nan()) || a == b || (a - b).abs() <= 1e-12
2693 }
2694
2695 for (i, (&a, &b)) in baseline.values.iter().zip(into_out.iter()).enumerate() {
2696 assert!(
2697 eq_or_both_nan(a, b),
2698 "parity mismatch at {}: vec_api={} vs into_api={}",
2699 i,
2700 a,
2701 b
2702 );
2703 }
2704}