1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::Candles;
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24#[cfg(test)]
25use std::error::Error as StdError;
26use std::mem::ManuallyDrop;
27use thiserror::Error;
28
29const DEFAULT_DATA_LENGTH: usize = 10;
30const DEFAULT_NORMALIZATION_LENGTH: usize = 100;
31const DEFAULT_BASE_LEVEL: &str = "level2";
32const MIN_DATA_LENGTH: usize = 1;
33const MIN_NORMALIZATION_LENGTH: usize = 10;
34const BIAS_BEARISH: u8 = 0;
35const BIAS_BULLISH: u8 = 1;
36
37#[inline(always)]
38fn hlc3(high: f64, low: f64, close: f64) -> f64 {
39 (high + low + close) / 3.0
40}
41
42#[inline(always)]
43fn floor_positive(value: f64) -> f64 {
44 if value > 0.0 {
45 value
46 } else {
47 f64::MIN_POSITIVE
48 }
49}
50
51#[inline(always)]
52fn close_source(candles: &Candles) -> &[f64] {
53 &candles.close
54}
55
56#[inline(always)]
57fn high_source(candles: &Candles) -> &[f64] {
58 &candles.high
59}
60
61#[inline(always)]
62fn low_source(candles: &Candles) -> &[f64] {
63 &candles.low
64}
65
66#[derive(Debug, Clone)]
67pub enum StatisticalTrailingStopData<'a> {
68 Candles {
69 candles: &'a Candles,
70 },
71 Slices {
72 high: &'a [f64],
73 low: &'a [f64],
74 close: &'a [f64],
75 },
76}
77
78#[derive(Debug, Clone)]
79#[cfg_attr(
80 all(target_arch = "wasm32", feature = "wasm"),
81 derive(Serialize, Deserialize)
82)]
83pub struct StatisticalTrailingStopOutput {
84 pub level: Vec<f64>,
85 pub anchor: Vec<f64>,
86 pub bias: Vec<f64>,
87 pub changed: Vec<f64>,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
91#[cfg_attr(
92 all(target_arch = "wasm32", feature = "wasm"),
93 derive(Serialize, Deserialize)
94)]
95pub struct StatisticalTrailingStopParams {
96 pub data_length: Option<usize>,
97 pub normalization_length: Option<usize>,
98 pub base_level: Option<String>,
99}
100
101impl Default for StatisticalTrailingStopParams {
102 fn default() -> Self {
103 Self {
104 data_length: Some(DEFAULT_DATA_LENGTH),
105 normalization_length: Some(DEFAULT_NORMALIZATION_LENGTH),
106 base_level: Some(DEFAULT_BASE_LEVEL.to_string()),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
112pub struct StatisticalTrailingStopInput<'a> {
113 pub data: StatisticalTrailingStopData<'a>,
114 pub params: StatisticalTrailingStopParams,
115}
116
117impl<'a> StatisticalTrailingStopInput<'a> {
118 #[inline(always)]
119 pub fn from_candles(candles: &'a Candles, params: StatisticalTrailingStopParams) -> Self {
120 Self {
121 data: StatisticalTrailingStopData::Candles { candles },
122 params,
123 }
124 }
125
126 #[inline(always)]
127 pub fn from_slices(
128 high: &'a [f64],
129 low: &'a [f64],
130 close: &'a [f64],
131 params: StatisticalTrailingStopParams,
132 ) -> Self {
133 Self {
134 data: StatisticalTrailingStopData::Slices { high, low, close },
135 params,
136 }
137 }
138
139 #[inline(always)]
140 pub fn with_default_candles(candles: &'a Candles) -> Self {
141 Self::from_candles(candles, StatisticalTrailingStopParams::default())
142 }
143
144 #[inline(always)]
145 pub fn get_data_length(&self) -> usize {
146 self.params.data_length.unwrap_or(DEFAULT_DATA_LENGTH)
147 }
148
149 #[inline(always)]
150 pub fn get_normalization_length(&self) -> usize {
151 self.params
152 .normalization_length
153 .unwrap_or(DEFAULT_NORMALIZATION_LENGTH)
154 }
155
156 #[inline(always)]
157 pub fn get_base_level(&self) -> &str {
158 self.params
159 .base_level
160 .as_deref()
161 .unwrap_or(DEFAULT_BASE_LEVEL)
162 }
163
164 #[inline(always)]
165 fn as_hlc(&self) -> (&'a [f64], &'a [f64], &'a [f64]) {
166 match &self.data {
167 StatisticalTrailingStopData::Candles { candles } => (
168 high_source(candles),
169 low_source(candles),
170 close_source(candles),
171 ),
172 StatisticalTrailingStopData::Slices { high, low, close } => (*high, *low, *close),
173 }
174 }
175}
176
177impl<'a> AsRef<[f64]> for StatisticalTrailingStopInput<'a> {
178 #[inline(always)]
179 fn as_ref(&self) -> &[f64] {
180 self.as_hlc().2
181 }
182}
183
184#[derive(Clone, Debug)]
185pub struct StatisticalTrailingStopBuilder {
186 data_length: Option<usize>,
187 normalization_length: Option<usize>,
188 base_level: Option<String>,
189 kernel: Kernel,
190}
191
192impl Default for StatisticalTrailingStopBuilder {
193 fn default() -> Self {
194 Self {
195 data_length: None,
196 normalization_length: None,
197 base_level: None,
198 kernel: Kernel::Auto,
199 }
200 }
201}
202
203impl StatisticalTrailingStopBuilder {
204 #[inline(always)]
205 pub fn new() -> Self {
206 Self::default()
207 }
208
209 #[inline(always)]
210 pub fn data_length(mut self, value: usize) -> Self {
211 self.data_length = Some(value);
212 self
213 }
214
215 #[inline(always)]
216 pub fn normalization_length(mut self, value: usize) -> Self {
217 self.normalization_length = Some(value);
218 self
219 }
220
221 #[inline(always)]
222 pub fn base_level<S: Into<String>>(mut self, value: S) -> Self {
223 self.base_level = Some(value.into());
224 self
225 }
226
227 #[inline(always)]
228 pub fn kernel(mut self, kernel: Kernel) -> Self {
229 self.kernel = kernel;
230 self
231 }
232
233 #[inline(always)]
234 fn params(self) -> StatisticalTrailingStopParams {
235 StatisticalTrailingStopParams {
236 data_length: self.data_length,
237 normalization_length: self.normalization_length,
238 base_level: self.base_level,
239 }
240 }
241
242 #[inline(always)]
243 pub fn apply(
244 self,
245 candles: &Candles,
246 ) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
247 let kernel = self.kernel;
248 let params = self.params();
249 statistical_trailing_stop_with_kernel(
250 &StatisticalTrailingStopInput::from_candles(candles, params),
251 kernel,
252 )
253 }
254
255 #[inline(always)]
256 pub fn apply_slices(
257 self,
258 high: &[f64],
259 low: &[f64],
260 close: &[f64],
261 ) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
262 let kernel = self.kernel;
263 let params = self.params();
264 statistical_trailing_stop_with_kernel(
265 &StatisticalTrailingStopInput::from_slices(high, low, close, params),
266 kernel,
267 )
268 }
269
270 #[inline(always)]
271 pub fn into_stream(
272 self,
273 ) -> Result<StatisticalTrailingStopStream, StatisticalTrailingStopError> {
274 StatisticalTrailingStopStream::try_new(self.params())
275 }
276}
277
278#[derive(Debug, Error)]
279pub enum StatisticalTrailingStopError {
280 #[error("statistical_trailing_stop: input data slice is empty.")]
281 EmptyInputData,
282 #[error("statistical_trailing_stop: all values are NaN.")]
283 AllValuesNaN,
284 #[error(
285 "statistical_trailing_stop: inconsistent data lengths - high = {high_len}, low = {low_len}, close = {close_len}"
286 )]
287 DataLengthMismatch {
288 high_len: usize,
289 low_len: usize,
290 close_len: usize,
291 },
292 #[error(
293 "statistical_trailing_stop: invalid period: data_length = {data_length}, normalization_length = {normalization_length}, data length = {data_len}"
294 )]
295 InvalidPeriod {
296 data_length: usize,
297 normalization_length: usize,
298 data_len: usize,
299 },
300 #[error(
301 "statistical_trailing_stop: invalid base level: {base_level}. expected one of level0, level1, level2, level3"
302 )]
303 InvalidBaseLevel { base_level: String },
304 #[error(
305 "statistical_trailing_stop: not enough valid data: needed = {needed}, valid = {valid}"
306 )]
307 NotEnoughValidData { needed: usize, valid: usize },
308 #[error(
309 "statistical_trailing_stop: output length mismatch: expected = {expected}, got = {got}"
310 )]
311 OutputLengthMismatch { expected: usize, got: usize },
312 #[error(
313 "statistical_trailing_stop: invalid range for {axis}: start = {start}, end = {end}, step = {step}"
314 )]
315 InvalidRange {
316 axis: &'static str,
317 start: String,
318 end: String,
319 step: String,
320 },
321 #[error("statistical_trailing_stop: invalid kernel for batch: {0:?}")]
322 InvalidKernelForBatch(Kernel),
323}
324
325#[derive(Clone, Copy, Debug)]
326struct PreparedInput<'a> {
327 high: &'a [f64],
328 low: &'a [f64],
329 close: &'a [f64],
330 data_length: usize,
331 normalization_length: usize,
332 base_level_index: usize,
333 warmup: usize,
334}
335
336#[derive(Clone, Debug)]
337struct MonoDeque {
338 idx: Vec<usize>,
339 vals: Vec<f64>,
340 head: usize,
341 descending: bool,
342}
343
344impl MonoDeque {
345 #[inline(always)]
346 fn new(descending: bool, cap: usize) -> Self {
347 Self {
348 idx: Vec::with_capacity(cap),
349 vals: Vec::with_capacity(cap),
350 head: 0,
351 descending,
352 }
353 }
354
355 #[inline(always)]
356 fn clear(&mut self) {
357 self.idx.clear();
358 self.vals.clear();
359 self.head = 0;
360 }
361
362 #[inline(always)]
363 fn push(&mut self, index: usize, value: f64) {
364 while self.idx.len() > self.head {
365 let last = *self.vals.last().unwrap();
366 let remove = if self.descending {
367 last <= value
368 } else {
369 last >= value
370 };
371 if !remove {
372 break;
373 }
374 self.idx.pop();
375 self.vals.pop();
376 }
377 self.idx.push(index);
378 self.vals.push(value);
379 }
380
381 #[inline(always)]
382 fn expire(&mut self, min_index: usize) {
383 while self.head < self.idx.len() && self.idx[self.head] < min_index {
384 self.head += 1;
385 }
386 if self.head > 64 && self.head * 2 >= self.idx.len() {
387 self.idx.drain(..self.head);
388 self.vals.drain(..self.head);
389 self.head = 0;
390 }
391 }
392
393 #[inline(always)]
394 fn front_value(&self) -> f64 {
395 self.vals[self.head]
396 }
397}
398
399#[derive(Clone, Debug)]
400struct RingHistory {
401 values: Vec<f64>,
402 head: usize,
403 count: usize,
404}
405
406impl RingHistory {
407 #[inline(always)]
408 fn new(cap: usize) -> Self {
409 Self {
410 values: vec![0.0; cap.max(1)],
411 head: 0,
412 count: 0,
413 }
414 }
415
416 #[inline(always)]
417 fn clear(&mut self) {
418 self.head = 0;
419 self.count = 0;
420 }
421
422 #[inline(always)]
423 fn push(&mut self, value: f64) {
424 self.values[self.head] = value;
425 self.head += 1;
426 if self.head == self.values.len() {
427 self.head = 0;
428 }
429 if self.count < self.values.len() {
430 self.count += 1;
431 }
432 }
433
434 #[inline(always)]
435 fn get_from_end(&self, offset: usize) -> Option<f64> {
436 if offset == 0 || offset > self.count {
437 return None;
438 }
439 let len = self.values.len();
440 let idx = (self.head + len - offset) % len;
441 Some(self.values[idx])
442 }
443}
444
445#[derive(Clone, Debug)]
446struct RollingStats {
447 ring: Vec<f64>,
448 head: usize,
449 count: usize,
450 sum: f64,
451 sum_sq: f64,
452}
453
454impl RollingStats {
455 #[inline(always)]
456 fn new(length: usize) -> Self {
457 Self {
458 ring: vec![0.0; length.max(1)],
459 head: 0,
460 count: 0,
461 sum: 0.0,
462 sum_sq: 0.0,
463 }
464 }
465
466 #[inline(always)]
467 fn clear(&mut self) {
468 self.head = 0;
469 self.count = 0;
470 self.sum = 0.0;
471 self.sum_sq = 0.0;
472 }
473
474 #[inline(always)]
475 fn push(&mut self, value: f64) -> Option<(f64, f64)> {
476 if self.count < self.ring.len() {
477 self.ring[self.head] = value;
478 self.head += 1;
479 if self.head == self.ring.len() {
480 self.head = 0;
481 }
482 self.count += 1;
483 self.sum += value;
484 self.sum_sq += value * value;
485 } else {
486 let old = self.ring[self.head];
487 self.ring[self.head] = value;
488 self.head += 1;
489 if self.head == self.ring.len() {
490 self.head = 0;
491 }
492 self.sum += value - old;
493 self.sum_sq += value * value - old * old;
494 }
495 if self.count < self.ring.len() {
496 return None;
497 }
498 let n = self.ring.len() as f64;
499 let mean = self.sum / n;
500 let variance = (self.sum_sq / n - mean * mean).max(0.0);
501 Some((mean, variance.sqrt()))
502 }
503}
504
505#[derive(Clone, Debug)]
506struct StatisticalTrailingStopState {
507 data_length: usize,
508 base_level_index: usize,
509 valid_run: usize,
510 max_high: MonoDeque,
511 min_low: MonoDeque,
512 close_history: RingHistory,
513 stats: RollingStats,
514 bias: u8,
515 delta: f64,
516 level: f64,
517 extreme: f64,
518 anchor: f64,
519 index: usize,
520}
521
522impl StatisticalTrailingStopState {
523 #[inline(always)]
524 fn new(data_length: usize, normalization_length: usize, base_level_index: usize) -> Self {
525 Self {
526 data_length,
527 base_level_index,
528 valid_run: 0,
529 max_high: MonoDeque::new(true, data_length + 2),
530 min_low: MonoDeque::new(false, data_length + 2),
531 close_history: RingHistory::new(data_length + 2),
532 stats: RollingStats::new(normalization_length),
533 bias: BIAS_BEARISH,
534 delta: f64::NAN,
535 level: f64::NAN,
536 extreme: f64::NAN,
537 anchor: f64::NAN,
538 index: 0,
539 }
540 }
541
542 #[inline(always)]
543 fn reset(&mut self) {
544 self.valid_run = 0;
545 self.max_high.clear();
546 self.min_low.clear();
547 self.close_history.clear();
548 self.stats.clear();
549 self.bias = BIAS_BEARISH;
550 self.delta = f64::NAN;
551 self.level = f64::NAN;
552 self.extreme = f64::NAN;
553 self.anchor = f64::NAN;
554 self.index = 0;
555 }
556
557 #[inline(always)]
558 fn update(
559 &mut self,
560 index: usize,
561 high: f64,
562 low: f64,
563 close: f64,
564 ) -> Option<(f64, f64, f64, f64)> {
565 if !high.is_finite() || !low.is_finite() || !close.is_finite() {
566 self.reset();
567 return None;
568 }
569
570 self.valid_run += 1;
571 self.max_high.push(index, high);
572 self.min_low.push(index, low);
573 let window_start = index + 1 - self.valid_run.min(self.data_length);
574 self.max_high.expire(window_start);
575 self.min_low.expire(window_start);
576 self.close_history.push(close);
577
578 if self.valid_run < self.data_length + 2 {
579 return None;
580 }
581
582 let previous_close = self.close_history.get_from_end(self.data_length + 2)?;
583 let highest = self.max_high.front_value();
584 let lowest = self.min_low.front_value();
585 let tr = (highest - lowest)
586 .max((highest - previous_close).abs())
587 .max((lowest - previous_close).abs());
588 let (mean, stdev) = self.stats.push(floor_positive(tr).ln())?;
589 let delta = (mean + self.base_level_index as f64 * stdev).exp();
590 self.delta = delta;
591
592 let current_hlc3 = hlc3(high, low, close);
593 if !self.level.is_finite() {
594 self.level = if self.bias == BIAS_BEARISH {
595 current_hlc3 + delta
596 } else {
597 (current_hlc3 - delta).max(0.0)
598 };
599 }
600
601 if self.bias == BIAS_BEARISH {
602 if self.extreme.is_finite() {
603 self.extreme = self.extreme.min(low);
604 }
605 self.level = self.level.min(current_hlc3 + delta);
606 } else {
607 if self.extreme.is_finite() {
608 self.extreme = self.extreme.max(high);
609 }
610 self.level = self.level.max((current_hlc3 - delta).max(0.0));
611 }
612
613 let triggered = (self.bias == BIAS_BEARISH && close >= self.level)
614 || (self.bias == BIAS_BULLISH && close <= self.level);
615 let mut changed = 0.0;
616
617 if triggered {
618 self.anchor = close;
619 self.index = index;
620 self.bias = if self.bias == BIAS_BEARISH {
621 BIAS_BULLISH
622 } else {
623 BIAS_BEARISH
624 };
625 self.level = if self.bias == BIAS_BEARISH {
626 current_hlc3 + delta
627 } else {
628 (current_hlc3 - delta).max(0.0)
629 };
630 self.extreme = if self.bias == BIAS_BEARISH { low } else { high };
631 changed = 1.0;
632 }
633
634 Some((self.level, self.anchor, self.bias as f64, changed))
635 }
636}
637
638#[derive(Clone, Debug)]
639pub struct StatisticalTrailingStopStream {
640 params: StatisticalTrailingStopParams,
641 state: StatisticalTrailingStopState,
642 index: usize,
643}
644
645impl StatisticalTrailingStopStream {
646 #[inline(always)]
647 pub fn try_new(
648 params: StatisticalTrailingStopParams,
649 ) -> Result<Self, StatisticalTrailingStopError> {
650 let data_length = params.data_length.unwrap_or(DEFAULT_DATA_LENGTH);
651 let normalization_length = params
652 .normalization_length
653 .unwrap_or(DEFAULT_NORMALIZATION_LENGTH);
654 validate_periods(data_length, normalization_length, usize::MAX)?;
655 let base_level_index =
656 parse_base_level(params.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
657 Ok(Self {
658 state: StatisticalTrailingStopState::new(
659 data_length,
660 normalization_length,
661 base_level_index,
662 ),
663 params,
664 index: 0,
665 })
666 }
667
668 #[inline(always)]
669 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
670 let output = self.state.update(self.index, high, low, close);
671 self.index = self.index.saturating_add(1);
672 output
673 }
674
675 #[inline(always)]
676 pub fn params(&self) -> &StatisticalTrailingStopParams {
677 &self.params
678 }
679}
680
681#[derive(Clone, Debug)]
682pub struct StatisticalTrailingStopBatchRange {
683 pub data_length: (usize, usize, usize),
684 pub normalization_length: (usize, usize, usize),
685 pub base_level: (String, String, usize),
686}
687
688impl Default for StatisticalTrailingStopBatchRange {
689 fn default() -> Self {
690 Self {
691 data_length: (DEFAULT_DATA_LENGTH, DEFAULT_DATA_LENGTH, 0),
692 normalization_length: (
693 DEFAULT_NORMALIZATION_LENGTH,
694 DEFAULT_NORMALIZATION_LENGTH,
695 0,
696 ),
697 base_level: (
698 DEFAULT_BASE_LEVEL.to_string(),
699 DEFAULT_BASE_LEVEL.to_string(),
700 0,
701 ),
702 }
703 }
704}
705
706#[derive(Clone, Debug, Default)]
707pub struct StatisticalTrailingStopBatchBuilder {
708 range: StatisticalTrailingStopBatchRange,
709 kernel: Kernel,
710}
711
712#[derive(Clone, Debug)]
713pub struct StatisticalTrailingStopBatchOutput {
714 pub level: Vec<f64>,
715 pub anchor: Vec<f64>,
716 pub bias: Vec<f64>,
717 pub changed: Vec<f64>,
718 pub combos: Vec<StatisticalTrailingStopParams>,
719 pub rows: usize,
720 pub cols: usize,
721}
722
723impl StatisticalTrailingStopBatchBuilder {
724 #[inline(always)]
725 pub fn new() -> Self {
726 Self::default()
727 }
728
729 #[inline(always)]
730 pub fn kernel(mut self, kernel: Kernel) -> Self {
731 self.kernel = kernel;
732 self
733 }
734
735 #[inline(always)]
736 pub fn data_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
737 self.range.data_length = (start, end, step);
738 self
739 }
740
741 #[inline(always)]
742 pub fn normalization_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
743 self.range.normalization_length = (start, end, step);
744 self
745 }
746
747 #[inline(always)]
748 pub fn base_level_range<S: Into<String>>(mut self, start: S, end: S, step: usize) -> Self {
749 self.range.base_level = (start.into(), end.into(), step);
750 self
751 }
752
753 #[inline(always)]
754 pub fn apply_slices(
755 self,
756 high: &[f64],
757 low: &[f64],
758 close: &[f64],
759 ) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
760 statistical_trailing_stop_batch_with_kernel(high, low, close, &self.range, self.kernel)
761 }
762
763 #[inline(always)]
764 pub fn apply(
765 self,
766 candles: &Candles,
767 ) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
768 self.apply_slices(&candles.high, &candles.low, &candles.close)
769 }
770}
771
772#[inline(always)]
773fn validate_periods(
774 data_length: usize,
775 normalization_length: usize,
776 data_len: usize,
777) -> Result<(), StatisticalTrailingStopError> {
778 if data_length < MIN_DATA_LENGTH
779 || normalization_length < MIN_NORMALIZATION_LENGTH
780 || data_length + normalization_length + 1 > data_len
781 {
782 return Err(StatisticalTrailingStopError::InvalidPeriod {
783 data_length,
784 normalization_length,
785 data_len,
786 });
787 }
788 Ok(())
789}
790
791#[inline(always)]
792fn parse_base_level(value: &str) -> Result<usize, StatisticalTrailingStopError> {
793 let normalized = value
794 .trim()
795 .to_ascii_lowercase()
796 .replace([' ', '_', '-'], "");
797 match normalized.as_str() {
798 "level0" | "0" => Ok(0),
799 "level1" | "1" => Ok(1),
800 "level2" | "2" => Ok(2),
801 "level3" | "3" => Ok(3),
802 _ => Err(StatisticalTrailingStopError::InvalidBaseLevel {
803 base_level: value.to_string(),
804 }),
805 }
806}
807
808#[inline(always)]
809fn canonical_base_level(value: &str) -> Result<&'static str, StatisticalTrailingStopError> {
810 match parse_base_level(value)? {
811 0 => Ok("level0"),
812 1 => Ok("level1"),
813 2 => Ok("level2"),
814 _ => Ok("level3"),
815 }
816}
817
818#[inline(always)]
819fn analyze_valid_segments(
820 high: &[f64],
821 low: &[f64],
822 close: &[f64],
823) -> Result<(usize, usize), StatisticalTrailingStopError> {
824 let mut first_valid = None;
825 let mut current = 0usize;
826 let mut max_run = 0usize;
827 for i in 0..close.len() {
828 let valid = high[i].is_finite() && low[i].is_finite() && close[i].is_finite();
829 if valid {
830 if first_valid.is_none() {
831 first_valid = Some(i);
832 }
833 current += 1;
834 max_run = max_run.max(current);
835 } else {
836 current = 0;
837 }
838 }
839 let first_valid = first_valid.ok_or(StatisticalTrailingStopError::AllValuesNaN)?;
840 Ok((first_valid, max_run))
841}
842
843fn prepare_input<'a>(
844 input: &'a StatisticalTrailingStopInput<'a>,
845 _kernel: Kernel,
846) -> Result<PreparedInput<'a>, StatisticalTrailingStopError> {
847 let (high, low, close) = input.as_hlc();
848 if high.is_empty() || low.is_empty() || close.is_empty() {
849 return Err(StatisticalTrailingStopError::EmptyInputData);
850 }
851 if high.len() != low.len() || high.len() != close.len() {
852 return Err(StatisticalTrailingStopError::DataLengthMismatch {
853 high_len: high.len(),
854 low_len: low.len(),
855 close_len: close.len(),
856 });
857 }
858
859 let data_length = input.get_data_length();
860 let normalization_length = input.get_normalization_length();
861 validate_periods(data_length, normalization_length, close.len())?;
862 let base_level_index = parse_base_level(input.get_base_level())?;
863 let (first_valid, max_run) = analyze_valid_segments(high, low, close)?;
864 let needed = data_length + normalization_length + 1;
865 if max_run < needed {
866 return Err(StatisticalTrailingStopError::NotEnoughValidData {
867 needed,
868 valid: max_run,
869 });
870 }
871
872 Ok(PreparedInput {
873 high,
874 low,
875 close,
876 data_length,
877 normalization_length,
878 base_level_index,
879 warmup: first_valid + needed - 1,
880 })
881}
882
883#[inline(always)]
884fn compute_row(
885 high: &[f64],
886 low: &[f64],
887 close: &[f64],
888 params: &StatisticalTrailingStopParams,
889 level_out: &mut [f64],
890 anchor_out: &mut [f64],
891 bias_out: &mut [f64],
892 changed_out: &mut [f64],
893) -> Result<(), StatisticalTrailingStopError> {
894 let len = close.len();
895 let expected = len;
896 if level_out.len() != expected
897 || anchor_out.len() != expected
898 || bias_out.len() != expected
899 || changed_out.len() != expected
900 {
901 return Err(StatisticalTrailingStopError::OutputLengthMismatch {
902 expected,
903 got: level_out
904 .len()
905 .max(anchor_out.len())
906 .max(bias_out.len())
907 .max(changed_out.len()),
908 });
909 }
910
911 let data_length = params.data_length.unwrap_or(DEFAULT_DATA_LENGTH);
912 let normalization_length = params
913 .normalization_length
914 .unwrap_or(DEFAULT_NORMALIZATION_LENGTH);
915 validate_periods(data_length, normalization_length, len)?;
916 let base_level_index =
917 parse_base_level(params.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
918
919 let mut state =
920 StatisticalTrailingStopState::new(data_length, normalization_length, base_level_index);
921 for i in 0..len {
922 if let Some((level, anchor, bias, changed)) = state.update(i, high[i], low[i], close[i]) {
923 level_out[i] = level;
924 anchor_out[i] = anchor;
925 bias_out[i] = bias;
926 changed_out[i] = changed;
927 } else {
928 level_out[i] = f64::NAN;
929 anchor_out[i] = f64::NAN;
930 bias_out[i] = f64::NAN;
931 changed_out[i] = f64::NAN;
932 }
933 }
934 Ok(())
935}
936
937#[inline]
938pub fn statistical_trailing_stop(
939 input: &StatisticalTrailingStopInput,
940) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
941 statistical_trailing_stop_with_kernel(input, Kernel::Auto)
942}
943
944pub fn statistical_trailing_stop_with_kernel(
945 input: &StatisticalTrailingStopInput,
946 kernel: Kernel,
947) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
948 let prepared = prepare_input(input, kernel)?;
949 let len = prepared.close.len();
950 let mut level = alloc_with_nan_prefix(len, prepared.warmup);
951 let mut anchor = alloc_with_nan_prefix(len, prepared.warmup);
952 let mut bias = alloc_with_nan_prefix(len, prepared.warmup);
953 let mut changed = alloc_with_nan_prefix(len, prepared.warmup);
954 compute_row(
955 prepared.high,
956 prepared.low,
957 prepared.close,
958 &StatisticalTrailingStopParams {
959 data_length: Some(prepared.data_length),
960 normalization_length: Some(prepared.normalization_length),
961 base_level: Some(
962 match prepared.base_level_index {
963 0 => "level0",
964 1 => "level1",
965 2 => "level2",
966 _ => "level3",
967 }
968 .to_string(),
969 ),
970 },
971 &mut level,
972 &mut anchor,
973 &mut bias,
974 &mut changed,
975 )?;
976 Ok(StatisticalTrailingStopOutput {
977 level,
978 anchor,
979 bias,
980 changed,
981 })
982}
983
984#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
985pub fn statistical_trailing_stop_into(
986 level_out: &mut [f64],
987 anchor_out: &mut [f64],
988 bias_out: &mut [f64],
989 changed_out: &mut [f64],
990 input: &StatisticalTrailingStopInput,
991) -> Result<(), StatisticalTrailingStopError> {
992 statistical_trailing_stop_into_slice(
993 level_out,
994 anchor_out,
995 bias_out,
996 changed_out,
997 input,
998 Kernel::Auto,
999 )
1000}
1001
1002pub fn statistical_trailing_stop_into_slice(
1003 level_out: &mut [f64],
1004 anchor_out: &mut [f64],
1005 bias_out: &mut [f64],
1006 changed_out: &mut [f64],
1007 input: &StatisticalTrailingStopInput,
1008 kernel: Kernel,
1009) -> Result<(), StatisticalTrailingStopError> {
1010 let prepared = prepare_input(input, kernel)?;
1011 compute_row(
1012 prepared.high,
1013 prepared.low,
1014 prepared.close,
1015 &StatisticalTrailingStopParams {
1016 data_length: Some(prepared.data_length),
1017 normalization_length: Some(prepared.normalization_length),
1018 base_level: Some(
1019 match prepared.base_level_index {
1020 0 => "level0",
1021 1 => "level1",
1022 2 => "level2",
1023 _ => "level3",
1024 }
1025 .to_string(),
1026 ),
1027 },
1028 level_out,
1029 anchor_out,
1030 bias_out,
1031 changed_out,
1032 )
1033}
1034
1035#[inline(always)]
1036pub fn expand_grid(
1037 sweep: &StatisticalTrailingStopBatchRange,
1038) -> Result<Vec<StatisticalTrailingStopParams>, StatisticalTrailingStopError> {
1039 fn axis_usize(
1040 axis: &'static str,
1041 (start, end, step): (usize, usize, usize),
1042 ) -> Result<Vec<usize>, StatisticalTrailingStopError> {
1043 if step == 0 || start == end {
1044 return Ok(vec![start]);
1045 }
1046 let mut out = Vec::new();
1047 if start < end {
1048 let mut value = start;
1049 let stride = step.max(1);
1050 while value <= end {
1051 out.push(value);
1052 match value.checked_add(stride) {
1053 Some(next) => value = next,
1054 None => break,
1055 }
1056 }
1057 } else {
1058 let mut value = start as isize;
1059 let stop = end as isize;
1060 let stride = step.max(1) as isize;
1061 while value >= stop {
1062 out.push(value as usize);
1063 value -= stride;
1064 }
1065 }
1066 if out.is_empty() {
1067 return Err(StatisticalTrailingStopError::InvalidRange {
1068 axis,
1069 start: start.to_string(),
1070 end: end.to_string(),
1071 step: step.to_string(),
1072 });
1073 }
1074 Ok(out)
1075 }
1076
1077 fn axis_base_level(
1078 axis: &'static str,
1079 (start, end, step): (String, String, usize),
1080 ) -> Result<Vec<String>, StatisticalTrailingStopError> {
1081 let start_idx = parse_base_level(&start)?;
1082 let end_idx = parse_base_level(&end)?;
1083 if step == 0 || start_idx == end_idx {
1084 return Ok(vec![canonical_base_level(&start)?.to_string()]);
1085 }
1086 let mut out = Vec::new();
1087 if start_idx < end_idx {
1088 let mut idx = start_idx;
1089 let stride = step.max(1);
1090 while idx <= end_idx {
1091 out.push(
1092 match idx {
1093 0 => "level0",
1094 1 => "level1",
1095 2 => "level2",
1096 _ => "level3",
1097 }
1098 .to_string(),
1099 );
1100 match idx.checked_add(stride) {
1101 Some(next) => idx = next,
1102 None => break,
1103 }
1104 }
1105 } else {
1106 let mut idx = start_idx as isize;
1107 let stop = end_idx as isize;
1108 let stride = step.max(1) as isize;
1109 while idx >= stop {
1110 out.push(
1111 match idx as usize {
1112 0 => "level0",
1113 1 => "level1",
1114 2 => "level2",
1115 _ => "level3",
1116 }
1117 .to_string(),
1118 );
1119 idx -= stride;
1120 }
1121 }
1122 if out.is_empty() {
1123 return Err(StatisticalTrailingStopError::InvalidRange {
1124 axis,
1125 start,
1126 end,
1127 step: step.to_string(),
1128 });
1129 }
1130 Ok(out)
1131 }
1132
1133 let data_lengths = axis_usize("data_length", sweep.data_length)?;
1134 let normalization_lengths = axis_usize("normalization_length", sweep.normalization_length)?;
1135 let base_levels = axis_base_level("base_level", sweep.base_level.clone())?;
1136
1137 let cap = data_lengths
1138 .len()
1139 .checked_mul(normalization_lengths.len())
1140 .and_then(|v| v.checked_mul(base_levels.len()))
1141 .ok_or(StatisticalTrailingStopError::InvalidRange {
1142 axis: "grid",
1143 start: "cap".to_string(),
1144 end: "overflow".to_string(),
1145 step: "mul".to_string(),
1146 })?;
1147
1148 let mut out = Vec::with_capacity(cap);
1149 for &data_length in &data_lengths {
1150 for &normalization_length in &normalization_lengths {
1151 for base_level in &base_levels {
1152 out.push(StatisticalTrailingStopParams {
1153 data_length: Some(data_length),
1154 normalization_length: Some(normalization_length),
1155 base_level: Some(base_level.clone()),
1156 });
1157 }
1158 }
1159 }
1160 Ok(out)
1161}
1162
1163fn statistical_trailing_stop_batch_inner_into(
1164 high: &[f64],
1165 low: &[f64],
1166 close: &[f64],
1167 sweep: &StatisticalTrailingStopBatchRange,
1168 parallel: bool,
1169 level_out: &mut [f64],
1170 anchor_out: &mut [f64],
1171 bias_out: &mut [f64],
1172 changed_out: &mut [f64],
1173) -> Result<Vec<StatisticalTrailingStopParams>, StatisticalTrailingStopError> {
1174 if high.is_empty() || low.is_empty() || close.is_empty() {
1175 return Err(StatisticalTrailingStopError::EmptyInputData);
1176 }
1177 if high.len() != low.len() || high.len() != close.len() {
1178 return Err(StatisticalTrailingStopError::DataLengthMismatch {
1179 high_len: high.len(),
1180 low_len: low.len(),
1181 close_len: close.len(),
1182 });
1183 }
1184 let combos = expand_grid(sweep)?;
1185 let rows = combos.len();
1186 let cols = close.len();
1187 let expected =
1188 rows.checked_mul(cols)
1189 .ok_or(StatisticalTrailingStopError::OutputLengthMismatch {
1190 expected: usize::MAX,
1191 got: level_out.len(),
1192 })?;
1193 if level_out.len() != expected
1194 || anchor_out.len() != expected
1195 || bias_out.len() != expected
1196 || changed_out.len() != expected
1197 {
1198 return Err(StatisticalTrailingStopError::OutputLengthMismatch {
1199 expected,
1200 got: level_out
1201 .len()
1202 .max(anchor_out.len())
1203 .max(bias_out.len())
1204 .max(changed_out.len()),
1205 });
1206 }
1207 let (_, max_run) = analyze_valid_segments(high, low, close)?;
1208
1209 for params in &combos {
1210 let data_length = params.data_length.unwrap_or(DEFAULT_DATA_LENGTH);
1211 let normalization_length = params
1212 .normalization_length
1213 .unwrap_or(DEFAULT_NORMALIZATION_LENGTH);
1214 validate_periods(data_length, normalization_length, cols)?;
1215 let needed = data_length + normalization_length + 1;
1216 if max_run < needed {
1217 return Err(StatisticalTrailingStopError::NotEnoughValidData {
1218 needed,
1219 valid: max_run,
1220 });
1221 }
1222 let _ = parse_base_level(params.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
1223 }
1224
1225 let do_row = |row: usize,
1226 level_row: &mut [f64],
1227 anchor_row: &mut [f64],
1228 bias_row: &mut [f64],
1229 changed_row: &mut [f64]| {
1230 compute_row(
1231 high,
1232 low,
1233 close,
1234 &combos[row],
1235 level_row,
1236 anchor_row,
1237 bias_row,
1238 changed_row,
1239 )
1240 };
1241
1242 if parallel {
1243 #[cfg(not(target_arch = "wasm32"))]
1244 {
1245 level_out
1246 .par_chunks_mut(cols)
1247 .zip(anchor_out.par_chunks_mut(cols))
1248 .zip(bias_out.par_chunks_mut(cols))
1249 .zip(changed_out.par_chunks_mut(cols))
1250 .enumerate()
1251 .try_for_each(
1252 |(row, (((level_row, anchor_row), bias_row), changed_row))| {
1253 do_row(row, level_row, anchor_row, bias_row, changed_row)
1254 },
1255 )?;
1256 }
1257 #[cfg(target_arch = "wasm32")]
1258 {
1259 for (row, (((level_row, anchor_row), bias_row), changed_row)) in level_out
1260 .chunks_mut(cols)
1261 .zip(anchor_out.chunks_mut(cols))
1262 .zip(bias_out.chunks_mut(cols))
1263 .zip(changed_out.chunks_mut(cols))
1264 .enumerate()
1265 {
1266 do_row(row, level_row, anchor_row, bias_row, changed_row)?;
1267 }
1268 }
1269 } else {
1270 for (row, (((level_row, anchor_row), bias_row), changed_row)) in level_out
1271 .chunks_mut(cols)
1272 .zip(anchor_out.chunks_mut(cols))
1273 .zip(bias_out.chunks_mut(cols))
1274 .zip(changed_out.chunks_mut(cols))
1275 .enumerate()
1276 {
1277 do_row(row, level_row, anchor_row, bias_row, changed_row)?;
1278 }
1279 }
1280
1281 Ok(combos)
1282}
1283
1284pub fn statistical_trailing_stop_batch_with_kernel(
1285 high: &[f64],
1286 low: &[f64],
1287 close: &[f64],
1288 sweep: &StatisticalTrailingStopBatchRange,
1289 kernel: Kernel,
1290) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1291 match kernel {
1292 Kernel::Auto => {
1293 let _ = detect_best_batch_kernel();
1294 }
1295 k if !k.is_batch() => return Err(StatisticalTrailingStopError::InvalidKernelForBatch(k)),
1296 _ => {}
1297 }
1298 statistical_trailing_stop_batch_par_slice(high, low, close, sweep, Kernel::ScalarBatch)
1299}
1300
1301pub fn statistical_trailing_stop_batch_slice(
1302 high: &[f64],
1303 low: &[f64],
1304 close: &[f64],
1305 sweep: &StatisticalTrailingStopBatchRange,
1306 _kernel: Kernel,
1307) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1308 statistical_trailing_stop_batch_impl(high, low, close, sweep, false)
1309}
1310
1311pub fn statistical_trailing_stop_batch_par_slice(
1312 high: &[f64],
1313 low: &[f64],
1314 close: &[f64],
1315 sweep: &StatisticalTrailingStopBatchRange,
1316 _kernel: Kernel,
1317) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1318 statistical_trailing_stop_batch_impl(high, low, close, sweep, true)
1319}
1320
1321fn statistical_trailing_stop_batch_impl(
1322 high: &[f64],
1323 low: &[f64],
1324 close: &[f64],
1325 sweep: &StatisticalTrailingStopBatchRange,
1326 parallel: bool,
1327) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1328 let rows = expand_grid(sweep)?.len();
1329 let cols = close.len();
1330
1331 let mut level_mu = make_uninit_matrix(rows, cols);
1332 let mut anchor_mu = make_uninit_matrix(rows, cols);
1333 let mut bias_mu = make_uninit_matrix(rows, cols);
1334 let mut changed_mu = make_uninit_matrix(rows, cols);
1335
1336 let mut level_guard = ManuallyDrop::new(level_mu);
1337 let mut anchor_guard = ManuallyDrop::new(anchor_mu);
1338 let mut bias_guard = ManuallyDrop::new(bias_mu);
1339 let mut changed_guard = ManuallyDrop::new(changed_mu);
1340
1341 let level_out: &mut [f64] = unsafe {
1342 core::slice::from_raw_parts_mut(level_guard.as_mut_ptr() as *mut f64, level_guard.len())
1343 };
1344 let anchor_out: &mut [f64] = unsafe {
1345 core::slice::from_raw_parts_mut(anchor_guard.as_mut_ptr() as *mut f64, anchor_guard.len())
1346 };
1347 let bias_out: &mut [f64] = unsafe {
1348 core::slice::from_raw_parts_mut(bias_guard.as_mut_ptr() as *mut f64, bias_guard.len())
1349 };
1350 let changed_out: &mut [f64] = unsafe {
1351 core::slice::from_raw_parts_mut(changed_guard.as_mut_ptr() as *mut f64, changed_guard.len())
1352 };
1353
1354 let combos = statistical_trailing_stop_batch_inner_into(
1355 high,
1356 low,
1357 close,
1358 sweep,
1359 parallel,
1360 level_out,
1361 anchor_out,
1362 bias_out,
1363 changed_out,
1364 )?;
1365
1366 let level = unsafe {
1367 Vec::from_raw_parts(
1368 level_guard.as_mut_ptr() as *mut f64,
1369 level_guard.len(),
1370 level_guard.capacity(),
1371 )
1372 };
1373 let anchor = unsafe {
1374 Vec::from_raw_parts(
1375 anchor_guard.as_mut_ptr() as *mut f64,
1376 anchor_guard.len(),
1377 anchor_guard.capacity(),
1378 )
1379 };
1380 let bias = unsafe {
1381 Vec::from_raw_parts(
1382 bias_guard.as_mut_ptr() as *mut f64,
1383 bias_guard.len(),
1384 bias_guard.capacity(),
1385 )
1386 };
1387 let changed = unsafe {
1388 Vec::from_raw_parts(
1389 changed_guard.as_mut_ptr() as *mut f64,
1390 changed_guard.len(),
1391 changed_guard.capacity(),
1392 )
1393 };
1394
1395 Ok(StatisticalTrailingStopBatchOutput {
1396 level,
1397 anchor,
1398 bias,
1399 changed,
1400 combos,
1401 rows,
1402 cols,
1403 })
1404}
1405
1406#[cfg(feature = "python")]
1407#[pyfunction(name = "statistical_trailing_stop")]
1408#[pyo3(signature = (high, low, close, data_length=DEFAULT_DATA_LENGTH, normalization_length=DEFAULT_NORMALIZATION_LENGTH, base_level=DEFAULT_BASE_LEVEL, kernel=None))]
1409pub fn statistical_trailing_stop_py<'py>(
1410 py: Python<'py>,
1411 high: PyReadonlyArray1<'py, f64>,
1412 low: PyReadonlyArray1<'py, f64>,
1413 close: PyReadonlyArray1<'py, f64>,
1414 data_length: usize,
1415 normalization_length: usize,
1416 base_level: &str,
1417 kernel: Option<&str>,
1418) -> PyResult<(
1419 Bound<'py, PyArray1<f64>>,
1420 Bound<'py, PyArray1<f64>>,
1421 Bound<'py, PyArray1<f64>>,
1422 Bound<'py, PyArray1<f64>>,
1423)> {
1424 let high_slice = high.as_slice()?;
1425 let low_slice = low.as_slice()?;
1426 let close_slice = close.as_slice()?;
1427 let kernel = validate_kernel(kernel, false)?;
1428 let input = StatisticalTrailingStopInput::from_slices(
1429 high_slice,
1430 low_slice,
1431 close_slice,
1432 StatisticalTrailingStopParams {
1433 data_length: Some(data_length),
1434 normalization_length: Some(normalization_length),
1435 base_level: Some(base_level.to_string()),
1436 },
1437 );
1438 let output = py
1439 .allow_threads(|| statistical_trailing_stop_with_kernel(&input, kernel))
1440 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1441 Ok((
1442 output.level.into_pyarray(py),
1443 output.anchor.into_pyarray(py),
1444 output.bias.into_pyarray(py),
1445 output.changed.into_pyarray(py),
1446 ))
1447}
1448
1449#[cfg(feature = "python")]
1450#[pyfunction(name = "statistical_trailing_stop_batch")]
1451#[pyo3(signature = (high, low, close, data_length_range=(DEFAULT_DATA_LENGTH, DEFAULT_DATA_LENGTH, 0), normalization_length_range=(DEFAULT_NORMALIZATION_LENGTH, DEFAULT_NORMALIZATION_LENGTH, 0), base_level_range=(DEFAULT_BASE_LEVEL.to_string(), DEFAULT_BASE_LEVEL.to_string(), 0usize), kernel=None))]
1452pub fn statistical_trailing_stop_batch_py<'py>(
1453 py: Python<'py>,
1454 high: PyReadonlyArray1<'py, f64>,
1455 low: PyReadonlyArray1<'py, f64>,
1456 close: PyReadonlyArray1<'py, f64>,
1457 data_length_range: (usize, usize, usize),
1458 normalization_length_range: (usize, usize, usize),
1459 base_level_range: (String, String, usize),
1460 kernel: Option<&str>,
1461) -> PyResult<Bound<'py, PyDict>> {
1462 let high_slice = high.as_slice()?;
1463 let low_slice = low.as_slice()?;
1464 let close_slice = close.as_slice()?;
1465 let kernel = validate_kernel(kernel, true)?;
1466 let sweep = StatisticalTrailingStopBatchRange {
1467 data_length: data_length_range,
1468 normalization_length: normalization_length_range,
1469 base_level: base_level_range,
1470 };
1471
1472 let rows = expand_grid(&sweep)
1473 .map_err(|e| PyValueError::new_err(e.to_string()))?
1474 .len();
1475 let cols = close_slice.len();
1476 let total = rows.checked_mul(cols).ok_or_else(|| {
1477 PyValueError::new_err("rows*cols overflow in statistical_trailing_stop_batch")
1478 })?;
1479
1480 let level_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1481 let anchor_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1482 let bias_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1483 let changed_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1484
1485 let level_out = unsafe { level_arr.as_slice_mut()? };
1486 let anchor_out = unsafe { anchor_arr.as_slice_mut()? };
1487 let bias_out = unsafe { bias_arr.as_slice_mut()? };
1488 let changed_out = unsafe { changed_arr.as_slice_mut()? };
1489
1490 let combos = py
1491 .allow_threads(|| {
1492 statistical_trailing_stop_batch_inner_into(
1493 high_slice,
1494 low_slice,
1495 close_slice,
1496 &sweep,
1497 !matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch),
1498 level_out,
1499 anchor_out,
1500 bias_out,
1501 changed_out,
1502 )
1503 })
1504 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1505
1506 let dict = PyDict::new(py);
1507 dict.set_item("level", level_arr.reshape((rows, cols))?)?;
1508 dict.set_item("anchor", anchor_arr.reshape((rows, cols))?)?;
1509 dict.set_item("bias", bias_arr.reshape((rows, cols))?)?;
1510 dict.set_item("changed", changed_arr.reshape((rows, cols))?)?;
1511 dict.set_item(
1512 "data_lengths",
1513 combos
1514 .iter()
1515 .map(|c| c.data_length.unwrap_or(DEFAULT_DATA_LENGTH) as u64)
1516 .collect::<Vec<_>>()
1517 .into_pyarray(py),
1518 )?;
1519 dict.set_item(
1520 "normalization_lengths",
1521 combos
1522 .iter()
1523 .map(|c| {
1524 c.normalization_length
1525 .unwrap_or(DEFAULT_NORMALIZATION_LENGTH) as u64
1526 })
1527 .collect::<Vec<_>>()
1528 .into_pyarray(py),
1529 )?;
1530 let base_levels = PyList::empty(py);
1531 for combo in &combos {
1532 base_levels.append(combo.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
1533 }
1534 dict.set_item("base_levels", base_levels)?;
1535 dict.set_item("rows", rows)?;
1536 dict.set_item("cols", cols)?;
1537 Ok(dict)
1538}
1539
1540#[cfg(feature = "python")]
1541#[pyclass(name = "StatisticalTrailingStopStream")]
1542pub struct StatisticalTrailingStopStreamPy {
1543 stream: StatisticalTrailingStopStream,
1544}
1545
1546#[cfg(feature = "python")]
1547#[pymethods]
1548impl StatisticalTrailingStopStreamPy {
1549 #[new]
1550 #[pyo3(signature = (data_length=DEFAULT_DATA_LENGTH, normalization_length=DEFAULT_NORMALIZATION_LENGTH, base_level=DEFAULT_BASE_LEVEL))]
1551 fn new(data_length: usize, normalization_length: usize, base_level: &str) -> PyResult<Self> {
1552 let stream = StatisticalTrailingStopStream::try_new(StatisticalTrailingStopParams {
1553 data_length: Some(data_length),
1554 normalization_length: Some(normalization_length),
1555 base_level: Some(base_level.to_string()),
1556 })
1557 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1558 Ok(Self { stream })
1559 }
1560
1561 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
1562 self.stream.update(high, low, close)
1563 }
1564}
1565
1566#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1567#[derive(Serialize, Deserialize)]
1568pub struct StatisticalTrailingStopBatchConfig {
1569 pub data_length_range: (usize, usize, usize),
1570 pub normalization_length_range: (usize, usize, usize),
1571 pub base_level_range: (String, String, usize),
1572}
1573
1574#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1575#[derive(Serialize, Deserialize)]
1576pub struct StatisticalTrailingStopBatchJsOutput {
1577 pub level: Vec<f64>,
1578 pub anchor: Vec<f64>,
1579 pub bias: Vec<f64>,
1580 pub changed: Vec<f64>,
1581 pub combos: Vec<StatisticalTrailingStopParams>,
1582 pub rows: usize,
1583 pub cols: usize,
1584}
1585
1586#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1587#[wasm_bindgen]
1588pub fn statistical_trailing_stop_js(
1589 high: &[f64],
1590 low: &[f64],
1591 close: &[f64],
1592 data_length: usize,
1593 normalization_length: usize,
1594 base_level: &str,
1595) -> Result<JsValue, JsValue> {
1596 let input = StatisticalTrailingStopInput::from_slices(
1597 high,
1598 low,
1599 close,
1600 StatisticalTrailingStopParams {
1601 data_length: Some(data_length),
1602 normalization_length: Some(normalization_length),
1603 base_level: Some(base_level.to_string()),
1604 },
1605 );
1606 let output = statistical_trailing_stop_with_kernel(&input, Kernel::Auto)
1607 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1608 serde_wasm_bindgen::to_value(&output)
1609 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1610}
1611
1612#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1613#[wasm_bindgen]
1614pub fn statistical_trailing_stop_alloc(len: usize) -> *mut f64 {
1615 let mut vec = Vec::<f64>::with_capacity(len);
1616 let ptr = vec.as_mut_ptr();
1617 std::mem::forget(vec);
1618 ptr
1619}
1620
1621#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1622#[wasm_bindgen]
1623pub fn statistical_trailing_stop_free(ptr: *mut f64, len: usize) {
1624 if !ptr.is_null() {
1625 unsafe {
1626 let _ = Vec::from_raw_parts(ptr, len, len);
1627 }
1628 }
1629}
1630
1631#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1632#[wasm_bindgen]
1633pub fn statistical_trailing_stop_into(
1634 high_ptr: *const f64,
1635 low_ptr: *const f64,
1636 close_ptr: *const f64,
1637 level_ptr: *mut f64,
1638 anchor_ptr: *mut f64,
1639 bias_ptr: *mut f64,
1640 changed_ptr: *mut f64,
1641 len: usize,
1642 data_length: usize,
1643 normalization_length: usize,
1644 base_level: &str,
1645) -> Result<(), JsValue> {
1646 if high_ptr.is_null()
1647 || low_ptr.is_null()
1648 || close_ptr.is_null()
1649 || level_ptr.is_null()
1650 || anchor_ptr.is_null()
1651 || bias_ptr.is_null()
1652 || changed_ptr.is_null()
1653 {
1654 return Err(JsValue::from_str("Null pointer provided"));
1655 }
1656
1657 unsafe {
1658 let high = std::slice::from_raw_parts(high_ptr, len);
1659 let low = std::slice::from_raw_parts(low_ptr, len);
1660 let close = std::slice::from_raw_parts(close_ptr, len);
1661 let input = StatisticalTrailingStopInput::from_slices(
1662 high,
1663 low,
1664 close,
1665 StatisticalTrailingStopParams {
1666 data_length: Some(data_length),
1667 normalization_length: Some(normalization_length),
1668 base_level: Some(base_level.to_string()),
1669 },
1670 );
1671
1672 let aliased = [
1673 high_ptr as *const u8,
1674 low_ptr as *const u8,
1675 close_ptr as *const u8,
1676 ]
1677 .iter()
1678 .any(|&inp| {
1679 [
1680 level_ptr as *const u8,
1681 anchor_ptr as *const u8,
1682 bias_ptr as *const u8,
1683 changed_ptr as *const u8,
1684 ]
1685 .iter()
1686 .any(|&out| inp == out)
1687 }) || level_ptr == anchor_ptr
1688 || level_ptr == bias_ptr
1689 || level_ptr == changed_ptr
1690 || anchor_ptr == bias_ptr
1691 || anchor_ptr == changed_ptr
1692 || bias_ptr == changed_ptr;
1693
1694 if aliased {
1695 let output = statistical_trailing_stop_with_kernel(&input, Kernel::Auto)
1696 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1697 std::slice::from_raw_parts_mut(level_ptr, len).copy_from_slice(&output.level);
1698 std::slice::from_raw_parts_mut(anchor_ptr, len).copy_from_slice(&output.anchor);
1699 std::slice::from_raw_parts_mut(bias_ptr, len).copy_from_slice(&output.bias);
1700 std::slice::from_raw_parts_mut(changed_ptr, len).copy_from_slice(&output.changed);
1701 } else {
1702 let level_out = std::slice::from_raw_parts_mut(level_ptr, len);
1703 let anchor_out = std::slice::from_raw_parts_mut(anchor_ptr, len);
1704 let bias_out = std::slice::from_raw_parts_mut(bias_ptr, len);
1705 let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
1706 statistical_trailing_stop_into_slice(
1707 level_out,
1708 anchor_out,
1709 bias_out,
1710 changed_out,
1711 &input,
1712 Kernel::Auto,
1713 )
1714 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1715 }
1716 }
1717
1718 Ok(())
1719}
1720
1721#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1722#[wasm_bindgen(js_name = statistical_trailing_stop_batch)]
1723pub fn statistical_trailing_stop_batch_unified_js(
1724 high: &[f64],
1725 low: &[f64],
1726 close: &[f64],
1727 config: JsValue,
1728) -> Result<JsValue, JsValue> {
1729 let config: StatisticalTrailingStopBatchConfig = serde_wasm_bindgen::from_value(config)
1730 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1731 let sweep = StatisticalTrailingStopBatchRange {
1732 data_length: config.data_length_range,
1733 normalization_length: config.normalization_length_range,
1734 base_level: config.base_level_range,
1735 };
1736 let output =
1737 statistical_trailing_stop_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1738 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1739 let js_output = StatisticalTrailingStopBatchJsOutput {
1740 level: output.level,
1741 anchor: output.anchor,
1742 bias: output.bias,
1743 changed: output.changed,
1744 combos: output.combos,
1745 rows: output.rows,
1746 cols: output.cols,
1747 };
1748 serde_wasm_bindgen::to_value(&js_output)
1749 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1750}
1751
1752#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1753#[wasm_bindgen]
1754pub fn statistical_trailing_stop_batch_into(
1755 high_ptr: *const f64,
1756 low_ptr: *const f64,
1757 close_ptr: *const f64,
1758 level_ptr: *mut f64,
1759 anchor_ptr: *mut f64,
1760 bias_ptr: *mut f64,
1761 changed_ptr: *mut f64,
1762 len: usize,
1763 data_length_start: usize,
1764 data_length_end: usize,
1765 data_length_step: usize,
1766 normalization_length_start: usize,
1767 normalization_length_end: usize,
1768 normalization_length_step: usize,
1769 base_level_start: &str,
1770 base_level_end: &str,
1771 base_level_step: usize,
1772) -> Result<usize, JsValue> {
1773 if high_ptr.is_null()
1774 || low_ptr.is_null()
1775 || close_ptr.is_null()
1776 || level_ptr.is_null()
1777 || anchor_ptr.is_null()
1778 || bias_ptr.is_null()
1779 || changed_ptr.is_null()
1780 {
1781 return Err(JsValue::from_str("Null pointer provided"));
1782 }
1783 let sweep = StatisticalTrailingStopBatchRange {
1784 data_length: (data_length_start, data_length_end, data_length_step),
1785 normalization_length: (
1786 normalization_length_start,
1787 normalization_length_end,
1788 normalization_length_step,
1789 ),
1790 base_level: (
1791 base_level_start.to_string(),
1792 base_level_end.to_string(),
1793 base_level_step,
1794 ),
1795 };
1796 let rows = expand_grid(&sweep)
1797 .map_err(|e| JsValue::from_str(&e.to_string()))?
1798 .len();
1799 let total = rows
1800 .checked_mul(len)
1801 .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
1802
1803 unsafe {
1804 let high = std::slice::from_raw_parts(high_ptr, len);
1805 let low = std::slice::from_raw_parts(low_ptr, len);
1806 let close = std::slice::from_raw_parts(close_ptr, len);
1807 let level_out = std::slice::from_raw_parts_mut(level_ptr, total);
1808 let anchor_out = std::slice::from_raw_parts_mut(anchor_ptr, total);
1809 let bias_out = std::slice::from_raw_parts_mut(bias_ptr, total);
1810 let changed_out = std::slice::from_raw_parts_mut(changed_ptr, total);
1811 statistical_trailing_stop_batch_inner_into(
1812 high,
1813 low,
1814 close,
1815 &sweep,
1816 false,
1817 level_out,
1818 anchor_out,
1819 bias_out,
1820 changed_out,
1821 )
1822 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1823 }
1824
1825 Ok(rows)
1826}
1827
1828#[cfg(test)]
1829mod tests {
1830 use super::*;
1831 use crate::utilities::data_loader::read_candles_from_csv;
1832
1833 fn trend_data(size: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1834 let mut high = Vec::with_capacity(size);
1835 let mut low = Vec::with_capacity(size);
1836 let mut close = Vec::with_capacity(size);
1837 for i in 0..size {
1838 let base = 100.0 + i as f64 * 0.9;
1839 high.push(base + 1.0 + (i % 3) as f64 * 0.1);
1840 low.push(base - 1.0 - (i % 2) as f64 * 0.1);
1841 close.push(base);
1842 }
1843 (high, low, close)
1844 }
1845
1846 fn reversal_data() -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1847 let close = vec![
1848 100.0, 99.5, 99.0, 98.5, 98.0, 97.5, 97.0, 96.5, 96.0, 95.5, 97.0, 99.0, 101.0, 104.0,
1849 108.0, 111.0, 113.0, 112.0, 111.0, 109.0, 107.0, 105.0, 103.0, 101.0, 99.0, 97.0, 95.0,
1850 94.0, 93.5, 93.0,
1851 ];
1852 let high = close.iter().map(|v| v + 0.8).collect::<Vec<_>>();
1853 let low = close.iter().map(|v| v - 0.8).collect::<Vec<_>>();
1854 (high, low, close)
1855 }
1856
1857 fn arrays_eq_nan(a: &[f64], b: &[f64]) -> bool {
1858 a.len() == b.len()
1859 && a.iter().zip(b.iter()).all(|(x, y)| {
1860 (x.is_nan() && y.is_nan())
1861 || (!x.is_nan() && !y.is_nan() && (*x - *y).abs() <= 1e-12)
1862 })
1863 }
1864
1865 #[test]
1866 fn statistical_trailing_stop_into_matches_single() -> Result<(), Box<dyn StdError>> {
1867 let (high, low, close) = trend_data(160);
1868 let input = StatisticalTrailingStopInput::from_slices(
1869 &high,
1870 &low,
1871 &close,
1872 StatisticalTrailingStopParams::default(),
1873 );
1874 let single = statistical_trailing_stop(&input)?;
1875
1876 let mut level = vec![0.0; close.len()];
1877 let mut anchor = vec![0.0; close.len()];
1878 let mut bias = vec![0.0; close.len()];
1879 let mut changed = vec![0.0; close.len()];
1880 statistical_trailing_stop_into_slice(
1881 &mut level,
1882 &mut anchor,
1883 &mut bias,
1884 &mut changed,
1885 &input,
1886 Kernel::Auto,
1887 )?;
1888
1889 assert!(arrays_eq_nan(&single.level, &level));
1890 assert!(arrays_eq_nan(&single.anchor, &anchor));
1891 assert!(arrays_eq_nan(&single.bias, &bias));
1892 assert!(arrays_eq_nan(&single.changed, &changed));
1893 Ok(())
1894 }
1895
1896 #[test]
1897 fn statistical_trailing_stop_stream_matches_batch() -> Result<(), Box<dyn StdError>> {
1898 let (high, low, close) = trend_data(170);
1899 let params = StatisticalTrailingStopParams::default();
1900 let input = StatisticalTrailingStopInput::from_slices(&high, &low, &close, params.clone());
1901 let batch = statistical_trailing_stop(&input)?;
1902
1903 let mut stream = StatisticalTrailingStopStream::try_new(params)?;
1904 let mut level = Vec::with_capacity(close.len());
1905 let mut anchor = Vec::with_capacity(close.len());
1906 let mut bias = Vec::with_capacity(close.len());
1907 let mut changed = Vec::with_capacity(close.len());
1908
1909 for i in 0..close.len() {
1910 if let Some((lvl, anc, bs, ch)) = stream.update(high[i], low[i], close[i]) {
1911 level.push(lvl);
1912 anchor.push(anc);
1913 bias.push(bs);
1914 changed.push(ch);
1915 } else {
1916 level.push(f64::NAN);
1917 anchor.push(f64::NAN);
1918 bias.push(f64::NAN);
1919 changed.push(f64::NAN);
1920 }
1921 }
1922
1923 assert!(arrays_eq_nan(&batch.level, &level));
1924 assert!(arrays_eq_nan(&batch.anchor, &anchor));
1925 assert!(arrays_eq_nan(&batch.bias, &bias));
1926 assert!(arrays_eq_nan(&batch.changed, &changed));
1927 Ok(())
1928 }
1929
1930 #[test]
1931 fn statistical_trailing_stop_reversal_behavior() -> Result<(), Box<dyn StdError>> {
1932 let (high, low, close) = reversal_data();
1933 let output = statistical_trailing_stop(&StatisticalTrailingStopInput::from_slices(
1934 &high,
1935 &low,
1936 &close,
1937 StatisticalTrailingStopParams {
1938 data_length: Some(3),
1939 normalization_length: Some(10),
1940 base_level: Some("level0".to_string()),
1941 },
1942 ))?;
1943
1944 let changes = output
1945 .changed
1946 .iter()
1947 .enumerate()
1948 .filter_map(|(i, v)| if *v == 1.0 { Some(i) } else { None })
1949 .collect::<Vec<_>>();
1950 assert!(!changes.is_empty());
1951 let first = changes[0];
1952 assert!(output.anchor[first].is_finite());
1953 assert_eq!(output.bias[first], 1.0);
1954 Ok(())
1955 }
1956
1957 #[test]
1958 fn statistical_trailing_stop_nan_gap_restarts() -> Result<(), Box<dyn StdError>> {
1959 let (mut high, mut low, mut close) = trend_data(170);
1960 high[120] = f64::NAN;
1961 low[120] = f64::NAN;
1962 close[120] = f64::NAN;
1963 let output = statistical_trailing_stop(&StatisticalTrailingStopInput::from_slices(
1964 &high,
1965 &low,
1966 &close,
1967 StatisticalTrailingStopParams::default(),
1968 ))?;
1969
1970 let restart_end =
1971 (120 + DEFAULT_DATA_LENGTH + DEFAULT_NORMALIZATION_LENGTH + 1).min(output.level.len());
1972 for i in 120..restart_end {
1973 assert!(output.level[i].is_nan());
1974 assert!(output.bias[i].is_nan());
1975 }
1976 Ok(())
1977 }
1978
1979 #[test]
1980 fn statistical_trailing_stop_batch_matches_single() -> Result<(), Box<dyn StdError>> {
1981 let (high, low, close) = trend_data(170);
1982 let sweep = StatisticalTrailingStopBatchRange {
1983 data_length: (9, 10, 1),
1984 normalization_length: (10, 11, 1),
1985 base_level: ("level1".to_string(), "level2".to_string(), 1),
1986 };
1987 let batch = statistical_trailing_stop_batch_with_kernel(
1988 &high,
1989 &low,
1990 &close,
1991 &sweep,
1992 Kernel::ScalarBatch,
1993 )?;
1994
1995 assert_eq!(batch.rows, 8);
1996 assert_eq!(batch.cols, close.len());
1997 for row in 0..batch.rows {
1998 let combo = &batch.combos[row];
1999 let single = statistical_trailing_stop(&StatisticalTrailingStopInput::from_slices(
2000 &high,
2001 &low,
2002 &close,
2003 combo.clone(),
2004 ))?;
2005 let start = row * batch.cols;
2006 let end = start + batch.cols;
2007 assert!(arrays_eq_nan(
2008 &batch.level[start..end],
2009 single.level.as_slice()
2010 ));
2011 assert!(arrays_eq_nan(
2012 &batch.anchor[start..end],
2013 single.anchor.as_slice()
2014 ));
2015 assert!(arrays_eq_nan(
2016 &batch.bias[start..end],
2017 single.bias.as_slice()
2018 ));
2019 assert!(arrays_eq_nan(
2020 &batch.changed[start..end],
2021 single.changed.as_slice()
2022 ));
2023 }
2024 Ok(())
2025 }
2026
2027 #[test]
2028 fn statistical_trailing_stop_invalid_base_level_errors() {
2029 let (high, low, close) = trend_data(160);
2030 let input = StatisticalTrailingStopInput::from_slices(
2031 &high,
2032 &low,
2033 &close,
2034 StatisticalTrailingStopParams {
2035 data_length: Some(10),
2036 normalization_length: Some(100),
2037 base_level: Some("bad".to_string()),
2038 },
2039 );
2040 assert!(matches!(
2041 statistical_trailing_stop(&input),
2042 Err(StatisticalTrailingStopError::InvalidBaseLevel { .. })
2043 ));
2044 }
2045
2046 #[test]
2047 fn statistical_trailing_stop_all_nan_errors() {
2048 let high = vec![f64::NAN; 200];
2049 let low = vec![f64::NAN; 200];
2050 let close = vec![f64::NAN; 200];
2051 let input = StatisticalTrailingStopInput::from_slices(
2052 &high,
2053 &low,
2054 &close,
2055 StatisticalTrailingStopParams::default(),
2056 );
2057 assert!(matches!(
2058 statistical_trailing_stop(&input),
2059 Err(StatisticalTrailingStopError::AllValuesNaN)
2060 ));
2061 }
2062
2063 #[test]
2064 fn statistical_trailing_stop_default_candles_smoke() -> Result<(), Box<dyn StdError>> {
2065 let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
2066 let output = statistical_trailing_stop(
2067 &StatisticalTrailingStopInput::with_default_candles(&candles),
2068 )?;
2069 assert_eq!(output.level.len(), candles.close.len());
2070 assert_eq!(output.anchor.len(), candles.close.len());
2071 assert_eq!(output.bias.len(), candles.close.len());
2072 assert_eq!(output.changed.len(), candles.close.len());
2073 Ok(())
2074 }
2075}