Skip to main content

vector_ta/indicators/
statistical_trailing_stop.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::Candles;
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24#[cfg(test)]
25use std::error::Error as StdError;
26use std::mem::ManuallyDrop;
27use thiserror::Error;
28
29const DEFAULT_DATA_LENGTH: usize = 10;
30const DEFAULT_NORMALIZATION_LENGTH: usize = 100;
31const DEFAULT_BASE_LEVEL: &str = "level2";
32const MIN_DATA_LENGTH: usize = 1;
33const MIN_NORMALIZATION_LENGTH: usize = 10;
34const BIAS_BEARISH: u8 = 0;
35const BIAS_BULLISH: u8 = 1;
36
37#[inline(always)]
38fn hlc3(high: f64, low: f64, close: f64) -> f64 {
39    (high + low + close) / 3.0
40}
41
42#[inline(always)]
43fn floor_positive(value: f64) -> f64 {
44    if value > 0.0 {
45        value
46    } else {
47        f64::MIN_POSITIVE
48    }
49}
50
51#[inline(always)]
52fn close_source(candles: &Candles) -> &[f64] {
53    &candles.close
54}
55
56#[inline(always)]
57fn high_source(candles: &Candles) -> &[f64] {
58    &candles.high
59}
60
61#[inline(always)]
62fn low_source(candles: &Candles) -> &[f64] {
63    &candles.low
64}
65
66#[derive(Debug, Clone)]
67pub enum StatisticalTrailingStopData<'a> {
68    Candles {
69        candles: &'a Candles,
70    },
71    Slices {
72        high: &'a [f64],
73        low: &'a [f64],
74        close: &'a [f64],
75    },
76}
77
78#[derive(Debug, Clone)]
79#[cfg_attr(
80    all(target_arch = "wasm32", feature = "wasm"),
81    derive(Serialize, Deserialize)
82)]
83pub struct StatisticalTrailingStopOutput {
84    pub level: Vec<f64>,
85    pub anchor: Vec<f64>,
86    pub bias: Vec<f64>,
87    pub changed: Vec<f64>,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
91#[cfg_attr(
92    all(target_arch = "wasm32", feature = "wasm"),
93    derive(Serialize, Deserialize)
94)]
95pub struct StatisticalTrailingStopParams {
96    pub data_length: Option<usize>,
97    pub normalization_length: Option<usize>,
98    pub base_level: Option<String>,
99}
100
101impl Default for StatisticalTrailingStopParams {
102    fn default() -> Self {
103        Self {
104            data_length: Some(DEFAULT_DATA_LENGTH),
105            normalization_length: Some(DEFAULT_NORMALIZATION_LENGTH),
106            base_level: Some(DEFAULT_BASE_LEVEL.to_string()),
107        }
108    }
109}
110
111#[derive(Debug, Clone)]
112pub struct StatisticalTrailingStopInput<'a> {
113    pub data: StatisticalTrailingStopData<'a>,
114    pub params: StatisticalTrailingStopParams,
115}
116
117impl<'a> StatisticalTrailingStopInput<'a> {
118    #[inline(always)]
119    pub fn from_candles(candles: &'a Candles, params: StatisticalTrailingStopParams) -> Self {
120        Self {
121            data: StatisticalTrailingStopData::Candles { candles },
122            params,
123        }
124    }
125
126    #[inline(always)]
127    pub fn from_slices(
128        high: &'a [f64],
129        low: &'a [f64],
130        close: &'a [f64],
131        params: StatisticalTrailingStopParams,
132    ) -> Self {
133        Self {
134            data: StatisticalTrailingStopData::Slices { high, low, close },
135            params,
136        }
137    }
138
139    #[inline(always)]
140    pub fn with_default_candles(candles: &'a Candles) -> Self {
141        Self::from_candles(candles, StatisticalTrailingStopParams::default())
142    }
143
144    #[inline(always)]
145    pub fn get_data_length(&self) -> usize {
146        self.params.data_length.unwrap_or(DEFAULT_DATA_LENGTH)
147    }
148
149    #[inline(always)]
150    pub fn get_normalization_length(&self) -> usize {
151        self.params
152            .normalization_length
153            .unwrap_or(DEFAULT_NORMALIZATION_LENGTH)
154    }
155
156    #[inline(always)]
157    pub fn get_base_level(&self) -> &str {
158        self.params
159            .base_level
160            .as_deref()
161            .unwrap_or(DEFAULT_BASE_LEVEL)
162    }
163
164    #[inline(always)]
165    fn as_hlc(&self) -> (&'a [f64], &'a [f64], &'a [f64]) {
166        match &self.data {
167            StatisticalTrailingStopData::Candles { candles } => (
168                high_source(candles),
169                low_source(candles),
170                close_source(candles),
171            ),
172            StatisticalTrailingStopData::Slices { high, low, close } => (*high, *low, *close),
173        }
174    }
175}
176
177impl<'a> AsRef<[f64]> for StatisticalTrailingStopInput<'a> {
178    #[inline(always)]
179    fn as_ref(&self) -> &[f64] {
180        self.as_hlc().2
181    }
182}
183
184#[derive(Clone, Debug)]
185pub struct StatisticalTrailingStopBuilder {
186    data_length: Option<usize>,
187    normalization_length: Option<usize>,
188    base_level: Option<String>,
189    kernel: Kernel,
190}
191
192impl Default for StatisticalTrailingStopBuilder {
193    fn default() -> Self {
194        Self {
195            data_length: None,
196            normalization_length: None,
197            base_level: None,
198            kernel: Kernel::Auto,
199        }
200    }
201}
202
203impl StatisticalTrailingStopBuilder {
204    #[inline(always)]
205    pub fn new() -> Self {
206        Self::default()
207    }
208
209    #[inline(always)]
210    pub fn data_length(mut self, value: usize) -> Self {
211        self.data_length = Some(value);
212        self
213    }
214
215    #[inline(always)]
216    pub fn normalization_length(mut self, value: usize) -> Self {
217        self.normalization_length = Some(value);
218        self
219    }
220
221    #[inline(always)]
222    pub fn base_level<S: Into<String>>(mut self, value: S) -> Self {
223        self.base_level = Some(value.into());
224        self
225    }
226
227    #[inline(always)]
228    pub fn kernel(mut self, kernel: Kernel) -> Self {
229        self.kernel = kernel;
230        self
231    }
232
233    #[inline(always)]
234    fn params(self) -> StatisticalTrailingStopParams {
235        StatisticalTrailingStopParams {
236            data_length: self.data_length,
237            normalization_length: self.normalization_length,
238            base_level: self.base_level,
239        }
240    }
241
242    #[inline(always)]
243    pub fn apply(
244        self,
245        candles: &Candles,
246    ) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
247        let kernel = self.kernel;
248        let params = self.params();
249        statistical_trailing_stop_with_kernel(
250            &StatisticalTrailingStopInput::from_candles(candles, params),
251            kernel,
252        )
253    }
254
255    #[inline(always)]
256    pub fn apply_slices(
257        self,
258        high: &[f64],
259        low: &[f64],
260        close: &[f64],
261    ) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
262        let kernel = self.kernel;
263        let params = self.params();
264        statistical_trailing_stop_with_kernel(
265            &StatisticalTrailingStopInput::from_slices(high, low, close, params),
266            kernel,
267        )
268    }
269
270    #[inline(always)]
271    pub fn into_stream(
272        self,
273    ) -> Result<StatisticalTrailingStopStream, StatisticalTrailingStopError> {
274        StatisticalTrailingStopStream::try_new(self.params())
275    }
276}
277
278#[derive(Debug, Error)]
279pub enum StatisticalTrailingStopError {
280    #[error("statistical_trailing_stop: input data slice is empty.")]
281    EmptyInputData,
282    #[error("statistical_trailing_stop: all values are NaN.")]
283    AllValuesNaN,
284    #[error(
285        "statistical_trailing_stop: inconsistent data lengths - high = {high_len}, low = {low_len}, close = {close_len}"
286    )]
287    DataLengthMismatch {
288        high_len: usize,
289        low_len: usize,
290        close_len: usize,
291    },
292    #[error(
293        "statistical_trailing_stop: invalid period: data_length = {data_length}, normalization_length = {normalization_length}, data length = {data_len}"
294    )]
295    InvalidPeriod {
296        data_length: usize,
297        normalization_length: usize,
298        data_len: usize,
299    },
300    #[error(
301        "statistical_trailing_stop: invalid base level: {base_level}. expected one of level0, level1, level2, level3"
302    )]
303    InvalidBaseLevel { base_level: String },
304    #[error(
305        "statistical_trailing_stop: not enough valid data: needed = {needed}, valid = {valid}"
306    )]
307    NotEnoughValidData { needed: usize, valid: usize },
308    #[error(
309        "statistical_trailing_stop: output length mismatch: expected = {expected}, got = {got}"
310    )]
311    OutputLengthMismatch { expected: usize, got: usize },
312    #[error(
313        "statistical_trailing_stop: invalid range for {axis}: start = {start}, end = {end}, step = {step}"
314    )]
315    InvalidRange {
316        axis: &'static str,
317        start: String,
318        end: String,
319        step: String,
320    },
321    #[error("statistical_trailing_stop: invalid kernel for batch: {0:?}")]
322    InvalidKernelForBatch(Kernel),
323}
324
325#[derive(Clone, Copy, Debug)]
326struct PreparedInput<'a> {
327    high: &'a [f64],
328    low: &'a [f64],
329    close: &'a [f64],
330    data_length: usize,
331    normalization_length: usize,
332    base_level_index: usize,
333    warmup: usize,
334}
335
336#[derive(Clone, Debug)]
337struct MonoDeque {
338    idx: Vec<usize>,
339    vals: Vec<f64>,
340    head: usize,
341    descending: bool,
342}
343
344impl MonoDeque {
345    #[inline(always)]
346    fn new(descending: bool, cap: usize) -> Self {
347        Self {
348            idx: Vec::with_capacity(cap),
349            vals: Vec::with_capacity(cap),
350            head: 0,
351            descending,
352        }
353    }
354
355    #[inline(always)]
356    fn clear(&mut self) {
357        self.idx.clear();
358        self.vals.clear();
359        self.head = 0;
360    }
361
362    #[inline(always)]
363    fn push(&mut self, index: usize, value: f64) {
364        while self.idx.len() > self.head {
365            let last = *self.vals.last().unwrap();
366            let remove = if self.descending {
367                last <= value
368            } else {
369                last >= value
370            };
371            if !remove {
372                break;
373            }
374            self.idx.pop();
375            self.vals.pop();
376        }
377        self.idx.push(index);
378        self.vals.push(value);
379    }
380
381    #[inline(always)]
382    fn expire(&mut self, min_index: usize) {
383        while self.head < self.idx.len() && self.idx[self.head] < min_index {
384            self.head += 1;
385        }
386        if self.head > 64 && self.head * 2 >= self.idx.len() {
387            self.idx.drain(..self.head);
388            self.vals.drain(..self.head);
389            self.head = 0;
390        }
391    }
392
393    #[inline(always)]
394    fn front_value(&self) -> f64 {
395        self.vals[self.head]
396    }
397}
398
399#[derive(Clone, Debug)]
400struct RingHistory {
401    values: Vec<f64>,
402    head: usize,
403    count: usize,
404}
405
406impl RingHistory {
407    #[inline(always)]
408    fn new(cap: usize) -> Self {
409        Self {
410            values: vec![0.0; cap.max(1)],
411            head: 0,
412            count: 0,
413        }
414    }
415
416    #[inline(always)]
417    fn clear(&mut self) {
418        self.head = 0;
419        self.count = 0;
420    }
421
422    #[inline(always)]
423    fn push(&mut self, value: f64) {
424        self.values[self.head] = value;
425        self.head += 1;
426        if self.head == self.values.len() {
427            self.head = 0;
428        }
429        if self.count < self.values.len() {
430            self.count += 1;
431        }
432    }
433
434    #[inline(always)]
435    fn get_from_end(&self, offset: usize) -> Option<f64> {
436        if offset == 0 || offset > self.count {
437            return None;
438        }
439        let len = self.values.len();
440        let idx = (self.head + len - offset) % len;
441        Some(self.values[idx])
442    }
443}
444
445#[derive(Clone, Debug)]
446struct RollingStats {
447    ring: Vec<f64>,
448    head: usize,
449    count: usize,
450    sum: f64,
451    sum_sq: f64,
452}
453
454impl RollingStats {
455    #[inline(always)]
456    fn new(length: usize) -> Self {
457        Self {
458            ring: vec![0.0; length.max(1)],
459            head: 0,
460            count: 0,
461            sum: 0.0,
462            sum_sq: 0.0,
463        }
464    }
465
466    #[inline(always)]
467    fn clear(&mut self) {
468        self.head = 0;
469        self.count = 0;
470        self.sum = 0.0;
471        self.sum_sq = 0.0;
472    }
473
474    #[inline(always)]
475    fn push(&mut self, value: f64) -> Option<(f64, f64)> {
476        if self.count < self.ring.len() {
477            self.ring[self.head] = value;
478            self.head += 1;
479            if self.head == self.ring.len() {
480                self.head = 0;
481            }
482            self.count += 1;
483            self.sum += value;
484            self.sum_sq += value * value;
485        } else {
486            let old = self.ring[self.head];
487            self.ring[self.head] = value;
488            self.head += 1;
489            if self.head == self.ring.len() {
490                self.head = 0;
491            }
492            self.sum += value - old;
493            self.sum_sq += value * value - old * old;
494        }
495        if self.count < self.ring.len() {
496            return None;
497        }
498        let n = self.ring.len() as f64;
499        let mean = self.sum / n;
500        let variance = (self.sum_sq / n - mean * mean).max(0.0);
501        Some((mean, variance.sqrt()))
502    }
503}
504
505#[derive(Clone, Debug)]
506struct StatisticalTrailingStopState {
507    data_length: usize,
508    base_level_index: usize,
509    valid_run: usize,
510    max_high: MonoDeque,
511    min_low: MonoDeque,
512    close_history: RingHistory,
513    stats: RollingStats,
514    bias: u8,
515    delta: f64,
516    level: f64,
517    extreme: f64,
518    anchor: f64,
519    index: usize,
520}
521
522impl StatisticalTrailingStopState {
523    #[inline(always)]
524    fn new(data_length: usize, normalization_length: usize, base_level_index: usize) -> Self {
525        Self {
526            data_length,
527            base_level_index,
528            valid_run: 0,
529            max_high: MonoDeque::new(true, data_length + 2),
530            min_low: MonoDeque::new(false, data_length + 2),
531            close_history: RingHistory::new(data_length + 2),
532            stats: RollingStats::new(normalization_length),
533            bias: BIAS_BEARISH,
534            delta: f64::NAN,
535            level: f64::NAN,
536            extreme: f64::NAN,
537            anchor: f64::NAN,
538            index: 0,
539        }
540    }
541
542    #[inline(always)]
543    fn reset(&mut self) {
544        self.valid_run = 0;
545        self.max_high.clear();
546        self.min_low.clear();
547        self.close_history.clear();
548        self.stats.clear();
549        self.bias = BIAS_BEARISH;
550        self.delta = f64::NAN;
551        self.level = f64::NAN;
552        self.extreme = f64::NAN;
553        self.anchor = f64::NAN;
554        self.index = 0;
555    }
556
557    #[inline(always)]
558    fn update(
559        &mut self,
560        index: usize,
561        high: f64,
562        low: f64,
563        close: f64,
564    ) -> Option<(f64, f64, f64, f64)> {
565        if !high.is_finite() || !low.is_finite() || !close.is_finite() {
566            self.reset();
567            return None;
568        }
569
570        self.valid_run += 1;
571        self.max_high.push(index, high);
572        self.min_low.push(index, low);
573        let window_start = index + 1 - self.valid_run.min(self.data_length);
574        self.max_high.expire(window_start);
575        self.min_low.expire(window_start);
576        self.close_history.push(close);
577
578        if self.valid_run < self.data_length + 2 {
579            return None;
580        }
581
582        let previous_close = self.close_history.get_from_end(self.data_length + 2)?;
583        let highest = self.max_high.front_value();
584        let lowest = self.min_low.front_value();
585        let tr = (highest - lowest)
586            .max((highest - previous_close).abs())
587            .max((lowest - previous_close).abs());
588        let (mean, stdev) = self.stats.push(floor_positive(tr).ln())?;
589        let delta = (mean + self.base_level_index as f64 * stdev).exp();
590        self.delta = delta;
591
592        let current_hlc3 = hlc3(high, low, close);
593        if !self.level.is_finite() {
594            self.level = if self.bias == BIAS_BEARISH {
595                current_hlc3 + delta
596            } else {
597                (current_hlc3 - delta).max(0.0)
598            };
599        }
600
601        if self.bias == BIAS_BEARISH {
602            if self.extreme.is_finite() {
603                self.extreme = self.extreme.min(low);
604            }
605            self.level = self.level.min(current_hlc3 + delta);
606        } else {
607            if self.extreme.is_finite() {
608                self.extreme = self.extreme.max(high);
609            }
610            self.level = self.level.max((current_hlc3 - delta).max(0.0));
611        }
612
613        let triggered = (self.bias == BIAS_BEARISH && close >= self.level)
614            || (self.bias == BIAS_BULLISH && close <= self.level);
615        let mut changed = 0.0;
616
617        if triggered {
618            self.anchor = close;
619            self.index = index;
620            self.bias = if self.bias == BIAS_BEARISH {
621                BIAS_BULLISH
622            } else {
623                BIAS_BEARISH
624            };
625            self.level = if self.bias == BIAS_BEARISH {
626                current_hlc3 + delta
627            } else {
628                (current_hlc3 - delta).max(0.0)
629            };
630            self.extreme = if self.bias == BIAS_BEARISH { low } else { high };
631            changed = 1.0;
632        }
633
634        Some((self.level, self.anchor, self.bias as f64, changed))
635    }
636}
637
638#[derive(Clone, Debug)]
639pub struct StatisticalTrailingStopStream {
640    params: StatisticalTrailingStopParams,
641    state: StatisticalTrailingStopState,
642    index: usize,
643}
644
645impl StatisticalTrailingStopStream {
646    #[inline(always)]
647    pub fn try_new(
648        params: StatisticalTrailingStopParams,
649    ) -> Result<Self, StatisticalTrailingStopError> {
650        let data_length = params.data_length.unwrap_or(DEFAULT_DATA_LENGTH);
651        let normalization_length = params
652            .normalization_length
653            .unwrap_or(DEFAULT_NORMALIZATION_LENGTH);
654        validate_periods(data_length, normalization_length, usize::MAX)?;
655        let base_level_index =
656            parse_base_level(params.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
657        Ok(Self {
658            state: StatisticalTrailingStopState::new(
659                data_length,
660                normalization_length,
661                base_level_index,
662            ),
663            params,
664            index: 0,
665        })
666    }
667
668    #[inline(always)]
669    pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
670        let output = self.state.update(self.index, high, low, close);
671        self.index = self.index.saturating_add(1);
672        output
673    }
674
675    #[inline(always)]
676    pub fn params(&self) -> &StatisticalTrailingStopParams {
677        &self.params
678    }
679}
680
681#[derive(Clone, Debug)]
682pub struct StatisticalTrailingStopBatchRange {
683    pub data_length: (usize, usize, usize),
684    pub normalization_length: (usize, usize, usize),
685    pub base_level: (String, String, usize),
686}
687
688impl Default for StatisticalTrailingStopBatchRange {
689    fn default() -> Self {
690        Self {
691            data_length: (DEFAULT_DATA_LENGTH, DEFAULT_DATA_LENGTH, 0),
692            normalization_length: (
693                DEFAULT_NORMALIZATION_LENGTH,
694                DEFAULT_NORMALIZATION_LENGTH,
695                0,
696            ),
697            base_level: (
698                DEFAULT_BASE_LEVEL.to_string(),
699                DEFAULT_BASE_LEVEL.to_string(),
700                0,
701            ),
702        }
703    }
704}
705
706#[derive(Clone, Debug, Default)]
707pub struct StatisticalTrailingStopBatchBuilder {
708    range: StatisticalTrailingStopBatchRange,
709    kernel: Kernel,
710}
711
712#[derive(Clone, Debug)]
713pub struct StatisticalTrailingStopBatchOutput {
714    pub level: Vec<f64>,
715    pub anchor: Vec<f64>,
716    pub bias: Vec<f64>,
717    pub changed: Vec<f64>,
718    pub combos: Vec<StatisticalTrailingStopParams>,
719    pub rows: usize,
720    pub cols: usize,
721}
722
723impl StatisticalTrailingStopBatchBuilder {
724    #[inline(always)]
725    pub fn new() -> Self {
726        Self::default()
727    }
728
729    #[inline(always)]
730    pub fn kernel(mut self, kernel: Kernel) -> Self {
731        self.kernel = kernel;
732        self
733    }
734
735    #[inline(always)]
736    pub fn data_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
737        self.range.data_length = (start, end, step);
738        self
739    }
740
741    #[inline(always)]
742    pub fn normalization_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
743        self.range.normalization_length = (start, end, step);
744        self
745    }
746
747    #[inline(always)]
748    pub fn base_level_range<S: Into<String>>(mut self, start: S, end: S, step: usize) -> Self {
749        self.range.base_level = (start.into(), end.into(), step);
750        self
751    }
752
753    #[inline(always)]
754    pub fn apply_slices(
755        self,
756        high: &[f64],
757        low: &[f64],
758        close: &[f64],
759    ) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
760        statistical_trailing_stop_batch_with_kernel(high, low, close, &self.range, self.kernel)
761    }
762
763    #[inline(always)]
764    pub fn apply(
765        self,
766        candles: &Candles,
767    ) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
768        self.apply_slices(&candles.high, &candles.low, &candles.close)
769    }
770}
771
772#[inline(always)]
773fn validate_periods(
774    data_length: usize,
775    normalization_length: usize,
776    data_len: usize,
777) -> Result<(), StatisticalTrailingStopError> {
778    if data_length < MIN_DATA_LENGTH
779        || normalization_length < MIN_NORMALIZATION_LENGTH
780        || data_length + normalization_length + 1 > data_len
781    {
782        return Err(StatisticalTrailingStopError::InvalidPeriod {
783            data_length,
784            normalization_length,
785            data_len,
786        });
787    }
788    Ok(())
789}
790
791#[inline(always)]
792fn parse_base_level(value: &str) -> Result<usize, StatisticalTrailingStopError> {
793    let normalized = value
794        .trim()
795        .to_ascii_lowercase()
796        .replace([' ', '_', '-'], "");
797    match normalized.as_str() {
798        "level0" | "0" => Ok(0),
799        "level1" | "1" => Ok(1),
800        "level2" | "2" => Ok(2),
801        "level3" | "3" => Ok(3),
802        _ => Err(StatisticalTrailingStopError::InvalidBaseLevel {
803            base_level: value.to_string(),
804        }),
805    }
806}
807
808#[inline(always)]
809fn canonical_base_level(value: &str) -> Result<&'static str, StatisticalTrailingStopError> {
810    match parse_base_level(value)? {
811        0 => Ok("level0"),
812        1 => Ok("level1"),
813        2 => Ok("level2"),
814        _ => Ok("level3"),
815    }
816}
817
818#[inline(always)]
819fn analyze_valid_segments(
820    high: &[f64],
821    low: &[f64],
822    close: &[f64],
823) -> Result<(usize, usize), StatisticalTrailingStopError> {
824    let mut first_valid = None;
825    let mut current = 0usize;
826    let mut max_run = 0usize;
827    for i in 0..close.len() {
828        let valid = high[i].is_finite() && low[i].is_finite() && close[i].is_finite();
829        if valid {
830            if first_valid.is_none() {
831                first_valid = Some(i);
832            }
833            current += 1;
834            max_run = max_run.max(current);
835        } else {
836            current = 0;
837        }
838    }
839    let first_valid = first_valid.ok_or(StatisticalTrailingStopError::AllValuesNaN)?;
840    Ok((first_valid, max_run))
841}
842
843fn prepare_input<'a>(
844    input: &'a StatisticalTrailingStopInput<'a>,
845    _kernel: Kernel,
846) -> Result<PreparedInput<'a>, StatisticalTrailingStopError> {
847    let (high, low, close) = input.as_hlc();
848    if high.is_empty() || low.is_empty() || close.is_empty() {
849        return Err(StatisticalTrailingStopError::EmptyInputData);
850    }
851    if high.len() != low.len() || high.len() != close.len() {
852        return Err(StatisticalTrailingStopError::DataLengthMismatch {
853            high_len: high.len(),
854            low_len: low.len(),
855            close_len: close.len(),
856        });
857    }
858
859    let data_length = input.get_data_length();
860    let normalization_length = input.get_normalization_length();
861    validate_periods(data_length, normalization_length, close.len())?;
862    let base_level_index = parse_base_level(input.get_base_level())?;
863    let (first_valid, max_run) = analyze_valid_segments(high, low, close)?;
864    let needed = data_length + normalization_length + 1;
865    if max_run < needed {
866        return Err(StatisticalTrailingStopError::NotEnoughValidData {
867            needed,
868            valid: max_run,
869        });
870    }
871
872    Ok(PreparedInput {
873        high,
874        low,
875        close,
876        data_length,
877        normalization_length,
878        base_level_index,
879        warmup: first_valid + needed - 1,
880    })
881}
882
883#[inline(always)]
884fn compute_row(
885    high: &[f64],
886    low: &[f64],
887    close: &[f64],
888    params: &StatisticalTrailingStopParams,
889    level_out: &mut [f64],
890    anchor_out: &mut [f64],
891    bias_out: &mut [f64],
892    changed_out: &mut [f64],
893) -> Result<(), StatisticalTrailingStopError> {
894    let len = close.len();
895    let expected = len;
896    if level_out.len() != expected
897        || anchor_out.len() != expected
898        || bias_out.len() != expected
899        || changed_out.len() != expected
900    {
901        return Err(StatisticalTrailingStopError::OutputLengthMismatch {
902            expected,
903            got: level_out
904                .len()
905                .max(anchor_out.len())
906                .max(bias_out.len())
907                .max(changed_out.len()),
908        });
909    }
910
911    let data_length = params.data_length.unwrap_or(DEFAULT_DATA_LENGTH);
912    let normalization_length = params
913        .normalization_length
914        .unwrap_or(DEFAULT_NORMALIZATION_LENGTH);
915    validate_periods(data_length, normalization_length, len)?;
916    let base_level_index =
917        parse_base_level(params.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
918
919    let mut state =
920        StatisticalTrailingStopState::new(data_length, normalization_length, base_level_index);
921    for i in 0..len {
922        if let Some((level, anchor, bias, changed)) = state.update(i, high[i], low[i], close[i]) {
923            level_out[i] = level;
924            anchor_out[i] = anchor;
925            bias_out[i] = bias;
926            changed_out[i] = changed;
927        } else {
928            level_out[i] = f64::NAN;
929            anchor_out[i] = f64::NAN;
930            bias_out[i] = f64::NAN;
931            changed_out[i] = f64::NAN;
932        }
933    }
934    Ok(())
935}
936
937#[inline]
938pub fn statistical_trailing_stop(
939    input: &StatisticalTrailingStopInput,
940) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
941    statistical_trailing_stop_with_kernel(input, Kernel::Auto)
942}
943
944pub fn statistical_trailing_stop_with_kernel(
945    input: &StatisticalTrailingStopInput,
946    kernel: Kernel,
947) -> Result<StatisticalTrailingStopOutput, StatisticalTrailingStopError> {
948    let prepared = prepare_input(input, kernel)?;
949    let len = prepared.close.len();
950    let mut level = alloc_with_nan_prefix(len, prepared.warmup);
951    let mut anchor = alloc_with_nan_prefix(len, prepared.warmup);
952    let mut bias = alloc_with_nan_prefix(len, prepared.warmup);
953    let mut changed = alloc_with_nan_prefix(len, prepared.warmup);
954    compute_row(
955        prepared.high,
956        prepared.low,
957        prepared.close,
958        &StatisticalTrailingStopParams {
959            data_length: Some(prepared.data_length),
960            normalization_length: Some(prepared.normalization_length),
961            base_level: Some(
962                match prepared.base_level_index {
963                    0 => "level0",
964                    1 => "level1",
965                    2 => "level2",
966                    _ => "level3",
967                }
968                .to_string(),
969            ),
970        },
971        &mut level,
972        &mut anchor,
973        &mut bias,
974        &mut changed,
975    )?;
976    Ok(StatisticalTrailingStopOutput {
977        level,
978        anchor,
979        bias,
980        changed,
981    })
982}
983
984#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
985pub fn statistical_trailing_stop_into(
986    level_out: &mut [f64],
987    anchor_out: &mut [f64],
988    bias_out: &mut [f64],
989    changed_out: &mut [f64],
990    input: &StatisticalTrailingStopInput,
991) -> Result<(), StatisticalTrailingStopError> {
992    statistical_trailing_stop_into_slice(
993        level_out,
994        anchor_out,
995        bias_out,
996        changed_out,
997        input,
998        Kernel::Auto,
999    )
1000}
1001
1002pub fn statistical_trailing_stop_into_slice(
1003    level_out: &mut [f64],
1004    anchor_out: &mut [f64],
1005    bias_out: &mut [f64],
1006    changed_out: &mut [f64],
1007    input: &StatisticalTrailingStopInput,
1008    kernel: Kernel,
1009) -> Result<(), StatisticalTrailingStopError> {
1010    let prepared = prepare_input(input, kernel)?;
1011    compute_row(
1012        prepared.high,
1013        prepared.low,
1014        prepared.close,
1015        &StatisticalTrailingStopParams {
1016            data_length: Some(prepared.data_length),
1017            normalization_length: Some(prepared.normalization_length),
1018            base_level: Some(
1019                match prepared.base_level_index {
1020                    0 => "level0",
1021                    1 => "level1",
1022                    2 => "level2",
1023                    _ => "level3",
1024                }
1025                .to_string(),
1026            ),
1027        },
1028        level_out,
1029        anchor_out,
1030        bias_out,
1031        changed_out,
1032    )
1033}
1034
1035#[inline(always)]
1036pub fn expand_grid(
1037    sweep: &StatisticalTrailingStopBatchRange,
1038) -> Result<Vec<StatisticalTrailingStopParams>, StatisticalTrailingStopError> {
1039    fn axis_usize(
1040        axis: &'static str,
1041        (start, end, step): (usize, usize, usize),
1042    ) -> Result<Vec<usize>, StatisticalTrailingStopError> {
1043        if step == 0 || start == end {
1044            return Ok(vec![start]);
1045        }
1046        let mut out = Vec::new();
1047        if start < end {
1048            let mut value = start;
1049            let stride = step.max(1);
1050            while value <= end {
1051                out.push(value);
1052                match value.checked_add(stride) {
1053                    Some(next) => value = next,
1054                    None => break,
1055                }
1056            }
1057        } else {
1058            let mut value = start as isize;
1059            let stop = end as isize;
1060            let stride = step.max(1) as isize;
1061            while value >= stop {
1062                out.push(value as usize);
1063                value -= stride;
1064            }
1065        }
1066        if out.is_empty() {
1067            return Err(StatisticalTrailingStopError::InvalidRange {
1068                axis,
1069                start: start.to_string(),
1070                end: end.to_string(),
1071                step: step.to_string(),
1072            });
1073        }
1074        Ok(out)
1075    }
1076
1077    fn axis_base_level(
1078        axis: &'static str,
1079        (start, end, step): (String, String, usize),
1080    ) -> Result<Vec<String>, StatisticalTrailingStopError> {
1081        let start_idx = parse_base_level(&start)?;
1082        let end_idx = parse_base_level(&end)?;
1083        if step == 0 || start_idx == end_idx {
1084            return Ok(vec![canonical_base_level(&start)?.to_string()]);
1085        }
1086        let mut out = Vec::new();
1087        if start_idx < end_idx {
1088            let mut idx = start_idx;
1089            let stride = step.max(1);
1090            while idx <= end_idx {
1091                out.push(
1092                    match idx {
1093                        0 => "level0",
1094                        1 => "level1",
1095                        2 => "level2",
1096                        _ => "level3",
1097                    }
1098                    .to_string(),
1099                );
1100                match idx.checked_add(stride) {
1101                    Some(next) => idx = next,
1102                    None => break,
1103                }
1104            }
1105        } else {
1106            let mut idx = start_idx as isize;
1107            let stop = end_idx as isize;
1108            let stride = step.max(1) as isize;
1109            while idx >= stop {
1110                out.push(
1111                    match idx as usize {
1112                        0 => "level0",
1113                        1 => "level1",
1114                        2 => "level2",
1115                        _ => "level3",
1116                    }
1117                    .to_string(),
1118                );
1119                idx -= stride;
1120            }
1121        }
1122        if out.is_empty() {
1123            return Err(StatisticalTrailingStopError::InvalidRange {
1124                axis,
1125                start,
1126                end,
1127                step: step.to_string(),
1128            });
1129        }
1130        Ok(out)
1131    }
1132
1133    let data_lengths = axis_usize("data_length", sweep.data_length)?;
1134    let normalization_lengths = axis_usize("normalization_length", sweep.normalization_length)?;
1135    let base_levels = axis_base_level("base_level", sweep.base_level.clone())?;
1136
1137    let cap = data_lengths
1138        .len()
1139        .checked_mul(normalization_lengths.len())
1140        .and_then(|v| v.checked_mul(base_levels.len()))
1141        .ok_or(StatisticalTrailingStopError::InvalidRange {
1142            axis: "grid",
1143            start: "cap".to_string(),
1144            end: "overflow".to_string(),
1145            step: "mul".to_string(),
1146        })?;
1147
1148    let mut out = Vec::with_capacity(cap);
1149    for &data_length in &data_lengths {
1150        for &normalization_length in &normalization_lengths {
1151            for base_level in &base_levels {
1152                out.push(StatisticalTrailingStopParams {
1153                    data_length: Some(data_length),
1154                    normalization_length: Some(normalization_length),
1155                    base_level: Some(base_level.clone()),
1156                });
1157            }
1158        }
1159    }
1160    Ok(out)
1161}
1162
1163fn statistical_trailing_stop_batch_inner_into(
1164    high: &[f64],
1165    low: &[f64],
1166    close: &[f64],
1167    sweep: &StatisticalTrailingStopBatchRange,
1168    parallel: bool,
1169    level_out: &mut [f64],
1170    anchor_out: &mut [f64],
1171    bias_out: &mut [f64],
1172    changed_out: &mut [f64],
1173) -> Result<Vec<StatisticalTrailingStopParams>, StatisticalTrailingStopError> {
1174    if high.is_empty() || low.is_empty() || close.is_empty() {
1175        return Err(StatisticalTrailingStopError::EmptyInputData);
1176    }
1177    if high.len() != low.len() || high.len() != close.len() {
1178        return Err(StatisticalTrailingStopError::DataLengthMismatch {
1179            high_len: high.len(),
1180            low_len: low.len(),
1181            close_len: close.len(),
1182        });
1183    }
1184    let combos = expand_grid(sweep)?;
1185    let rows = combos.len();
1186    let cols = close.len();
1187    let expected =
1188        rows.checked_mul(cols)
1189            .ok_or(StatisticalTrailingStopError::OutputLengthMismatch {
1190                expected: usize::MAX,
1191                got: level_out.len(),
1192            })?;
1193    if level_out.len() != expected
1194        || anchor_out.len() != expected
1195        || bias_out.len() != expected
1196        || changed_out.len() != expected
1197    {
1198        return Err(StatisticalTrailingStopError::OutputLengthMismatch {
1199            expected,
1200            got: level_out
1201                .len()
1202                .max(anchor_out.len())
1203                .max(bias_out.len())
1204                .max(changed_out.len()),
1205        });
1206    }
1207    let (_, max_run) = analyze_valid_segments(high, low, close)?;
1208
1209    for params in &combos {
1210        let data_length = params.data_length.unwrap_or(DEFAULT_DATA_LENGTH);
1211        let normalization_length = params
1212            .normalization_length
1213            .unwrap_or(DEFAULT_NORMALIZATION_LENGTH);
1214        validate_periods(data_length, normalization_length, cols)?;
1215        let needed = data_length + normalization_length + 1;
1216        if max_run < needed {
1217            return Err(StatisticalTrailingStopError::NotEnoughValidData {
1218                needed,
1219                valid: max_run,
1220            });
1221        }
1222        let _ = parse_base_level(params.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
1223    }
1224
1225    let do_row = |row: usize,
1226                  level_row: &mut [f64],
1227                  anchor_row: &mut [f64],
1228                  bias_row: &mut [f64],
1229                  changed_row: &mut [f64]| {
1230        compute_row(
1231            high,
1232            low,
1233            close,
1234            &combos[row],
1235            level_row,
1236            anchor_row,
1237            bias_row,
1238            changed_row,
1239        )
1240    };
1241
1242    if parallel {
1243        #[cfg(not(target_arch = "wasm32"))]
1244        {
1245            level_out
1246                .par_chunks_mut(cols)
1247                .zip(anchor_out.par_chunks_mut(cols))
1248                .zip(bias_out.par_chunks_mut(cols))
1249                .zip(changed_out.par_chunks_mut(cols))
1250                .enumerate()
1251                .try_for_each(
1252                    |(row, (((level_row, anchor_row), bias_row), changed_row))| {
1253                        do_row(row, level_row, anchor_row, bias_row, changed_row)
1254                    },
1255                )?;
1256        }
1257        #[cfg(target_arch = "wasm32")]
1258        {
1259            for (row, (((level_row, anchor_row), bias_row), changed_row)) in level_out
1260                .chunks_mut(cols)
1261                .zip(anchor_out.chunks_mut(cols))
1262                .zip(bias_out.chunks_mut(cols))
1263                .zip(changed_out.chunks_mut(cols))
1264                .enumerate()
1265            {
1266                do_row(row, level_row, anchor_row, bias_row, changed_row)?;
1267            }
1268        }
1269    } else {
1270        for (row, (((level_row, anchor_row), bias_row), changed_row)) in level_out
1271            .chunks_mut(cols)
1272            .zip(anchor_out.chunks_mut(cols))
1273            .zip(bias_out.chunks_mut(cols))
1274            .zip(changed_out.chunks_mut(cols))
1275            .enumerate()
1276        {
1277            do_row(row, level_row, anchor_row, bias_row, changed_row)?;
1278        }
1279    }
1280
1281    Ok(combos)
1282}
1283
1284pub fn statistical_trailing_stop_batch_with_kernel(
1285    high: &[f64],
1286    low: &[f64],
1287    close: &[f64],
1288    sweep: &StatisticalTrailingStopBatchRange,
1289    kernel: Kernel,
1290) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1291    match kernel {
1292        Kernel::Auto => {
1293            let _ = detect_best_batch_kernel();
1294        }
1295        k if !k.is_batch() => return Err(StatisticalTrailingStopError::InvalidKernelForBatch(k)),
1296        _ => {}
1297    }
1298    statistical_trailing_stop_batch_par_slice(high, low, close, sweep, Kernel::ScalarBatch)
1299}
1300
1301pub fn statistical_trailing_stop_batch_slice(
1302    high: &[f64],
1303    low: &[f64],
1304    close: &[f64],
1305    sweep: &StatisticalTrailingStopBatchRange,
1306    _kernel: Kernel,
1307) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1308    statistical_trailing_stop_batch_impl(high, low, close, sweep, false)
1309}
1310
1311pub fn statistical_trailing_stop_batch_par_slice(
1312    high: &[f64],
1313    low: &[f64],
1314    close: &[f64],
1315    sweep: &StatisticalTrailingStopBatchRange,
1316    _kernel: Kernel,
1317) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1318    statistical_trailing_stop_batch_impl(high, low, close, sweep, true)
1319}
1320
1321fn statistical_trailing_stop_batch_impl(
1322    high: &[f64],
1323    low: &[f64],
1324    close: &[f64],
1325    sweep: &StatisticalTrailingStopBatchRange,
1326    parallel: bool,
1327) -> Result<StatisticalTrailingStopBatchOutput, StatisticalTrailingStopError> {
1328    let rows = expand_grid(sweep)?.len();
1329    let cols = close.len();
1330
1331    let mut level_mu = make_uninit_matrix(rows, cols);
1332    let mut anchor_mu = make_uninit_matrix(rows, cols);
1333    let mut bias_mu = make_uninit_matrix(rows, cols);
1334    let mut changed_mu = make_uninit_matrix(rows, cols);
1335
1336    let mut level_guard = ManuallyDrop::new(level_mu);
1337    let mut anchor_guard = ManuallyDrop::new(anchor_mu);
1338    let mut bias_guard = ManuallyDrop::new(bias_mu);
1339    let mut changed_guard = ManuallyDrop::new(changed_mu);
1340
1341    let level_out: &mut [f64] = unsafe {
1342        core::slice::from_raw_parts_mut(level_guard.as_mut_ptr() as *mut f64, level_guard.len())
1343    };
1344    let anchor_out: &mut [f64] = unsafe {
1345        core::slice::from_raw_parts_mut(anchor_guard.as_mut_ptr() as *mut f64, anchor_guard.len())
1346    };
1347    let bias_out: &mut [f64] = unsafe {
1348        core::slice::from_raw_parts_mut(bias_guard.as_mut_ptr() as *mut f64, bias_guard.len())
1349    };
1350    let changed_out: &mut [f64] = unsafe {
1351        core::slice::from_raw_parts_mut(changed_guard.as_mut_ptr() as *mut f64, changed_guard.len())
1352    };
1353
1354    let combos = statistical_trailing_stop_batch_inner_into(
1355        high,
1356        low,
1357        close,
1358        sweep,
1359        parallel,
1360        level_out,
1361        anchor_out,
1362        bias_out,
1363        changed_out,
1364    )?;
1365
1366    let level = unsafe {
1367        Vec::from_raw_parts(
1368            level_guard.as_mut_ptr() as *mut f64,
1369            level_guard.len(),
1370            level_guard.capacity(),
1371        )
1372    };
1373    let anchor = unsafe {
1374        Vec::from_raw_parts(
1375            anchor_guard.as_mut_ptr() as *mut f64,
1376            anchor_guard.len(),
1377            anchor_guard.capacity(),
1378        )
1379    };
1380    let bias = unsafe {
1381        Vec::from_raw_parts(
1382            bias_guard.as_mut_ptr() as *mut f64,
1383            bias_guard.len(),
1384            bias_guard.capacity(),
1385        )
1386    };
1387    let changed = unsafe {
1388        Vec::from_raw_parts(
1389            changed_guard.as_mut_ptr() as *mut f64,
1390            changed_guard.len(),
1391            changed_guard.capacity(),
1392        )
1393    };
1394
1395    Ok(StatisticalTrailingStopBatchOutput {
1396        level,
1397        anchor,
1398        bias,
1399        changed,
1400        combos,
1401        rows,
1402        cols,
1403    })
1404}
1405
1406#[cfg(feature = "python")]
1407#[pyfunction(name = "statistical_trailing_stop")]
1408#[pyo3(signature = (high, low, close, data_length=DEFAULT_DATA_LENGTH, normalization_length=DEFAULT_NORMALIZATION_LENGTH, base_level=DEFAULT_BASE_LEVEL, kernel=None))]
1409pub fn statistical_trailing_stop_py<'py>(
1410    py: Python<'py>,
1411    high: PyReadonlyArray1<'py, f64>,
1412    low: PyReadonlyArray1<'py, f64>,
1413    close: PyReadonlyArray1<'py, f64>,
1414    data_length: usize,
1415    normalization_length: usize,
1416    base_level: &str,
1417    kernel: Option<&str>,
1418) -> PyResult<(
1419    Bound<'py, PyArray1<f64>>,
1420    Bound<'py, PyArray1<f64>>,
1421    Bound<'py, PyArray1<f64>>,
1422    Bound<'py, PyArray1<f64>>,
1423)> {
1424    let high_slice = high.as_slice()?;
1425    let low_slice = low.as_slice()?;
1426    let close_slice = close.as_slice()?;
1427    let kernel = validate_kernel(kernel, false)?;
1428    let input = StatisticalTrailingStopInput::from_slices(
1429        high_slice,
1430        low_slice,
1431        close_slice,
1432        StatisticalTrailingStopParams {
1433            data_length: Some(data_length),
1434            normalization_length: Some(normalization_length),
1435            base_level: Some(base_level.to_string()),
1436        },
1437    );
1438    let output = py
1439        .allow_threads(|| statistical_trailing_stop_with_kernel(&input, kernel))
1440        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1441    Ok((
1442        output.level.into_pyarray(py),
1443        output.anchor.into_pyarray(py),
1444        output.bias.into_pyarray(py),
1445        output.changed.into_pyarray(py),
1446    ))
1447}
1448
1449#[cfg(feature = "python")]
1450#[pyfunction(name = "statistical_trailing_stop_batch")]
1451#[pyo3(signature = (high, low, close, data_length_range=(DEFAULT_DATA_LENGTH, DEFAULT_DATA_LENGTH, 0), normalization_length_range=(DEFAULT_NORMALIZATION_LENGTH, DEFAULT_NORMALIZATION_LENGTH, 0), base_level_range=(DEFAULT_BASE_LEVEL.to_string(), DEFAULT_BASE_LEVEL.to_string(), 0usize), kernel=None))]
1452pub fn statistical_trailing_stop_batch_py<'py>(
1453    py: Python<'py>,
1454    high: PyReadonlyArray1<'py, f64>,
1455    low: PyReadonlyArray1<'py, f64>,
1456    close: PyReadonlyArray1<'py, f64>,
1457    data_length_range: (usize, usize, usize),
1458    normalization_length_range: (usize, usize, usize),
1459    base_level_range: (String, String, usize),
1460    kernel: Option<&str>,
1461) -> PyResult<Bound<'py, PyDict>> {
1462    let high_slice = high.as_slice()?;
1463    let low_slice = low.as_slice()?;
1464    let close_slice = close.as_slice()?;
1465    let kernel = validate_kernel(kernel, true)?;
1466    let sweep = StatisticalTrailingStopBatchRange {
1467        data_length: data_length_range,
1468        normalization_length: normalization_length_range,
1469        base_level: base_level_range,
1470    };
1471
1472    let rows = expand_grid(&sweep)
1473        .map_err(|e| PyValueError::new_err(e.to_string()))?
1474        .len();
1475    let cols = close_slice.len();
1476    let total = rows.checked_mul(cols).ok_or_else(|| {
1477        PyValueError::new_err("rows*cols overflow in statistical_trailing_stop_batch")
1478    })?;
1479
1480    let level_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1481    let anchor_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1482    let bias_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1483    let changed_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1484
1485    let level_out = unsafe { level_arr.as_slice_mut()? };
1486    let anchor_out = unsafe { anchor_arr.as_slice_mut()? };
1487    let bias_out = unsafe { bias_arr.as_slice_mut()? };
1488    let changed_out = unsafe { changed_arr.as_slice_mut()? };
1489
1490    let combos = py
1491        .allow_threads(|| {
1492            statistical_trailing_stop_batch_inner_into(
1493                high_slice,
1494                low_slice,
1495                close_slice,
1496                &sweep,
1497                !matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch),
1498                level_out,
1499                anchor_out,
1500                bias_out,
1501                changed_out,
1502            )
1503        })
1504        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1505
1506    let dict = PyDict::new(py);
1507    dict.set_item("level", level_arr.reshape((rows, cols))?)?;
1508    dict.set_item("anchor", anchor_arr.reshape((rows, cols))?)?;
1509    dict.set_item("bias", bias_arr.reshape((rows, cols))?)?;
1510    dict.set_item("changed", changed_arr.reshape((rows, cols))?)?;
1511    dict.set_item(
1512        "data_lengths",
1513        combos
1514            .iter()
1515            .map(|c| c.data_length.unwrap_or(DEFAULT_DATA_LENGTH) as u64)
1516            .collect::<Vec<_>>()
1517            .into_pyarray(py),
1518    )?;
1519    dict.set_item(
1520        "normalization_lengths",
1521        combos
1522            .iter()
1523            .map(|c| {
1524                c.normalization_length
1525                    .unwrap_or(DEFAULT_NORMALIZATION_LENGTH) as u64
1526            })
1527            .collect::<Vec<_>>()
1528            .into_pyarray(py),
1529    )?;
1530    let base_levels = PyList::empty(py);
1531    for combo in &combos {
1532        base_levels.append(combo.base_level.as_deref().unwrap_or(DEFAULT_BASE_LEVEL))?;
1533    }
1534    dict.set_item("base_levels", base_levels)?;
1535    dict.set_item("rows", rows)?;
1536    dict.set_item("cols", cols)?;
1537    Ok(dict)
1538}
1539
1540#[cfg(feature = "python")]
1541#[pyclass(name = "StatisticalTrailingStopStream")]
1542pub struct StatisticalTrailingStopStreamPy {
1543    stream: StatisticalTrailingStopStream,
1544}
1545
1546#[cfg(feature = "python")]
1547#[pymethods]
1548impl StatisticalTrailingStopStreamPy {
1549    #[new]
1550    #[pyo3(signature = (data_length=DEFAULT_DATA_LENGTH, normalization_length=DEFAULT_NORMALIZATION_LENGTH, base_level=DEFAULT_BASE_LEVEL))]
1551    fn new(data_length: usize, normalization_length: usize, base_level: &str) -> PyResult<Self> {
1552        let stream = StatisticalTrailingStopStream::try_new(StatisticalTrailingStopParams {
1553            data_length: Some(data_length),
1554            normalization_length: Some(normalization_length),
1555            base_level: Some(base_level.to_string()),
1556        })
1557        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1558        Ok(Self { stream })
1559    }
1560
1561    fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64, f64, f64)> {
1562        self.stream.update(high, low, close)
1563    }
1564}
1565
1566#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1567#[derive(Serialize, Deserialize)]
1568pub struct StatisticalTrailingStopBatchConfig {
1569    pub data_length_range: (usize, usize, usize),
1570    pub normalization_length_range: (usize, usize, usize),
1571    pub base_level_range: (String, String, usize),
1572}
1573
1574#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1575#[derive(Serialize, Deserialize)]
1576pub struct StatisticalTrailingStopBatchJsOutput {
1577    pub level: Vec<f64>,
1578    pub anchor: Vec<f64>,
1579    pub bias: Vec<f64>,
1580    pub changed: Vec<f64>,
1581    pub combos: Vec<StatisticalTrailingStopParams>,
1582    pub rows: usize,
1583    pub cols: usize,
1584}
1585
1586#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1587#[wasm_bindgen]
1588pub fn statistical_trailing_stop_js(
1589    high: &[f64],
1590    low: &[f64],
1591    close: &[f64],
1592    data_length: usize,
1593    normalization_length: usize,
1594    base_level: &str,
1595) -> Result<JsValue, JsValue> {
1596    let input = StatisticalTrailingStopInput::from_slices(
1597        high,
1598        low,
1599        close,
1600        StatisticalTrailingStopParams {
1601            data_length: Some(data_length),
1602            normalization_length: Some(normalization_length),
1603            base_level: Some(base_level.to_string()),
1604        },
1605    );
1606    let output = statistical_trailing_stop_with_kernel(&input, Kernel::Auto)
1607        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1608    serde_wasm_bindgen::to_value(&output)
1609        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1610}
1611
1612#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1613#[wasm_bindgen]
1614pub fn statistical_trailing_stop_alloc(len: usize) -> *mut f64 {
1615    let mut vec = Vec::<f64>::with_capacity(len);
1616    let ptr = vec.as_mut_ptr();
1617    std::mem::forget(vec);
1618    ptr
1619}
1620
1621#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1622#[wasm_bindgen]
1623pub fn statistical_trailing_stop_free(ptr: *mut f64, len: usize) {
1624    if !ptr.is_null() {
1625        unsafe {
1626            let _ = Vec::from_raw_parts(ptr, len, len);
1627        }
1628    }
1629}
1630
1631#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1632#[wasm_bindgen]
1633pub fn statistical_trailing_stop_into(
1634    high_ptr: *const f64,
1635    low_ptr: *const f64,
1636    close_ptr: *const f64,
1637    level_ptr: *mut f64,
1638    anchor_ptr: *mut f64,
1639    bias_ptr: *mut f64,
1640    changed_ptr: *mut f64,
1641    len: usize,
1642    data_length: usize,
1643    normalization_length: usize,
1644    base_level: &str,
1645) -> Result<(), JsValue> {
1646    if high_ptr.is_null()
1647        || low_ptr.is_null()
1648        || close_ptr.is_null()
1649        || level_ptr.is_null()
1650        || anchor_ptr.is_null()
1651        || bias_ptr.is_null()
1652        || changed_ptr.is_null()
1653    {
1654        return Err(JsValue::from_str("Null pointer provided"));
1655    }
1656
1657    unsafe {
1658        let high = std::slice::from_raw_parts(high_ptr, len);
1659        let low = std::slice::from_raw_parts(low_ptr, len);
1660        let close = std::slice::from_raw_parts(close_ptr, len);
1661        let input = StatisticalTrailingStopInput::from_slices(
1662            high,
1663            low,
1664            close,
1665            StatisticalTrailingStopParams {
1666                data_length: Some(data_length),
1667                normalization_length: Some(normalization_length),
1668                base_level: Some(base_level.to_string()),
1669            },
1670        );
1671
1672        let aliased = [
1673            high_ptr as *const u8,
1674            low_ptr as *const u8,
1675            close_ptr as *const u8,
1676        ]
1677        .iter()
1678        .any(|&inp| {
1679            [
1680                level_ptr as *const u8,
1681                anchor_ptr as *const u8,
1682                bias_ptr as *const u8,
1683                changed_ptr as *const u8,
1684            ]
1685            .iter()
1686            .any(|&out| inp == out)
1687        }) || level_ptr == anchor_ptr
1688            || level_ptr == bias_ptr
1689            || level_ptr == changed_ptr
1690            || anchor_ptr == bias_ptr
1691            || anchor_ptr == changed_ptr
1692            || bias_ptr == changed_ptr;
1693
1694        if aliased {
1695            let output = statistical_trailing_stop_with_kernel(&input, Kernel::Auto)
1696                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1697            std::slice::from_raw_parts_mut(level_ptr, len).copy_from_slice(&output.level);
1698            std::slice::from_raw_parts_mut(anchor_ptr, len).copy_from_slice(&output.anchor);
1699            std::slice::from_raw_parts_mut(bias_ptr, len).copy_from_slice(&output.bias);
1700            std::slice::from_raw_parts_mut(changed_ptr, len).copy_from_slice(&output.changed);
1701        } else {
1702            let level_out = std::slice::from_raw_parts_mut(level_ptr, len);
1703            let anchor_out = std::slice::from_raw_parts_mut(anchor_ptr, len);
1704            let bias_out = std::slice::from_raw_parts_mut(bias_ptr, len);
1705            let changed_out = std::slice::from_raw_parts_mut(changed_ptr, len);
1706            statistical_trailing_stop_into_slice(
1707                level_out,
1708                anchor_out,
1709                bias_out,
1710                changed_out,
1711                &input,
1712                Kernel::Auto,
1713            )
1714            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1715        }
1716    }
1717
1718    Ok(())
1719}
1720
1721#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1722#[wasm_bindgen(js_name = statistical_trailing_stop_batch)]
1723pub fn statistical_trailing_stop_batch_unified_js(
1724    high: &[f64],
1725    low: &[f64],
1726    close: &[f64],
1727    config: JsValue,
1728) -> Result<JsValue, JsValue> {
1729    let config: StatisticalTrailingStopBatchConfig = serde_wasm_bindgen::from_value(config)
1730        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1731    let sweep = StatisticalTrailingStopBatchRange {
1732        data_length: config.data_length_range,
1733        normalization_length: config.normalization_length_range,
1734        base_level: config.base_level_range,
1735    };
1736    let output =
1737        statistical_trailing_stop_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1738            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1739    let js_output = StatisticalTrailingStopBatchJsOutput {
1740        level: output.level,
1741        anchor: output.anchor,
1742        bias: output.bias,
1743        changed: output.changed,
1744        combos: output.combos,
1745        rows: output.rows,
1746        cols: output.cols,
1747    };
1748    serde_wasm_bindgen::to_value(&js_output)
1749        .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1750}
1751
1752#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1753#[wasm_bindgen]
1754pub fn statistical_trailing_stop_batch_into(
1755    high_ptr: *const f64,
1756    low_ptr: *const f64,
1757    close_ptr: *const f64,
1758    level_ptr: *mut f64,
1759    anchor_ptr: *mut f64,
1760    bias_ptr: *mut f64,
1761    changed_ptr: *mut f64,
1762    len: usize,
1763    data_length_start: usize,
1764    data_length_end: usize,
1765    data_length_step: usize,
1766    normalization_length_start: usize,
1767    normalization_length_end: usize,
1768    normalization_length_step: usize,
1769    base_level_start: &str,
1770    base_level_end: &str,
1771    base_level_step: usize,
1772) -> Result<usize, JsValue> {
1773    if high_ptr.is_null()
1774        || low_ptr.is_null()
1775        || close_ptr.is_null()
1776        || level_ptr.is_null()
1777        || anchor_ptr.is_null()
1778        || bias_ptr.is_null()
1779        || changed_ptr.is_null()
1780    {
1781        return Err(JsValue::from_str("Null pointer provided"));
1782    }
1783    let sweep = StatisticalTrailingStopBatchRange {
1784        data_length: (data_length_start, data_length_end, data_length_step),
1785        normalization_length: (
1786            normalization_length_start,
1787            normalization_length_end,
1788            normalization_length_step,
1789        ),
1790        base_level: (
1791            base_level_start.to_string(),
1792            base_level_end.to_string(),
1793            base_level_step,
1794        ),
1795    };
1796    let rows = expand_grid(&sweep)
1797        .map_err(|e| JsValue::from_str(&e.to_string()))?
1798        .len();
1799    let total = rows
1800        .checked_mul(len)
1801        .ok_or_else(|| JsValue::from_str("rows*len overflow"))?;
1802
1803    unsafe {
1804        let high = std::slice::from_raw_parts(high_ptr, len);
1805        let low = std::slice::from_raw_parts(low_ptr, len);
1806        let close = std::slice::from_raw_parts(close_ptr, len);
1807        let level_out = std::slice::from_raw_parts_mut(level_ptr, total);
1808        let anchor_out = std::slice::from_raw_parts_mut(anchor_ptr, total);
1809        let bias_out = std::slice::from_raw_parts_mut(bias_ptr, total);
1810        let changed_out = std::slice::from_raw_parts_mut(changed_ptr, total);
1811        statistical_trailing_stop_batch_inner_into(
1812            high,
1813            low,
1814            close,
1815            &sweep,
1816            false,
1817            level_out,
1818            anchor_out,
1819            bias_out,
1820            changed_out,
1821        )
1822        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1823    }
1824
1825    Ok(rows)
1826}
1827
1828#[cfg(test)]
1829mod tests {
1830    use super::*;
1831    use crate::utilities::data_loader::read_candles_from_csv;
1832
1833    fn trend_data(size: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1834        let mut high = Vec::with_capacity(size);
1835        let mut low = Vec::with_capacity(size);
1836        let mut close = Vec::with_capacity(size);
1837        for i in 0..size {
1838            let base = 100.0 + i as f64 * 0.9;
1839            high.push(base + 1.0 + (i % 3) as f64 * 0.1);
1840            low.push(base - 1.0 - (i % 2) as f64 * 0.1);
1841            close.push(base);
1842        }
1843        (high, low, close)
1844    }
1845
1846    fn reversal_data() -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1847        let close = vec![
1848            100.0, 99.5, 99.0, 98.5, 98.0, 97.5, 97.0, 96.5, 96.0, 95.5, 97.0, 99.0, 101.0, 104.0,
1849            108.0, 111.0, 113.0, 112.0, 111.0, 109.0, 107.0, 105.0, 103.0, 101.0, 99.0, 97.0, 95.0,
1850            94.0, 93.5, 93.0,
1851        ];
1852        let high = close.iter().map(|v| v + 0.8).collect::<Vec<_>>();
1853        let low = close.iter().map(|v| v - 0.8).collect::<Vec<_>>();
1854        (high, low, close)
1855    }
1856
1857    fn arrays_eq_nan(a: &[f64], b: &[f64]) -> bool {
1858        a.len() == b.len()
1859            && a.iter().zip(b.iter()).all(|(x, y)| {
1860                (x.is_nan() && y.is_nan())
1861                    || (!x.is_nan() && !y.is_nan() && (*x - *y).abs() <= 1e-12)
1862            })
1863    }
1864
1865    #[test]
1866    fn statistical_trailing_stop_into_matches_single() -> Result<(), Box<dyn StdError>> {
1867        let (high, low, close) = trend_data(160);
1868        let input = StatisticalTrailingStopInput::from_slices(
1869            &high,
1870            &low,
1871            &close,
1872            StatisticalTrailingStopParams::default(),
1873        );
1874        let single = statistical_trailing_stop(&input)?;
1875
1876        let mut level = vec![0.0; close.len()];
1877        let mut anchor = vec![0.0; close.len()];
1878        let mut bias = vec![0.0; close.len()];
1879        let mut changed = vec![0.0; close.len()];
1880        statistical_trailing_stop_into_slice(
1881            &mut level,
1882            &mut anchor,
1883            &mut bias,
1884            &mut changed,
1885            &input,
1886            Kernel::Auto,
1887        )?;
1888
1889        assert!(arrays_eq_nan(&single.level, &level));
1890        assert!(arrays_eq_nan(&single.anchor, &anchor));
1891        assert!(arrays_eq_nan(&single.bias, &bias));
1892        assert!(arrays_eq_nan(&single.changed, &changed));
1893        Ok(())
1894    }
1895
1896    #[test]
1897    fn statistical_trailing_stop_stream_matches_batch() -> Result<(), Box<dyn StdError>> {
1898        let (high, low, close) = trend_data(170);
1899        let params = StatisticalTrailingStopParams::default();
1900        let input = StatisticalTrailingStopInput::from_slices(&high, &low, &close, params.clone());
1901        let batch = statistical_trailing_stop(&input)?;
1902
1903        let mut stream = StatisticalTrailingStopStream::try_new(params)?;
1904        let mut level = Vec::with_capacity(close.len());
1905        let mut anchor = Vec::with_capacity(close.len());
1906        let mut bias = Vec::with_capacity(close.len());
1907        let mut changed = Vec::with_capacity(close.len());
1908
1909        for i in 0..close.len() {
1910            if let Some((lvl, anc, bs, ch)) = stream.update(high[i], low[i], close[i]) {
1911                level.push(lvl);
1912                anchor.push(anc);
1913                bias.push(bs);
1914                changed.push(ch);
1915            } else {
1916                level.push(f64::NAN);
1917                anchor.push(f64::NAN);
1918                bias.push(f64::NAN);
1919                changed.push(f64::NAN);
1920            }
1921        }
1922
1923        assert!(arrays_eq_nan(&batch.level, &level));
1924        assert!(arrays_eq_nan(&batch.anchor, &anchor));
1925        assert!(arrays_eq_nan(&batch.bias, &bias));
1926        assert!(arrays_eq_nan(&batch.changed, &changed));
1927        Ok(())
1928    }
1929
1930    #[test]
1931    fn statistical_trailing_stop_reversal_behavior() -> Result<(), Box<dyn StdError>> {
1932        let (high, low, close) = reversal_data();
1933        let output = statistical_trailing_stop(&StatisticalTrailingStopInput::from_slices(
1934            &high,
1935            &low,
1936            &close,
1937            StatisticalTrailingStopParams {
1938                data_length: Some(3),
1939                normalization_length: Some(10),
1940                base_level: Some("level0".to_string()),
1941            },
1942        ))?;
1943
1944        let changes = output
1945            .changed
1946            .iter()
1947            .enumerate()
1948            .filter_map(|(i, v)| if *v == 1.0 { Some(i) } else { None })
1949            .collect::<Vec<_>>();
1950        assert!(!changes.is_empty());
1951        let first = changes[0];
1952        assert!(output.anchor[first].is_finite());
1953        assert_eq!(output.bias[first], 1.0);
1954        Ok(())
1955    }
1956
1957    #[test]
1958    fn statistical_trailing_stop_nan_gap_restarts() -> Result<(), Box<dyn StdError>> {
1959        let (mut high, mut low, mut close) = trend_data(170);
1960        high[120] = f64::NAN;
1961        low[120] = f64::NAN;
1962        close[120] = f64::NAN;
1963        let output = statistical_trailing_stop(&StatisticalTrailingStopInput::from_slices(
1964            &high,
1965            &low,
1966            &close,
1967            StatisticalTrailingStopParams::default(),
1968        ))?;
1969
1970        let restart_end =
1971            (120 + DEFAULT_DATA_LENGTH + DEFAULT_NORMALIZATION_LENGTH + 1).min(output.level.len());
1972        for i in 120..restart_end {
1973            assert!(output.level[i].is_nan());
1974            assert!(output.bias[i].is_nan());
1975        }
1976        Ok(())
1977    }
1978
1979    #[test]
1980    fn statistical_trailing_stop_batch_matches_single() -> Result<(), Box<dyn StdError>> {
1981        let (high, low, close) = trend_data(170);
1982        let sweep = StatisticalTrailingStopBatchRange {
1983            data_length: (9, 10, 1),
1984            normalization_length: (10, 11, 1),
1985            base_level: ("level1".to_string(), "level2".to_string(), 1),
1986        };
1987        let batch = statistical_trailing_stop_batch_with_kernel(
1988            &high,
1989            &low,
1990            &close,
1991            &sweep,
1992            Kernel::ScalarBatch,
1993        )?;
1994
1995        assert_eq!(batch.rows, 8);
1996        assert_eq!(batch.cols, close.len());
1997        for row in 0..batch.rows {
1998            let combo = &batch.combos[row];
1999            let single = statistical_trailing_stop(&StatisticalTrailingStopInput::from_slices(
2000                &high,
2001                &low,
2002                &close,
2003                combo.clone(),
2004            ))?;
2005            let start = row * batch.cols;
2006            let end = start + batch.cols;
2007            assert!(arrays_eq_nan(
2008                &batch.level[start..end],
2009                single.level.as_slice()
2010            ));
2011            assert!(arrays_eq_nan(
2012                &batch.anchor[start..end],
2013                single.anchor.as_slice()
2014            ));
2015            assert!(arrays_eq_nan(
2016                &batch.bias[start..end],
2017                single.bias.as_slice()
2018            ));
2019            assert!(arrays_eq_nan(
2020                &batch.changed[start..end],
2021                single.changed.as_slice()
2022            ));
2023        }
2024        Ok(())
2025    }
2026
2027    #[test]
2028    fn statistical_trailing_stop_invalid_base_level_errors() {
2029        let (high, low, close) = trend_data(160);
2030        let input = StatisticalTrailingStopInput::from_slices(
2031            &high,
2032            &low,
2033            &close,
2034            StatisticalTrailingStopParams {
2035                data_length: Some(10),
2036                normalization_length: Some(100),
2037                base_level: Some("bad".to_string()),
2038            },
2039        );
2040        assert!(matches!(
2041            statistical_trailing_stop(&input),
2042            Err(StatisticalTrailingStopError::InvalidBaseLevel { .. })
2043        ));
2044    }
2045
2046    #[test]
2047    fn statistical_trailing_stop_all_nan_errors() {
2048        let high = vec![f64::NAN; 200];
2049        let low = vec![f64::NAN; 200];
2050        let close = vec![f64::NAN; 200];
2051        let input = StatisticalTrailingStopInput::from_slices(
2052            &high,
2053            &low,
2054            &close,
2055            StatisticalTrailingStopParams::default(),
2056        );
2057        assert!(matches!(
2058            statistical_trailing_stop(&input),
2059            Err(StatisticalTrailingStopError::AllValuesNaN)
2060        ));
2061    }
2062
2063    #[test]
2064    fn statistical_trailing_stop_default_candles_smoke() -> Result<(), Box<dyn StdError>> {
2065        let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
2066        let output = statistical_trailing_stop(
2067            &StatisticalTrailingStopInput::with_default_candles(&candles),
2068        )?;
2069        assert_eq!(output.level.len(), candles.close.len());
2070        assert_eq!(output.anchor.len(), candles.close.len());
2071        assert_eq!(output.bias.len(), candles.close.len());
2072        assert_eq!(output.changed.len(), candles.close.len());
2073        Ok(())
2074    }
2075}