Skip to main content

vector_ta/indicators/
random_walk_index.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::collections::VecDeque;
25use std::mem::ManuallyDrop;
26use thiserror::Error;
27
28const DEFAULT_LENGTH: usize = 14;
29
30#[derive(Debug, Clone)]
31pub enum RandomWalkIndexData<'a> {
32    Candles {
33        candles: &'a Candles,
34    },
35    Slices {
36        high: &'a [f64],
37        low: &'a [f64],
38        close: &'a [f64],
39    },
40}
41
42#[derive(Debug, Clone)]
43pub struct RandomWalkIndexOutput {
44    pub high: Vec<f64>,
45    pub low: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50    all(target_arch = "wasm32", feature = "wasm"),
51    derive(Serialize, Deserialize)
52)]
53pub struct RandomWalkIndexParams {
54    pub length: Option<usize>,
55}
56
57impl Default for RandomWalkIndexParams {
58    fn default() -> Self {
59        Self {
60            length: Some(DEFAULT_LENGTH),
61        }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct RandomWalkIndexInput<'a> {
67    pub data: RandomWalkIndexData<'a>,
68    pub params: RandomWalkIndexParams,
69}
70
71impl<'a> RandomWalkIndexInput<'a> {
72    #[inline]
73    pub fn from_candles(candles: &'a Candles, params: RandomWalkIndexParams) -> Self {
74        Self {
75            data: RandomWalkIndexData::Candles { candles },
76            params,
77        }
78    }
79
80    #[inline]
81    pub fn from_slices(
82        high: &'a [f64],
83        low: &'a [f64],
84        close: &'a [f64],
85        params: RandomWalkIndexParams,
86    ) -> Self {
87        Self {
88            data: RandomWalkIndexData::Slices { high, low, close },
89            params,
90        }
91    }
92
93    #[inline]
94    pub fn with_default_candles(candles: &'a Candles) -> Self {
95        Self::from_candles(candles, RandomWalkIndexParams::default())
96    }
97
98    #[inline]
99    pub fn get_length(&self) -> usize {
100        self.params.length.unwrap_or(DEFAULT_LENGTH)
101    }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct RandomWalkIndexBuilder {
106    length: Option<usize>,
107    kernel: Kernel,
108}
109
110impl Default for RandomWalkIndexBuilder {
111    fn default() -> Self {
112        Self {
113            length: None,
114            kernel: Kernel::Auto,
115        }
116    }
117}
118
119impl RandomWalkIndexBuilder {
120    #[inline(always)]
121    pub fn new() -> Self {
122        Self::default()
123    }
124
125    #[inline(always)]
126    pub fn length(mut self, value: usize) -> Self {
127        self.length = Some(value);
128        self
129    }
130
131    #[inline(always)]
132    pub fn kernel(mut self, value: Kernel) -> Self {
133        self.kernel = value;
134        self
135    }
136
137    #[inline(always)]
138    pub fn apply(self, candles: &Candles) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
139        let input = RandomWalkIndexInput::from_candles(
140            candles,
141            RandomWalkIndexParams {
142                length: self.length,
143            },
144        );
145        random_walk_index_with_kernel(&input, self.kernel)
146    }
147
148    #[inline(always)]
149    pub fn apply_slices(
150        self,
151        high: &[f64],
152        low: &[f64],
153        close: &[f64],
154    ) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
155        let input = RandomWalkIndexInput::from_slices(
156            high,
157            low,
158            close,
159            RandomWalkIndexParams {
160                length: self.length,
161            },
162        );
163        random_walk_index_with_kernel(&input, self.kernel)
164    }
165
166    #[inline(always)]
167    pub fn into_stream(self) -> Result<RandomWalkIndexStream, RandomWalkIndexError> {
168        RandomWalkIndexStream::try_new(RandomWalkIndexParams {
169            length: self.length,
170        })
171    }
172}
173
174#[derive(Debug, Error)]
175pub enum RandomWalkIndexError {
176    #[error("random_walk_index: Input data slice is empty.")]
177    EmptyInputData,
178    #[error("random_walk_index: All values are NaN.")]
179    AllValuesNaN,
180    #[error("random_walk_index: Inconsistent slice lengths: high={high_len}, low={low_len}, close={close_len}")]
181    InconsistentSliceLengths {
182        high_len: usize,
183        low_len: usize,
184        close_len: usize,
185    },
186    #[error("random_walk_index: Invalid length: length={length}, data length={data_len}")]
187    InvalidLength { length: usize, data_len: usize },
188    #[error("random_walk_index: Not enough valid data: needed={needed}, valid={valid}")]
189    NotEnoughValidData { needed: usize, valid: usize },
190    #[error("random_walk_index: Output length mismatch: expected={expected}, got={got}")]
191    OutputLengthMismatch { expected: usize, got: usize },
192    #[error("random_walk_index: Invalid range: start={start}, end={end}, step={step}")]
193    InvalidRange {
194        start: String,
195        end: String,
196        step: String,
197    },
198    #[error("random_walk_index: Invalid kernel for batch: {0:?}")]
199    InvalidKernelForBatch(Kernel),
200}
201
202#[inline(always)]
203fn extract_hlc<'a>(
204    input: &'a RandomWalkIndexInput<'a>,
205) -> Result<(&'a [f64], &'a [f64], &'a [f64]), RandomWalkIndexError> {
206    let (high, low, close) = match &input.data {
207        RandomWalkIndexData::Candles { candles } => (
208            source_type(candles, "high"),
209            source_type(candles, "low"),
210            source_type(candles, "close"),
211        ),
212        RandomWalkIndexData::Slices { high, low, close } => (*high, *low, *close),
213    };
214
215    if high.is_empty() || low.is_empty() || close.is_empty() {
216        return Err(RandomWalkIndexError::EmptyInputData);
217    }
218    if high.len() != low.len() || high.len() != close.len() {
219        return Err(RandomWalkIndexError::InconsistentSliceLengths {
220            high_len: high.len(),
221            low_len: low.len(),
222            close_len: close.len(),
223        });
224    }
225    Ok((high, low, close))
226}
227
228#[inline(always)]
229fn first_valid_hlc(high: &[f64], low: &[f64], close: &[f64]) -> Option<usize> {
230    (0..high.len()).find(|&i| high[i].is_finite() && low[i].is_finite() && close[i].is_finite())
231}
232
233#[inline(always)]
234fn prepare<'a>(
235    input: &'a RandomWalkIndexInput<'a>,
236    kernel: Kernel,
237) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), RandomWalkIndexError> {
238    let (high, low, close) = extract_hlc(input)?;
239    let len = close.len();
240    let length = input.get_length();
241    if length == 0 || length > len {
242        return Err(RandomWalkIndexError::InvalidLength {
243            length,
244            data_len: len,
245        });
246    }
247    let first = first_valid_hlc(high, low, close).ok_or(RandomWalkIndexError::AllValuesNaN)?;
248    let valid = len.saturating_sub(first);
249    if valid < length {
250        return Err(RandomWalkIndexError::NotEnoughValidData {
251            needed: length,
252            valid,
253        });
254    }
255    Ok((high, low, close, length, first, kernel.to_non_batch()))
256}
257
258#[inline(always)]
259fn nz_history(src: &[f64], idx: usize, offset: usize) -> f64 {
260    if idx >= offset {
261        let value = src[idx - offset];
262        if value.is_finite() {
263            value
264        } else {
265            0.0
266        }
267    } else {
268        0.0
269    }
270}
271
272#[inline(always)]
273fn compute_random_walk_index_into(
274    high: &[f64],
275    low: &[f64],
276    close: &[f64],
277    length: usize,
278    first: usize,
279    out_high: &mut [f64],
280    out_low: &mut [f64],
281) {
282    let n = close.len();
283    let warm = first + length - 1;
284    let sqrt_length = (length as f64).sqrt();
285    let alpha = 1.0 / length as f64;
286
287    let mut prev_close = close[first];
288    let mut sum_tr = high[first] - low[first];
289    let mut atr = f64::NAN;
290
291    if length == 1 {
292        atr = sum_tr;
293        let denom = atr * sqrt_length;
294        if denom.is_finite() && denom != 0.0 {
295            out_high[first] = (high[first] - nz_history(low, first, length)) / denom;
296            out_low[first] = (nz_history(high, first, length) - low[first]) / denom;
297        }
298    }
299
300    let mut i = first + 1;
301    while i < n {
302        let tr = (high[i] - low[i])
303            .max((high[i] - prev_close).abs())
304            .max((low[i] - prev_close).abs());
305
306        if i <= warm {
307            sum_tr += tr;
308            if i == warm {
309                atr = sum_tr / length as f64;
310            }
311        } else {
312            atr = alpha.mul_add(tr - atr, atr);
313        }
314
315        if i >= warm {
316            let denom = atr * sqrt_length;
317            if denom.is_finite() && denom != 0.0 {
318                out_high[i] = (high[i] - nz_history(low, i, length)) / denom;
319                out_low[i] = (nz_history(high, i, length) - low[i]) / denom;
320            } else {
321                out_high[i] = f64::NAN;
322                out_low[i] = f64::NAN;
323            }
324        }
325
326        prev_close = close[i];
327        i += 1;
328    }
329}
330
331#[inline]
332pub fn random_walk_index(
333    input: &RandomWalkIndexInput,
334) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
335    random_walk_index_with_kernel(input, Kernel::Auto)
336}
337
338pub fn random_walk_index_with_kernel(
339    input: &RandomWalkIndexInput,
340    kernel: Kernel,
341) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
342    let (high, low, close, length, first, _) = prepare(input, kernel)?;
343    let warm = first + length - 1;
344    let mut out_high = alloc_with_nan_prefix(close.len(), warm);
345    let mut out_low = alloc_with_nan_prefix(close.len(), warm);
346    compute_random_walk_index_into(high, low, close, length, first, &mut out_high, &mut out_low);
347    Ok(RandomWalkIndexOutput {
348        high: out_high,
349        low: out_low,
350    })
351}
352
353#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
354pub fn random_walk_index_into(
355    out_high: &mut [f64],
356    out_low: &mut [f64],
357    input: &RandomWalkIndexInput,
358    kernel: Kernel,
359) -> Result<(), RandomWalkIndexError> {
360    random_walk_index_into_slice(out_high, out_low, input, kernel)
361}
362
363pub fn random_walk_index_into_slice(
364    out_high: &mut [f64],
365    out_low: &mut [f64],
366    input: &RandomWalkIndexInput,
367    kernel: Kernel,
368) -> Result<(), RandomWalkIndexError> {
369    let (high, low, close, length, first, _) = prepare(input, kernel)?;
370    let expected = close.len();
371    if out_high.len() != expected || out_low.len() != expected {
372        return Err(RandomWalkIndexError::OutputLengthMismatch {
373            expected,
374            got: out_high.len().max(out_low.len()),
375        });
376    }
377    let warm = first + length - 1;
378    out_high[..warm.min(expected)].fill(f64::NAN);
379    out_low[..warm.min(expected)].fill(f64::NAN);
380    compute_random_walk_index_into(high, low, close, length, first, out_high, out_low);
381    Ok(())
382}
383
384#[derive(Debug, Clone)]
385pub struct RandomWalkIndexStream {
386    length: usize,
387    sqrt_length: f64,
388    count: usize,
389    warm_sum: f64,
390    atr: f64,
391    prev_close: f64,
392    history_high: VecDeque<f64>,
393    history_low: VecDeque<f64>,
394}
395
396impl RandomWalkIndexStream {
397    pub fn try_new(params: RandomWalkIndexParams) -> Result<Self, RandomWalkIndexError> {
398        let length = params.length.unwrap_or(DEFAULT_LENGTH);
399        if length == 0 {
400            return Err(RandomWalkIndexError::InvalidLength {
401                length,
402                data_len: 0,
403            });
404        }
405        Ok(Self {
406            length,
407            sqrt_length: (length as f64).sqrt(),
408            count: 0,
409            warm_sum: 0.0,
410            atr: f64::NAN,
411            prev_close: f64::NAN,
412            history_high: VecDeque::with_capacity(length),
413            history_low: VecDeque::with_capacity(length),
414        })
415    }
416
417    #[inline]
418    pub fn update(&mut self, high: f64, low: f64, close: f64) -> (f64, f64) {
419        if !high.is_finite() || !low.is_finite() || !close.is_finite() {
420            return (f64::NAN, f64::NAN);
421        }
422
423        let tr = if self.count == 0 {
424            high - low
425        } else {
426            (high - low)
427                .max((high - self.prev_close).abs())
428                .max((low - self.prev_close).abs())
429        };
430
431        if self.count < self.length {
432            self.warm_sum += tr;
433            self.count += 1;
434            if self.count == self.length {
435                self.atr = self.warm_sum / self.length as f64;
436            }
437        } else {
438            let alpha = 1.0 / self.length as f64;
439            self.atr = alpha.mul_add(tr - self.atr, self.atr);
440            self.count += 1;
441        }
442
443        let hist_high = if self.history_high.len() == self.length {
444            self.history_high.front().copied().unwrap_or(0.0)
445        } else {
446            0.0
447        };
448        let hist_low = if self.history_low.len() == self.length {
449            self.history_low.front().copied().unwrap_or(0.0)
450        } else {
451            0.0
452        };
453        let denom = self.atr * self.sqrt_length;
454        let out = if self.count >= self.length && denom.is_finite() && denom != 0.0 {
455            ((high - hist_low) / denom, (hist_high - low) / denom)
456        } else {
457            (f64::NAN, f64::NAN)
458        };
459
460        self.history_high.push_back(high);
461        self.history_low.push_back(low);
462        if self.history_high.len() > self.length {
463            self.history_high.pop_front();
464        }
465        if self.history_low.len() > self.length {
466            self.history_low.pop_front();
467        }
468        self.prev_close = close;
469
470        out
471    }
472}
473
474#[derive(Debug, Clone)]
475pub struct RandomWalkIndexBatchRange {
476    pub length: (usize, usize, usize),
477}
478
479#[derive(Debug, Clone)]
480pub struct RandomWalkIndexBatchOutput {
481    pub high: Vec<f64>,
482    pub low: Vec<f64>,
483    pub combos: Vec<RandomWalkIndexParams>,
484    pub rows: usize,
485    pub cols: usize,
486}
487
488#[derive(Copy, Clone, Debug)]
489pub struct RandomWalkIndexBatchBuilder {
490    length: (usize, usize, usize),
491    kernel: Kernel,
492}
493
494impl Default for RandomWalkIndexBatchBuilder {
495    fn default() -> Self {
496        Self {
497            length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
498            kernel: Kernel::Auto,
499        }
500    }
501}
502
503impl RandomWalkIndexBatchBuilder {
504    #[inline(always)]
505    pub fn new() -> Self {
506        Self::default()
507    }
508
509    #[inline(always)]
510    pub fn length_range(mut self, value: (usize, usize, usize)) -> Self {
511        self.length = value;
512        self
513    }
514
515    #[inline(always)]
516    pub fn kernel(mut self, value: Kernel) -> Self {
517        self.kernel = value;
518        self
519    }
520
521    #[inline(always)]
522    pub fn apply_candles(
523        self,
524        candles: &Candles,
525    ) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
526        random_walk_index_batch_with_kernel(
527            candles.high.as_slice(),
528            candles.low.as_slice(),
529            candles.close.as_slice(),
530            &RandomWalkIndexBatchRange {
531                length: self.length,
532            },
533            self.kernel,
534        )
535    }
536
537    #[inline(always)]
538    pub fn apply_slices(
539        self,
540        high: &[f64],
541        low: &[f64],
542        close: &[f64],
543    ) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
544        random_walk_index_batch_with_kernel(
545            high,
546            low,
547            close,
548            &RandomWalkIndexBatchRange {
549                length: self.length,
550            },
551            self.kernel,
552        )
553    }
554}
555
556pub fn expand_grid(
557    sweep: &RandomWalkIndexBatchRange,
558) -> Result<Vec<RandomWalkIndexParams>, RandomWalkIndexError> {
559    let (start, end, step) = sweep.length;
560    if start == 0 {
561        return Err(RandomWalkIndexError::InvalidRange {
562            start: start.to_string(),
563            end: end.to_string(),
564            step: step.to_string(),
565        });
566    }
567    let mut lengths = Vec::new();
568    if step == 0 {
569        if start != end {
570            return Err(RandomWalkIndexError::InvalidRange {
571                start: start.to_string(),
572                end: end.to_string(),
573                step: step.to_string(),
574            });
575        }
576        lengths.push(start);
577    } else {
578        if start > end {
579            return Err(RandomWalkIndexError::InvalidRange {
580                start: start.to_string(),
581                end: end.to_string(),
582                step: step.to_string(),
583            });
584        }
585        let mut current = start;
586        while current <= end {
587            lengths.push(current);
588            match current.checked_add(step) {
589                Some(next) => current = next,
590                None => break,
591            }
592        }
593    }
594
595    Ok(lengths
596        .into_iter()
597        .map(|length| RandomWalkIndexParams {
598            length: Some(length),
599        })
600        .collect())
601}
602
603pub fn random_walk_index_batch_with_kernel(
604    high: &[f64],
605    low: &[f64],
606    close: &[f64],
607    sweep: &RandomWalkIndexBatchRange,
608    kernel: Kernel,
609) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
610    let batch_kernel = match kernel {
611        Kernel::Auto => detect_best_batch_kernel(),
612        other if other.is_batch() => other,
613        _ => return Err(RandomWalkIndexError::InvalidKernelForBatch(kernel)),
614    };
615    random_walk_index_batch_par_slice(high, low, close, sweep, batch_kernel.to_non_batch())
616}
617
618#[inline(always)]
619pub fn random_walk_index_batch_slice(
620    high: &[f64],
621    low: &[f64],
622    close: &[f64],
623    sweep: &RandomWalkIndexBatchRange,
624    kernel: Kernel,
625) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
626    random_walk_index_batch_inner(high, low, close, sweep, kernel, false)
627}
628
629#[inline(always)]
630pub fn random_walk_index_batch_par_slice(
631    high: &[f64],
632    low: &[f64],
633    close: &[f64],
634    sweep: &RandomWalkIndexBatchRange,
635    kernel: Kernel,
636) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
637    random_walk_index_batch_inner(high, low, close, sweep, kernel, true)
638}
639
640fn validate_raw_slices(
641    high: &[f64],
642    low: &[f64],
643    close: &[f64],
644) -> Result<usize, RandomWalkIndexError> {
645    if high.is_empty() || low.is_empty() || close.is_empty() {
646        return Err(RandomWalkIndexError::EmptyInputData);
647    }
648    if high.len() != low.len() || high.len() != close.len() {
649        return Err(RandomWalkIndexError::InconsistentSliceLengths {
650            high_len: high.len(),
651            low_len: low.len(),
652            close_len: close.len(),
653        });
654    }
655    first_valid_hlc(high, low, close).ok_or(RandomWalkIndexError::AllValuesNaN)
656}
657
658fn random_walk_index_batch_inner(
659    high: &[f64],
660    low: &[f64],
661    close: &[f64],
662    sweep: &RandomWalkIndexBatchRange,
663    kernel: Kernel,
664    parallel: bool,
665) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
666    let combos = expand_grid(sweep)?;
667    let first = validate_raw_slices(high, low, close)?;
668    let max_length = combos
669        .iter()
670        .map(|combo| combo.length.unwrap())
671        .max()
672        .unwrap();
673    let valid = close.len().saturating_sub(first);
674    if valid < max_length {
675        return Err(RandomWalkIndexError::NotEnoughValidData {
676            needed: max_length,
677            valid,
678        });
679    }
680
681    let rows = combos.len();
682    let cols = close.len();
683    let warmups: Vec<usize> = combos
684        .iter()
685        .map(|combo| first + combo.length.unwrap() - 1)
686        .collect();
687
688    let mut high_buf = make_uninit_matrix(rows, cols);
689    init_matrix_prefixes(&mut high_buf, cols, &warmups);
690    let mut high_guard = ManuallyDrop::new(high_buf);
691    let out_high: &mut [f64] = unsafe {
692        core::slice::from_raw_parts_mut(high_guard.as_mut_ptr() as *mut f64, high_guard.len())
693    };
694
695    let mut low_buf = make_uninit_matrix(rows, cols);
696    init_matrix_prefixes(&mut low_buf, cols, &warmups);
697    let mut low_guard = ManuallyDrop::new(low_buf);
698    let out_low: &mut [f64] = unsafe {
699        core::slice::from_raw_parts_mut(low_guard.as_mut_ptr() as *mut f64, low_guard.len())
700    };
701
702    random_walk_index_batch_inner_into(
703        high, low, close, sweep, kernel, parallel, out_high, out_low,
704    )?;
705
706    let high_values = unsafe {
707        Vec::from_raw_parts(
708            high_guard.as_mut_ptr() as *mut f64,
709            high_guard.len(),
710            high_guard.capacity(),
711        )
712    };
713    let low_values = unsafe {
714        Vec::from_raw_parts(
715            low_guard.as_mut_ptr() as *mut f64,
716            low_guard.len(),
717            low_guard.capacity(),
718        )
719    };
720
721    Ok(RandomWalkIndexBatchOutput {
722        high: high_values,
723        low: low_values,
724        combos,
725        rows,
726        cols,
727    })
728}
729
730pub fn random_walk_index_batch_into_slice(
731    out_high: &mut [f64],
732    out_low: &mut [f64],
733    high: &[f64],
734    low: &[f64],
735    close: &[f64],
736    sweep: &RandomWalkIndexBatchRange,
737    kernel: Kernel,
738) -> Result<(), RandomWalkIndexError> {
739    random_walk_index_batch_inner_into(high, low, close, sweep, kernel, false, out_high, out_low)?;
740    Ok(())
741}
742
743fn random_walk_index_batch_inner_into(
744    high: &[f64],
745    low: &[f64],
746    close: &[f64],
747    sweep: &RandomWalkIndexBatchRange,
748    _kernel: Kernel,
749    parallel: bool,
750    out_high: &mut [f64],
751    out_low: &mut [f64],
752) -> Result<Vec<RandomWalkIndexParams>, RandomWalkIndexError> {
753    let combos = expand_grid(sweep)?;
754    let first = validate_raw_slices(high, low, close)?;
755    let rows = combos.len();
756    let cols = close.len();
757    let expected = rows
758        .checked_mul(cols)
759        .ok_or_else(|| RandomWalkIndexError::InvalidRange {
760            start: rows.to_string(),
761            end: cols.to_string(),
762            step: "rows*cols".to_string(),
763        })?;
764    if out_high.len() != expected || out_low.len() != expected {
765        return Err(RandomWalkIndexError::OutputLengthMismatch {
766            expected,
767            got: out_high.len().max(out_low.len()),
768        });
769    }
770    let max_length = combos
771        .iter()
772        .map(|combo| combo.length.unwrap())
773        .max()
774        .unwrap();
775    let valid = cols.saturating_sub(first);
776    if valid < max_length {
777        return Err(RandomWalkIndexError::NotEnoughValidData {
778            needed: max_length,
779            valid,
780        });
781    }
782
783    let do_row = |row: usize, dst_high: &mut [f64], dst_low: &mut [f64]| {
784        let length = combos[row].length.unwrap();
785        let warm = first + length - 1;
786        dst_high[..warm.min(cols)].fill(f64::NAN);
787        dst_low[..warm.min(cols)].fill(f64::NAN);
788        compute_random_walk_index_into(high, low, close, length, first, dst_high, dst_low);
789    };
790
791    if parallel {
792        #[cfg(not(target_arch = "wasm32"))]
793        {
794            out_high
795                .par_chunks_mut(cols)
796                .zip(out_low.par_chunks_mut(cols))
797                .enumerate()
798                .for_each(|(row, (dst_high, dst_low))| do_row(row, dst_high, dst_low));
799        }
800        #[cfg(target_arch = "wasm32")]
801        {
802            for ((row, dst_high), dst_low) in out_high
803                .chunks_mut(cols)
804                .enumerate()
805                .zip(out_low.chunks_mut(cols))
806            {
807                do_row(row, dst_high, dst_low);
808            }
809        }
810    } else {
811        for ((row, dst_high), dst_low) in out_high
812            .chunks_mut(cols)
813            .enumerate()
814            .zip(out_low.chunks_mut(cols))
815        {
816            do_row(row, dst_high, dst_low);
817        }
818    }
819
820    Ok(combos)
821}
822
823#[cfg(feature = "python")]
824#[pyfunction(name = "random_walk_index")]
825#[pyo3(signature = (high, low, close, length=14, kernel=None))]
826pub fn random_walk_index_py<'py>(
827    py: Python<'py>,
828    high: PyReadonlyArray1<'py, f64>,
829    low: PyReadonlyArray1<'py, f64>,
830    close: PyReadonlyArray1<'py, f64>,
831    length: usize,
832    kernel: Option<&str>,
833) -> PyResult<Bound<'py, PyDict>> {
834    let high = high.as_slice()?;
835    let low = low.as_slice()?;
836    let close = close.as_slice()?;
837    let input = RandomWalkIndexInput::from_slices(
838        high,
839        low,
840        close,
841        RandomWalkIndexParams {
842            length: Some(length),
843        },
844    );
845    let kernel = validate_kernel(kernel, false)?;
846    let out = py
847        .allow_threads(|| random_walk_index_with_kernel(&input, kernel))
848        .map_err(|e| PyValueError::new_err(e.to_string()))?;
849    let dict = PyDict::new(py);
850    dict.set_item("high", out.high.into_pyarray(py))?;
851    dict.set_item("low", out.low.into_pyarray(py))?;
852    Ok(dict)
853}
854
855#[cfg(feature = "python")]
856#[pyclass(name = "RandomWalkIndexStream")]
857pub struct RandomWalkIndexStreamPy {
858    stream: RandomWalkIndexStream,
859}
860
861#[cfg(feature = "python")]
862#[pymethods]
863impl RandomWalkIndexStreamPy {
864    #[new]
865    #[pyo3(signature = (length=14))]
866    fn new(length: usize) -> PyResult<Self> {
867        let stream = RandomWalkIndexStream::try_new(RandomWalkIndexParams {
868            length: Some(length),
869        })
870        .map_err(|e| PyValueError::new_err(e.to_string()))?;
871        Ok(Self { stream })
872    }
873
874    fn update(&mut self, high: f64, low: f64, close: f64) -> (f64, f64) {
875        self.stream.update(high, low, close)
876    }
877}
878
879#[cfg(feature = "python")]
880#[pyfunction(name = "random_walk_index_batch")]
881#[pyo3(signature = (high, low, close, length_range=(14,14,0), kernel=None))]
882pub fn random_walk_index_batch_py<'py>(
883    py: Python<'py>,
884    high: PyReadonlyArray1<'py, f64>,
885    low: PyReadonlyArray1<'py, f64>,
886    close: PyReadonlyArray1<'py, f64>,
887    length_range: (usize, usize, usize),
888    kernel: Option<&str>,
889) -> PyResult<Bound<'py, PyDict>> {
890    let high = high.as_slice()?;
891    let low = low.as_slice()?;
892    let close = close.as_slice()?;
893    let sweep = RandomWalkIndexBatchRange {
894        length: length_range,
895    };
896    let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
897    let rows = combos.len();
898    let cols = close.len();
899    let total = rows
900        .checked_mul(cols)
901        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
902
903    let out_high = unsafe { PyArray1::<f64>::new(py, [total], false) };
904    let out_low = unsafe { PyArray1::<f64>::new(py, [total], false) };
905    let high_slice = unsafe { out_high.as_slice_mut()? };
906    let low_slice = unsafe { out_low.as_slice_mut()? };
907    let kernel = validate_kernel(kernel, true)?;
908
909    py.allow_threads(|| {
910        let batch_kernel = match kernel {
911            Kernel::Auto => detect_best_batch_kernel(),
912            other => other,
913        };
914        random_walk_index_batch_inner_into(
915            high,
916            low,
917            close,
918            &sweep,
919            batch_kernel.to_non_batch(),
920            true,
921            high_slice,
922            low_slice,
923        )
924    })
925    .map_err(|e| PyValueError::new_err(e.to_string()))?;
926
927    let dict = PyDict::new(py);
928    dict.set_item("high", out_high.reshape((rows, cols))?)?;
929    dict.set_item("low", out_low.reshape((rows, cols))?)?;
930    dict.set_item(
931        "lengths",
932        combos
933            .iter()
934            .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH) as u64)
935            .collect::<Vec<_>>()
936            .into_pyarray(py),
937    )?;
938    dict.set_item("rows", rows)?;
939    dict.set_item("cols", cols)?;
940    Ok(dict)
941}
942
943#[cfg(feature = "python")]
944pub fn register_random_walk_index_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
945    m.add_function(wrap_pyfunction!(random_walk_index_py, m)?)?;
946    m.add_function(wrap_pyfunction!(random_walk_index_batch_py, m)?)?;
947    m.add_class::<RandomWalkIndexStreamPy>()?;
948    Ok(())
949}
950
951#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
952#[derive(Serialize, Deserialize)]
953pub struct RandomWalkIndexJsOutput {
954    pub high: Vec<f64>,
955    pub low: Vec<f64>,
956}
957
958#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
959#[wasm_bindgen(js_name = "random_walk_index_js")]
960pub fn random_walk_index_js(
961    high: &[f64],
962    low: &[f64],
963    close: &[f64],
964    length: usize,
965) -> Result<JsValue, JsValue> {
966    let input = RandomWalkIndexInput::from_slices(
967        high,
968        low,
969        close,
970        RandomWalkIndexParams {
971            length: Some(length),
972        },
973    );
974    let out = random_walk_index_with_kernel(&input, Kernel::Auto)
975        .map_err(|e| JsValue::from_str(&e.to_string()))?;
976    serde_wasm_bindgen::to_value(&RandomWalkIndexJsOutput {
977        high: out.high,
978        low: out.low,
979    })
980    .map_err(|e| JsValue::from_str(&e.to_string()))
981}
982
983#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
984#[derive(Serialize, Deserialize)]
985pub struct RandomWalkIndexBatchConfig {
986    pub length_range: Vec<f64>,
987}
988
989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
990#[derive(Serialize, Deserialize)]
991pub struct RandomWalkIndexBatchJsOutput {
992    pub high: Vec<f64>,
993    pub low: Vec<f64>,
994    pub lengths: Vec<usize>,
995    pub rows: usize,
996    pub cols: usize,
997}
998
999#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1000fn js_vec3_to_usize(name: &str, values: &[f64]) -> Result<(usize, usize, usize), JsValue> {
1001    if values.len() != 3 {
1002        return Err(JsValue::from_str(&format!(
1003            "Invalid config: {name} must have exactly 3 elements [start, end, step]"
1004        )));
1005    }
1006    let mut out = [0usize; 3];
1007    for (i, value) in values.iter().copied().enumerate() {
1008        if !value.is_finite() || value < 0.0 {
1009            return Err(JsValue::from_str(&format!(
1010                "Invalid config: {name}[{i}] must be a finite non-negative whole number"
1011            )));
1012        }
1013        let rounded = value.round();
1014        if (value - rounded).abs() > 1e-9 {
1015            return Err(JsValue::from_str(&format!(
1016                "Invalid config: {name}[{i}] must be a whole number"
1017            )));
1018        }
1019        out[i] = rounded as usize;
1020    }
1021    Ok((out[0], out[1], out[2]))
1022}
1023
1024#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1025#[wasm_bindgen(js_name = "random_walk_index_batch_js")]
1026pub fn random_walk_index_batch_js(
1027    high: &[f64],
1028    low: &[f64],
1029    close: &[f64],
1030    config: JsValue,
1031) -> Result<JsValue, JsValue> {
1032    let config: RandomWalkIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1033        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1034    let sweep = RandomWalkIndexBatchRange {
1035        length: js_vec3_to_usize("length_range", &config.length_range)?,
1036    };
1037    let out = random_walk_index_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1038        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1039    let lengths = out
1040        .combos
1041        .iter()
1042        .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
1043        .collect();
1044    serde_wasm_bindgen::to_value(&RandomWalkIndexBatchJsOutput {
1045        high: out.high,
1046        low: out.low,
1047        lengths,
1048        rows: out.rows,
1049        cols: out.cols,
1050    })
1051    .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1052}
1053
1054#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1055#[wasm_bindgen]
1056pub fn random_walk_index_alloc(len: usize) -> *mut f64 {
1057    let mut vec = Vec::<f64>::with_capacity(len);
1058    let ptr = vec.as_mut_ptr();
1059    std::mem::forget(vec);
1060    ptr
1061}
1062
1063#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1064#[wasm_bindgen]
1065pub fn random_walk_index_free(ptr: *mut f64, len: usize) {
1066    if !ptr.is_null() {
1067        unsafe {
1068            let _ = Vec::from_raw_parts(ptr, len, len);
1069        }
1070    }
1071}
1072
1073#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1074#[wasm_bindgen]
1075pub fn random_walk_index_into(
1076    high_ptr: *const f64,
1077    low_ptr: *const f64,
1078    close_ptr: *const f64,
1079    out_high_ptr: *mut f64,
1080    out_low_ptr: *mut f64,
1081    len: usize,
1082    length: usize,
1083) -> Result<(), JsValue> {
1084    if high_ptr.is_null()
1085        || low_ptr.is_null()
1086        || close_ptr.is_null()
1087        || out_high_ptr.is_null()
1088        || out_low_ptr.is_null()
1089    {
1090        return Err(JsValue::from_str("Null pointer provided"));
1091    }
1092    unsafe {
1093        let high = std::slice::from_raw_parts(high_ptr, len);
1094        let low = std::slice::from_raw_parts(low_ptr, len);
1095        let close = std::slice::from_raw_parts(close_ptr, len);
1096        let out_high = std::slice::from_raw_parts_mut(out_high_ptr, len);
1097        let out_low = std::slice::from_raw_parts_mut(out_low_ptr, len);
1098        let input = RandomWalkIndexInput::from_slices(
1099            high,
1100            low,
1101            close,
1102            RandomWalkIndexParams {
1103                length: Some(length),
1104            },
1105        );
1106        random_walk_index_into_slice(out_high, out_low, &input, Kernel::Auto)
1107            .map_err(|e| JsValue::from_str(&e.to_string()))
1108    }
1109}
1110
1111#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1112#[wasm_bindgen]
1113pub fn random_walk_index_batch_into(
1114    high_ptr: *const f64,
1115    low_ptr: *const f64,
1116    close_ptr: *const f64,
1117    out_high_ptr: *mut f64,
1118    out_low_ptr: *mut f64,
1119    len: usize,
1120    length_start: usize,
1121    length_end: usize,
1122    length_step: usize,
1123) -> Result<usize, JsValue> {
1124    if high_ptr.is_null()
1125        || low_ptr.is_null()
1126        || close_ptr.is_null()
1127        || out_high_ptr.is_null()
1128        || out_low_ptr.is_null()
1129    {
1130        return Err(JsValue::from_str(
1131            "null pointer passed to random_walk_index_batch_into",
1132        ));
1133    }
1134    unsafe {
1135        let high = std::slice::from_raw_parts(high_ptr, len);
1136        let low = std::slice::from_raw_parts(low_ptr, len);
1137        let close = std::slice::from_raw_parts(close_ptr, len);
1138        let sweep = RandomWalkIndexBatchRange {
1139            length: (length_start, length_end, length_step),
1140        };
1141        let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1142        let rows = combos.len();
1143        let total = rows.checked_mul(len).ok_or_else(|| {
1144            JsValue::from_str("rows*cols overflow in random_walk_index_batch_into")
1145        })?;
1146        let out_high = std::slice::from_raw_parts_mut(out_high_ptr, total);
1147        let out_low = std::slice::from_raw_parts_mut(out_low_ptr, total);
1148        random_walk_index_batch_into_slice(
1149            out_high,
1150            out_low,
1151            high,
1152            low,
1153            close,
1154            &sweep,
1155            Kernel::Auto,
1156        )
1157        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1158        Ok(rows)
1159    }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164    use super::*;
1165
1166    fn manual_random_walk_index(
1167        high: &[f64],
1168        low: &[f64],
1169        close: &[f64],
1170        length: usize,
1171    ) -> (Vec<f64>, Vec<f64>) {
1172        let n = close.len();
1173        let mut out_high = vec![f64::NAN; n];
1174        let mut out_low = vec![f64::NAN; n];
1175        let first = first_valid_hlc(high, low, close).unwrap();
1176        compute_random_walk_index_into(
1177            high,
1178            low,
1179            close,
1180            length,
1181            first,
1182            &mut out_high,
1183            &mut out_low,
1184        );
1185        (out_high, out_low)
1186    }
1187
1188    fn sample_hlc(n: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1189        let close: Vec<f64> = (0..n)
1190            .map(|i| 100.0 + ((i as f64) * 0.19).sin() * 2.0 + (i as f64) * 0.03)
1191            .collect();
1192        let high: Vec<f64> = close
1193            .iter()
1194            .enumerate()
1195            .map(|(i, &c)| c + 1.5 + ((i as f64) * 0.11).cos().abs())
1196            .collect();
1197        let low: Vec<f64> = close
1198            .iter()
1199            .enumerate()
1200            .map(|(i, &c)| c - 1.3 - ((i as f64) * 0.07).sin().abs())
1201            .collect();
1202        (high, low, close)
1203    }
1204
1205    fn assert_close(lhs: &[f64], rhs: &[f64]) {
1206        assert_eq!(lhs.len(), rhs.len());
1207        for (idx, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
1208            if a.is_nan() && b.is_nan() {
1209                continue;
1210            }
1211            let diff = (a - b).abs();
1212            assert!(diff <= 1e-12, "mismatch at {idx}: {a} vs {b}");
1213        }
1214    }
1215
1216    #[test]
1217    fn manual_reference_matches_api() {
1218        let (high, low, close) = sample_hlc(128);
1219        let input = RandomWalkIndexInput::from_slices(
1220            &high,
1221            &low,
1222            &close,
1223            RandomWalkIndexParams { length: Some(14) },
1224        );
1225        let out = random_walk_index(&input).unwrap();
1226        let (want_high, want_low) = manual_random_walk_index(&high, &low, &close, 14);
1227        assert_close(&out.high, &want_high);
1228        assert_close(&out.low, &want_low);
1229    }
1230
1231    #[test]
1232    fn stream_matches_batch() {
1233        let (high, low, close) = sample_hlc(96);
1234        let input = RandomWalkIndexInput::from_slices(
1235            &high,
1236            &low,
1237            &close,
1238            RandomWalkIndexParams { length: Some(14) },
1239        );
1240        let out = random_walk_index(&input).unwrap();
1241        let mut stream =
1242            RandomWalkIndexStream::try_new(RandomWalkIndexParams { length: Some(14) }).unwrap();
1243        let mut got_high = Vec::with_capacity(high.len());
1244        let mut got_low = Vec::with_capacity(high.len());
1245        for i in 0..high.len() {
1246            let (h, l) = stream.update(high[i], low[i], close[i]);
1247            got_high.push(h);
1248            got_low.push(l);
1249        }
1250        assert_close(&out.high, &got_high);
1251        assert_close(&out.low, &got_low);
1252    }
1253
1254    #[test]
1255    fn batch_first_row_matches_single() {
1256        let (high, low, close) = sample_hlc(80);
1257        let batch = random_walk_index_batch_with_kernel(
1258            &high,
1259            &low,
1260            &close,
1261            &RandomWalkIndexBatchRange {
1262                length: (14, 16, 2),
1263            },
1264            Kernel::Auto,
1265        )
1266        .unwrap();
1267        let input = RandomWalkIndexInput::from_slices(
1268            &high,
1269            &low,
1270            &close,
1271            RandomWalkIndexParams { length: Some(14) },
1272        );
1273        let single = random_walk_index(&input).unwrap();
1274        assert_eq!(batch.rows, 2);
1275        assert_close(&batch.high[..80], single.high.as_slice());
1276        assert_close(&batch.low[..80], single.low.as_slice());
1277    }
1278
1279    #[test]
1280    fn into_slice_matches_single() {
1281        let (high, low, close) = sample_hlc(72);
1282        let input = RandomWalkIndexInput::from_slices(
1283            &high,
1284            &low,
1285            &close,
1286            RandomWalkIndexParams { length: Some(14) },
1287        );
1288        let single = random_walk_index(&input).unwrap();
1289        let mut out_high = vec![0.0; close.len()];
1290        let mut out_low = vec![0.0; close.len()];
1291        random_walk_index_into_slice(&mut out_high, &mut out_low, &input, Kernel::Auto).unwrap();
1292        assert_close(&single.high, &out_high);
1293        assert_close(&single.low, &out_low);
1294    }
1295
1296    #[test]
1297    fn invalid_length_is_rejected() {
1298        let (high, low, close) = sample_hlc(8);
1299        let input = RandomWalkIndexInput::from_slices(
1300            &high,
1301            &low,
1302            &close,
1303            RandomWalkIndexParams { length: Some(0) },
1304        );
1305        assert!(matches!(
1306            random_walk_index(&input),
1307            Err(RandomWalkIndexError::InvalidLength { .. })
1308        ));
1309    }
1310}