1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
19};
20#[cfg(feature = "python")]
21use crate::utilities::kernel_validation::validate_kernel;
22#[cfg(not(target_arch = "wasm32"))]
23use rayon::prelude::*;
24use std::collections::VecDeque;
25use std::mem::ManuallyDrop;
26use thiserror::Error;
27
28const DEFAULT_LENGTH: usize = 14;
29
30#[derive(Debug, Clone)]
31pub enum RandomWalkIndexData<'a> {
32 Candles {
33 candles: &'a Candles,
34 },
35 Slices {
36 high: &'a [f64],
37 low: &'a [f64],
38 close: &'a [f64],
39 },
40}
41
42#[derive(Debug, Clone)]
43pub struct RandomWalkIndexOutput {
44 pub high: Vec<f64>,
45 pub low: Vec<f64>,
46}
47
48#[derive(Debug, Clone)]
49#[cfg_attr(
50 all(target_arch = "wasm32", feature = "wasm"),
51 derive(Serialize, Deserialize)
52)]
53pub struct RandomWalkIndexParams {
54 pub length: Option<usize>,
55}
56
57impl Default for RandomWalkIndexParams {
58 fn default() -> Self {
59 Self {
60 length: Some(DEFAULT_LENGTH),
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct RandomWalkIndexInput<'a> {
67 pub data: RandomWalkIndexData<'a>,
68 pub params: RandomWalkIndexParams,
69}
70
71impl<'a> RandomWalkIndexInput<'a> {
72 #[inline]
73 pub fn from_candles(candles: &'a Candles, params: RandomWalkIndexParams) -> Self {
74 Self {
75 data: RandomWalkIndexData::Candles { candles },
76 params,
77 }
78 }
79
80 #[inline]
81 pub fn from_slices(
82 high: &'a [f64],
83 low: &'a [f64],
84 close: &'a [f64],
85 params: RandomWalkIndexParams,
86 ) -> Self {
87 Self {
88 data: RandomWalkIndexData::Slices { high, low, close },
89 params,
90 }
91 }
92
93 #[inline]
94 pub fn with_default_candles(candles: &'a Candles) -> Self {
95 Self::from_candles(candles, RandomWalkIndexParams::default())
96 }
97
98 #[inline]
99 pub fn get_length(&self) -> usize {
100 self.params.length.unwrap_or(DEFAULT_LENGTH)
101 }
102}
103
104#[derive(Copy, Clone, Debug)]
105pub struct RandomWalkIndexBuilder {
106 length: Option<usize>,
107 kernel: Kernel,
108}
109
110impl Default for RandomWalkIndexBuilder {
111 fn default() -> Self {
112 Self {
113 length: None,
114 kernel: Kernel::Auto,
115 }
116 }
117}
118
119impl RandomWalkIndexBuilder {
120 #[inline(always)]
121 pub fn new() -> Self {
122 Self::default()
123 }
124
125 #[inline(always)]
126 pub fn length(mut self, value: usize) -> Self {
127 self.length = Some(value);
128 self
129 }
130
131 #[inline(always)]
132 pub fn kernel(mut self, value: Kernel) -> Self {
133 self.kernel = value;
134 self
135 }
136
137 #[inline(always)]
138 pub fn apply(self, candles: &Candles) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
139 let input = RandomWalkIndexInput::from_candles(
140 candles,
141 RandomWalkIndexParams {
142 length: self.length,
143 },
144 );
145 random_walk_index_with_kernel(&input, self.kernel)
146 }
147
148 #[inline(always)]
149 pub fn apply_slices(
150 self,
151 high: &[f64],
152 low: &[f64],
153 close: &[f64],
154 ) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
155 let input = RandomWalkIndexInput::from_slices(
156 high,
157 low,
158 close,
159 RandomWalkIndexParams {
160 length: self.length,
161 },
162 );
163 random_walk_index_with_kernel(&input, self.kernel)
164 }
165
166 #[inline(always)]
167 pub fn into_stream(self) -> Result<RandomWalkIndexStream, RandomWalkIndexError> {
168 RandomWalkIndexStream::try_new(RandomWalkIndexParams {
169 length: self.length,
170 })
171 }
172}
173
174#[derive(Debug, Error)]
175pub enum RandomWalkIndexError {
176 #[error("random_walk_index: Input data slice is empty.")]
177 EmptyInputData,
178 #[error("random_walk_index: All values are NaN.")]
179 AllValuesNaN,
180 #[error("random_walk_index: Inconsistent slice lengths: high={high_len}, low={low_len}, close={close_len}")]
181 InconsistentSliceLengths {
182 high_len: usize,
183 low_len: usize,
184 close_len: usize,
185 },
186 #[error("random_walk_index: Invalid length: length={length}, data length={data_len}")]
187 InvalidLength { length: usize, data_len: usize },
188 #[error("random_walk_index: Not enough valid data: needed={needed}, valid={valid}")]
189 NotEnoughValidData { needed: usize, valid: usize },
190 #[error("random_walk_index: Output length mismatch: expected={expected}, got={got}")]
191 OutputLengthMismatch { expected: usize, got: usize },
192 #[error("random_walk_index: Invalid range: start={start}, end={end}, step={step}")]
193 InvalidRange {
194 start: String,
195 end: String,
196 step: String,
197 },
198 #[error("random_walk_index: Invalid kernel for batch: {0:?}")]
199 InvalidKernelForBatch(Kernel),
200}
201
202#[inline(always)]
203fn extract_hlc<'a>(
204 input: &'a RandomWalkIndexInput<'a>,
205) -> Result<(&'a [f64], &'a [f64], &'a [f64]), RandomWalkIndexError> {
206 let (high, low, close) = match &input.data {
207 RandomWalkIndexData::Candles { candles } => (
208 source_type(candles, "high"),
209 source_type(candles, "low"),
210 source_type(candles, "close"),
211 ),
212 RandomWalkIndexData::Slices { high, low, close } => (*high, *low, *close),
213 };
214
215 if high.is_empty() || low.is_empty() || close.is_empty() {
216 return Err(RandomWalkIndexError::EmptyInputData);
217 }
218 if high.len() != low.len() || high.len() != close.len() {
219 return Err(RandomWalkIndexError::InconsistentSliceLengths {
220 high_len: high.len(),
221 low_len: low.len(),
222 close_len: close.len(),
223 });
224 }
225 Ok((high, low, close))
226}
227
228#[inline(always)]
229fn first_valid_hlc(high: &[f64], low: &[f64], close: &[f64]) -> Option<usize> {
230 (0..high.len()).find(|&i| high[i].is_finite() && low[i].is_finite() && close[i].is_finite())
231}
232
233#[inline(always)]
234fn prepare<'a>(
235 input: &'a RandomWalkIndexInput<'a>,
236 kernel: Kernel,
237) -> Result<(&'a [f64], &'a [f64], &'a [f64], usize, usize, Kernel), RandomWalkIndexError> {
238 let (high, low, close) = extract_hlc(input)?;
239 let len = close.len();
240 let length = input.get_length();
241 if length == 0 || length > len {
242 return Err(RandomWalkIndexError::InvalidLength {
243 length,
244 data_len: len,
245 });
246 }
247 let first = first_valid_hlc(high, low, close).ok_or(RandomWalkIndexError::AllValuesNaN)?;
248 let valid = len.saturating_sub(first);
249 if valid < length {
250 return Err(RandomWalkIndexError::NotEnoughValidData {
251 needed: length,
252 valid,
253 });
254 }
255 Ok((high, low, close, length, first, kernel.to_non_batch()))
256}
257
258#[inline(always)]
259fn nz_history(src: &[f64], idx: usize, offset: usize) -> f64 {
260 if idx >= offset {
261 let value = src[idx - offset];
262 if value.is_finite() {
263 value
264 } else {
265 0.0
266 }
267 } else {
268 0.0
269 }
270}
271
272#[inline(always)]
273fn compute_random_walk_index_into(
274 high: &[f64],
275 low: &[f64],
276 close: &[f64],
277 length: usize,
278 first: usize,
279 out_high: &mut [f64],
280 out_low: &mut [f64],
281) {
282 let n = close.len();
283 let warm = first + length - 1;
284 let sqrt_length = (length as f64).sqrt();
285 let alpha = 1.0 / length as f64;
286
287 let mut prev_close = close[first];
288 let mut sum_tr = high[first] - low[first];
289 let mut atr = f64::NAN;
290
291 if length == 1 {
292 atr = sum_tr;
293 let denom = atr * sqrt_length;
294 if denom.is_finite() && denom != 0.0 {
295 out_high[first] = (high[first] - nz_history(low, first, length)) / denom;
296 out_low[first] = (nz_history(high, first, length) - low[first]) / denom;
297 }
298 }
299
300 let mut i = first + 1;
301 while i < n {
302 let tr = (high[i] - low[i])
303 .max((high[i] - prev_close).abs())
304 .max((low[i] - prev_close).abs());
305
306 if i <= warm {
307 sum_tr += tr;
308 if i == warm {
309 atr = sum_tr / length as f64;
310 }
311 } else {
312 atr = alpha.mul_add(tr - atr, atr);
313 }
314
315 if i >= warm {
316 let denom = atr * sqrt_length;
317 if denom.is_finite() && denom != 0.0 {
318 out_high[i] = (high[i] - nz_history(low, i, length)) / denom;
319 out_low[i] = (nz_history(high, i, length) - low[i]) / denom;
320 } else {
321 out_high[i] = f64::NAN;
322 out_low[i] = f64::NAN;
323 }
324 }
325
326 prev_close = close[i];
327 i += 1;
328 }
329}
330
331#[inline]
332pub fn random_walk_index(
333 input: &RandomWalkIndexInput,
334) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
335 random_walk_index_with_kernel(input, Kernel::Auto)
336}
337
338pub fn random_walk_index_with_kernel(
339 input: &RandomWalkIndexInput,
340 kernel: Kernel,
341) -> Result<RandomWalkIndexOutput, RandomWalkIndexError> {
342 let (high, low, close, length, first, _) = prepare(input, kernel)?;
343 let warm = first + length - 1;
344 let mut out_high = alloc_with_nan_prefix(close.len(), warm);
345 let mut out_low = alloc_with_nan_prefix(close.len(), warm);
346 compute_random_walk_index_into(high, low, close, length, first, &mut out_high, &mut out_low);
347 Ok(RandomWalkIndexOutput {
348 high: out_high,
349 low: out_low,
350 })
351}
352
353#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
354pub fn random_walk_index_into(
355 out_high: &mut [f64],
356 out_low: &mut [f64],
357 input: &RandomWalkIndexInput,
358 kernel: Kernel,
359) -> Result<(), RandomWalkIndexError> {
360 random_walk_index_into_slice(out_high, out_low, input, kernel)
361}
362
363pub fn random_walk_index_into_slice(
364 out_high: &mut [f64],
365 out_low: &mut [f64],
366 input: &RandomWalkIndexInput,
367 kernel: Kernel,
368) -> Result<(), RandomWalkIndexError> {
369 let (high, low, close, length, first, _) = prepare(input, kernel)?;
370 let expected = close.len();
371 if out_high.len() != expected || out_low.len() != expected {
372 return Err(RandomWalkIndexError::OutputLengthMismatch {
373 expected,
374 got: out_high.len().max(out_low.len()),
375 });
376 }
377 let warm = first + length - 1;
378 out_high[..warm.min(expected)].fill(f64::NAN);
379 out_low[..warm.min(expected)].fill(f64::NAN);
380 compute_random_walk_index_into(high, low, close, length, first, out_high, out_low);
381 Ok(())
382}
383
384#[derive(Debug, Clone)]
385pub struct RandomWalkIndexStream {
386 length: usize,
387 sqrt_length: f64,
388 count: usize,
389 warm_sum: f64,
390 atr: f64,
391 prev_close: f64,
392 history_high: VecDeque<f64>,
393 history_low: VecDeque<f64>,
394}
395
396impl RandomWalkIndexStream {
397 pub fn try_new(params: RandomWalkIndexParams) -> Result<Self, RandomWalkIndexError> {
398 let length = params.length.unwrap_or(DEFAULT_LENGTH);
399 if length == 0 {
400 return Err(RandomWalkIndexError::InvalidLength {
401 length,
402 data_len: 0,
403 });
404 }
405 Ok(Self {
406 length,
407 sqrt_length: (length as f64).sqrt(),
408 count: 0,
409 warm_sum: 0.0,
410 atr: f64::NAN,
411 prev_close: f64::NAN,
412 history_high: VecDeque::with_capacity(length),
413 history_low: VecDeque::with_capacity(length),
414 })
415 }
416
417 #[inline]
418 pub fn update(&mut self, high: f64, low: f64, close: f64) -> (f64, f64) {
419 if !high.is_finite() || !low.is_finite() || !close.is_finite() {
420 return (f64::NAN, f64::NAN);
421 }
422
423 let tr = if self.count == 0 {
424 high - low
425 } else {
426 (high - low)
427 .max((high - self.prev_close).abs())
428 .max((low - self.prev_close).abs())
429 };
430
431 if self.count < self.length {
432 self.warm_sum += tr;
433 self.count += 1;
434 if self.count == self.length {
435 self.atr = self.warm_sum / self.length as f64;
436 }
437 } else {
438 let alpha = 1.0 / self.length as f64;
439 self.atr = alpha.mul_add(tr - self.atr, self.atr);
440 self.count += 1;
441 }
442
443 let hist_high = if self.history_high.len() == self.length {
444 self.history_high.front().copied().unwrap_or(0.0)
445 } else {
446 0.0
447 };
448 let hist_low = if self.history_low.len() == self.length {
449 self.history_low.front().copied().unwrap_or(0.0)
450 } else {
451 0.0
452 };
453 let denom = self.atr * self.sqrt_length;
454 let out = if self.count >= self.length && denom.is_finite() && denom != 0.0 {
455 ((high - hist_low) / denom, (hist_high - low) / denom)
456 } else {
457 (f64::NAN, f64::NAN)
458 };
459
460 self.history_high.push_back(high);
461 self.history_low.push_back(low);
462 if self.history_high.len() > self.length {
463 self.history_high.pop_front();
464 }
465 if self.history_low.len() > self.length {
466 self.history_low.pop_front();
467 }
468 self.prev_close = close;
469
470 out
471 }
472}
473
474#[derive(Debug, Clone)]
475pub struct RandomWalkIndexBatchRange {
476 pub length: (usize, usize, usize),
477}
478
479#[derive(Debug, Clone)]
480pub struct RandomWalkIndexBatchOutput {
481 pub high: Vec<f64>,
482 pub low: Vec<f64>,
483 pub combos: Vec<RandomWalkIndexParams>,
484 pub rows: usize,
485 pub cols: usize,
486}
487
488#[derive(Copy, Clone, Debug)]
489pub struct RandomWalkIndexBatchBuilder {
490 length: (usize, usize, usize),
491 kernel: Kernel,
492}
493
494impl Default for RandomWalkIndexBatchBuilder {
495 fn default() -> Self {
496 Self {
497 length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
498 kernel: Kernel::Auto,
499 }
500 }
501}
502
503impl RandomWalkIndexBatchBuilder {
504 #[inline(always)]
505 pub fn new() -> Self {
506 Self::default()
507 }
508
509 #[inline(always)]
510 pub fn length_range(mut self, value: (usize, usize, usize)) -> Self {
511 self.length = value;
512 self
513 }
514
515 #[inline(always)]
516 pub fn kernel(mut self, value: Kernel) -> Self {
517 self.kernel = value;
518 self
519 }
520
521 #[inline(always)]
522 pub fn apply_candles(
523 self,
524 candles: &Candles,
525 ) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
526 random_walk_index_batch_with_kernel(
527 candles.high.as_slice(),
528 candles.low.as_slice(),
529 candles.close.as_slice(),
530 &RandomWalkIndexBatchRange {
531 length: self.length,
532 },
533 self.kernel,
534 )
535 }
536
537 #[inline(always)]
538 pub fn apply_slices(
539 self,
540 high: &[f64],
541 low: &[f64],
542 close: &[f64],
543 ) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
544 random_walk_index_batch_with_kernel(
545 high,
546 low,
547 close,
548 &RandomWalkIndexBatchRange {
549 length: self.length,
550 },
551 self.kernel,
552 )
553 }
554}
555
556pub fn expand_grid(
557 sweep: &RandomWalkIndexBatchRange,
558) -> Result<Vec<RandomWalkIndexParams>, RandomWalkIndexError> {
559 let (start, end, step) = sweep.length;
560 if start == 0 {
561 return Err(RandomWalkIndexError::InvalidRange {
562 start: start.to_string(),
563 end: end.to_string(),
564 step: step.to_string(),
565 });
566 }
567 let mut lengths = Vec::new();
568 if step == 0 {
569 if start != end {
570 return Err(RandomWalkIndexError::InvalidRange {
571 start: start.to_string(),
572 end: end.to_string(),
573 step: step.to_string(),
574 });
575 }
576 lengths.push(start);
577 } else {
578 if start > end {
579 return Err(RandomWalkIndexError::InvalidRange {
580 start: start.to_string(),
581 end: end.to_string(),
582 step: step.to_string(),
583 });
584 }
585 let mut current = start;
586 while current <= end {
587 lengths.push(current);
588 match current.checked_add(step) {
589 Some(next) => current = next,
590 None => break,
591 }
592 }
593 }
594
595 Ok(lengths
596 .into_iter()
597 .map(|length| RandomWalkIndexParams {
598 length: Some(length),
599 })
600 .collect())
601}
602
603pub fn random_walk_index_batch_with_kernel(
604 high: &[f64],
605 low: &[f64],
606 close: &[f64],
607 sweep: &RandomWalkIndexBatchRange,
608 kernel: Kernel,
609) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
610 let batch_kernel = match kernel {
611 Kernel::Auto => detect_best_batch_kernel(),
612 other if other.is_batch() => other,
613 _ => return Err(RandomWalkIndexError::InvalidKernelForBatch(kernel)),
614 };
615 random_walk_index_batch_par_slice(high, low, close, sweep, batch_kernel.to_non_batch())
616}
617
618#[inline(always)]
619pub fn random_walk_index_batch_slice(
620 high: &[f64],
621 low: &[f64],
622 close: &[f64],
623 sweep: &RandomWalkIndexBatchRange,
624 kernel: Kernel,
625) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
626 random_walk_index_batch_inner(high, low, close, sweep, kernel, false)
627}
628
629#[inline(always)]
630pub fn random_walk_index_batch_par_slice(
631 high: &[f64],
632 low: &[f64],
633 close: &[f64],
634 sweep: &RandomWalkIndexBatchRange,
635 kernel: Kernel,
636) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
637 random_walk_index_batch_inner(high, low, close, sweep, kernel, true)
638}
639
640fn validate_raw_slices(
641 high: &[f64],
642 low: &[f64],
643 close: &[f64],
644) -> Result<usize, RandomWalkIndexError> {
645 if high.is_empty() || low.is_empty() || close.is_empty() {
646 return Err(RandomWalkIndexError::EmptyInputData);
647 }
648 if high.len() != low.len() || high.len() != close.len() {
649 return Err(RandomWalkIndexError::InconsistentSliceLengths {
650 high_len: high.len(),
651 low_len: low.len(),
652 close_len: close.len(),
653 });
654 }
655 first_valid_hlc(high, low, close).ok_or(RandomWalkIndexError::AllValuesNaN)
656}
657
658fn random_walk_index_batch_inner(
659 high: &[f64],
660 low: &[f64],
661 close: &[f64],
662 sweep: &RandomWalkIndexBatchRange,
663 kernel: Kernel,
664 parallel: bool,
665) -> Result<RandomWalkIndexBatchOutput, RandomWalkIndexError> {
666 let combos = expand_grid(sweep)?;
667 let first = validate_raw_slices(high, low, close)?;
668 let max_length = combos
669 .iter()
670 .map(|combo| combo.length.unwrap())
671 .max()
672 .unwrap();
673 let valid = close.len().saturating_sub(first);
674 if valid < max_length {
675 return Err(RandomWalkIndexError::NotEnoughValidData {
676 needed: max_length,
677 valid,
678 });
679 }
680
681 let rows = combos.len();
682 let cols = close.len();
683 let warmups: Vec<usize> = combos
684 .iter()
685 .map(|combo| first + combo.length.unwrap() - 1)
686 .collect();
687
688 let mut high_buf = make_uninit_matrix(rows, cols);
689 init_matrix_prefixes(&mut high_buf, cols, &warmups);
690 let mut high_guard = ManuallyDrop::new(high_buf);
691 let out_high: &mut [f64] = unsafe {
692 core::slice::from_raw_parts_mut(high_guard.as_mut_ptr() as *mut f64, high_guard.len())
693 };
694
695 let mut low_buf = make_uninit_matrix(rows, cols);
696 init_matrix_prefixes(&mut low_buf, cols, &warmups);
697 let mut low_guard = ManuallyDrop::new(low_buf);
698 let out_low: &mut [f64] = unsafe {
699 core::slice::from_raw_parts_mut(low_guard.as_mut_ptr() as *mut f64, low_guard.len())
700 };
701
702 random_walk_index_batch_inner_into(
703 high, low, close, sweep, kernel, parallel, out_high, out_low,
704 )?;
705
706 let high_values = unsafe {
707 Vec::from_raw_parts(
708 high_guard.as_mut_ptr() as *mut f64,
709 high_guard.len(),
710 high_guard.capacity(),
711 )
712 };
713 let low_values = unsafe {
714 Vec::from_raw_parts(
715 low_guard.as_mut_ptr() as *mut f64,
716 low_guard.len(),
717 low_guard.capacity(),
718 )
719 };
720
721 Ok(RandomWalkIndexBatchOutput {
722 high: high_values,
723 low: low_values,
724 combos,
725 rows,
726 cols,
727 })
728}
729
730pub fn random_walk_index_batch_into_slice(
731 out_high: &mut [f64],
732 out_low: &mut [f64],
733 high: &[f64],
734 low: &[f64],
735 close: &[f64],
736 sweep: &RandomWalkIndexBatchRange,
737 kernel: Kernel,
738) -> Result<(), RandomWalkIndexError> {
739 random_walk_index_batch_inner_into(high, low, close, sweep, kernel, false, out_high, out_low)?;
740 Ok(())
741}
742
743fn random_walk_index_batch_inner_into(
744 high: &[f64],
745 low: &[f64],
746 close: &[f64],
747 sweep: &RandomWalkIndexBatchRange,
748 _kernel: Kernel,
749 parallel: bool,
750 out_high: &mut [f64],
751 out_low: &mut [f64],
752) -> Result<Vec<RandomWalkIndexParams>, RandomWalkIndexError> {
753 let combos = expand_grid(sweep)?;
754 let first = validate_raw_slices(high, low, close)?;
755 let rows = combos.len();
756 let cols = close.len();
757 let expected = rows
758 .checked_mul(cols)
759 .ok_or_else(|| RandomWalkIndexError::InvalidRange {
760 start: rows.to_string(),
761 end: cols.to_string(),
762 step: "rows*cols".to_string(),
763 })?;
764 if out_high.len() != expected || out_low.len() != expected {
765 return Err(RandomWalkIndexError::OutputLengthMismatch {
766 expected,
767 got: out_high.len().max(out_low.len()),
768 });
769 }
770 let max_length = combos
771 .iter()
772 .map(|combo| combo.length.unwrap())
773 .max()
774 .unwrap();
775 let valid = cols.saturating_sub(first);
776 if valid < max_length {
777 return Err(RandomWalkIndexError::NotEnoughValidData {
778 needed: max_length,
779 valid,
780 });
781 }
782
783 let do_row = |row: usize, dst_high: &mut [f64], dst_low: &mut [f64]| {
784 let length = combos[row].length.unwrap();
785 let warm = first + length - 1;
786 dst_high[..warm.min(cols)].fill(f64::NAN);
787 dst_low[..warm.min(cols)].fill(f64::NAN);
788 compute_random_walk_index_into(high, low, close, length, first, dst_high, dst_low);
789 };
790
791 if parallel {
792 #[cfg(not(target_arch = "wasm32"))]
793 {
794 out_high
795 .par_chunks_mut(cols)
796 .zip(out_low.par_chunks_mut(cols))
797 .enumerate()
798 .for_each(|(row, (dst_high, dst_low))| do_row(row, dst_high, dst_low));
799 }
800 #[cfg(target_arch = "wasm32")]
801 {
802 for ((row, dst_high), dst_low) in out_high
803 .chunks_mut(cols)
804 .enumerate()
805 .zip(out_low.chunks_mut(cols))
806 {
807 do_row(row, dst_high, dst_low);
808 }
809 }
810 } else {
811 for ((row, dst_high), dst_low) in out_high
812 .chunks_mut(cols)
813 .enumerate()
814 .zip(out_low.chunks_mut(cols))
815 {
816 do_row(row, dst_high, dst_low);
817 }
818 }
819
820 Ok(combos)
821}
822
823#[cfg(feature = "python")]
824#[pyfunction(name = "random_walk_index")]
825#[pyo3(signature = (high, low, close, length=14, kernel=None))]
826pub fn random_walk_index_py<'py>(
827 py: Python<'py>,
828 high: PyReadonlyArray1<'py, f64>,
829 low: PyReadonlyArray1<'py, f64>,
830 close: PyReadonlyArray1<'py, f64>,
831 length: usize,
832 kernel: Option<&str>,
833) -> PyResult<Bound<'py, PyDict>> {
834 let high = high.as_slice()?;
835 let low = low.as_slice()?;
836 let close = close.as_slice()?;
837 let input = RandomWalkIndexInput::from_slices(
838 high,
839 low,
840 close,
841 RandomWalkIndexParams {
842 length: Some(length),
843 },
844 );
845 let kernel = validate_kernel(kernel, false)?;
846 let out = py
847 .allow_threads(|| random_walk_index_with_kernel(&input, kernel))
848 .map_err(|e| PyValueError::new_err(e.to_string()))?;
849 let dict = PyDict::new(py);
850 dict.set_item("high", out.high.into_pyarray(py))?;
851 dict.set_item("low", out.low.into_pyarray(py))?;
852 Ok(dict)
853}
854
855#[cfg(feature = "python")]
856#[pyclass(name = "RandomWalkIndexStream")]
857pub struct RandomWalkIndexStreamPy {
858 stream: RandomWalkIndexStream,
859}
860
861#[cfg(feature = "python")]
862#[pymethods]
863impl RandomWalkIndexStreamPy {
864 #[new]
865 #[pyo3(signature = (length=14))]
866 fn new(length: usize) -> PyResult<Self> {
867 let stream = RandomWalkIndexStream::try_new(RandomWalkIndexParams {
868 length: Some(length),
869 })
870 .map_err(|e| PyValueError::new_err(e.to_string()))?;
871 Ok(Self { stream })
872 }
873
874 fn update(&mut self, high: f64, low: f64, close: f64) -> (f64, f64) {
875 self.stream.update(high, low, close)
876 }
877}
878
879#[cfg(feature = "python")]
880#[pyfunction(name = "random_walk_index_batch")]
881#[pyo3(signature = (high, low, close, length_range=(14,14,0), kernel=None))]
882pub fn random_walk_index_batch_py<'py>(
883 py: Python<'py>,
884 high: PyReadonlyArray1<'py, f64>,
885 low: PyReadonlyArray1<'py, f64>,
886 close: PyReadonlyArray1<'py, f64>,
887 length_range: (usize, usize, usize),
888 kernel: Option<&str>,
889) -> PyResult<Bound<'py, PyDict>> {
890 let high = high.as_slice()?;
891 let low = low.as_slice()?;
892 let close = close.as_slice()?;
893 let sweep = RandomWalkIndexBatchRange {
894 length: length_range,
895 };
896 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
897 let rows = combos.len();
898 let cols = close.len();
899 let total = rows
900 .checked_mul(cols)
901 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
902
903 let out_high = unsafe { PyArray1::<f64>::new(py, [total], false) };
904 let out_low = unsafe { PyArray1::<f64>::new(py, [total], false) };
905 let high_slice = unsafe { out_high.as_slice_mut()? };
906 let low_slice = unsafe { out_low.as_slice_mut()? };
907 let kernel = validate_kernel(kernel, true)?;
908
909 py.allow_threads(|| {
910 let batch_kernel = match kernel {
911 Kernel::Auto => detect_best_batch_kernel(),
912 other => other,
913 };
914 random_walk_index_batch_inner_into(
915 high,
916 low,
917 close,
918 &sweep,
919 batch_kernel.to_non_batch(),
920 true,
921 high_slice,
922 low_slice,
923 )
924 })
925 .map_err(|e| PyValueError::new_err(e.to_string()))?;
926
927 let dict = PyDict::new(py);
928 dict.set_item("high", out_high.reshape((rows, cols))?)?;
929 dict.set_item("low", out_low.reshape((rows, cols))?)?;
930 dict.set_item(
931 "lengths",
932 combos
933 .iter()
934 .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH) as u64)
935 .collect::<Vec<_>>()
936 .into_pyarray(py),
937 )?;
938 dict.set_item("rows", rows)?;
939 dict.set_item("cols", cols)?;
940 Ok(dict)
941}
942
943#[cfg(feature = "python")]
944pub fn register_random_walk_index_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
945 m.add_function(wrap_pyfunction!(random_walk_index_py, m)?)?;
946 m.add_function(wrap_pyfunction!(random_walk_index_batch_py, m)?)?;
947 m.add_class::<RandomWalkIndexStreamPy>()?;
948 Ok(())
949}
950
951#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
952#[derive(Serialize, Deserialize)]
953pub struct RandomWalkIndexJsOutput {
954 pub high: Vec<f64>,
955 pub low: Vec<f64>,
956}
957
958#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
959#[wasm_bindgen(js_name = "random_walk_index_js")]
960pub fn random_walk_index_js(
961 high: &[f64],
962 low: &[f64],
963 close: &[f64],
964 length: usize,
965) -> Result<JsValue, JsValue> {
966 let input = RandomWalkIndexInput::from_slices(
967 high,
968 low,
969 close,
970 RandomWalkIndexParams {
971 length: Some(length),
972 },
973 );
974 let out = random_walk_index_with_kernel(&input, Kernel::Auto)
975 .map_err(|e| JsValue::from_str(&e.to_string()))?;
976 serde_wasm_bindgen::to_value(&RandomWalkIndexJsOutput {
977 high: out.high,
978 low: out.low,
979 })
980 .map_err(|e| JsValue::from_str(&e.to_string()))
981}
982
983#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
984#[derive(Serialize, Deserialize)]
985pub struct RandomWalkIndexBatchConfig {
986 pub length_range: Vec<f64>,
987}
988
989#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
990#[derive(Serialize, Deserialize)]
991pub struct RandomWalkIndexBatchJsOutput {
992 pub high: Vec<f64>,
993 pub low: Vec<f64>,
994 pub lengths: Vec<usize>,
995 pub rows: usize,
996 pub cols: usize,
997}
998
999#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1000fn js_vec3_to_usize(name: &str, values: &[f64]) -> Result<(usize, usize, usize), JsValue> {
1001 if values.len() != 3 {
1002 return Err(JsValue::from_str(&format!(
1003 "Invalid config: {name} must have exactly 3 elements [start, end, step]"
1004 )));
1005 }
1006 let mut out = [0usize; 3];
1007 for (i, value) in values.iter().copied().enumerate() {
1008 if !value.is_finite() || value < 0.0 {
1009 return Err(JsValue::from_str(&format!(
1010 "Invalid config: {name}[{i}] must be a finite non-negative whole number"
1011 )));
1012 }
1013 let rounded = value.round();
1014 if (value - rounded).abs() > 1e-9 {
1015 return Err(JsValue::from_str(&format!(
1016 "Invalid config: {name}[{i}] must be a whole number"
1017 )));
1018 }
1019 out[i] = rounded as usize;
1020 }
1021 Ok((out[0], out[1], out[2]))
1022}
1023
1024#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1025#[wasm_bindgen(js_name = "random_walk_index_batch_js")]
1026pub fn random_walk_index_batch_js(
1027 high: &[f64],
1028 low: &[f64],
1029 close: &[f64],
1030 config: JsValue,
1031) -> Result<JsValue, JsValue> {
1032 let config: RandomWalkIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1033 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1034 let sweep = RandomWalkIndexBatchRange {
1035 length: js_vec3_to_usize("length_range", &config.length_range)?,
1036 };
1037 let out = random_walk_index_batch_with_kernel(high, low, close, &sweep, Kernel::Auto)
1038 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1039 let lengths = out
1040 .combos
1041 .iter()
1042 .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
1043 .collect();
1044 serde_wasm_bindgen::to_value(&RandomWalkIndexBatchJsOutput {
1045 high: out.high,
1046 low: out.low,
1047 lengths,
1048 rows: out.rows,
1049 cols: out.cols,
1050 })
1051 .map_err(|e| JsValue::from_str(&format!("Serialization error: {e}")))
1052}
1053
1054#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1055#[wasm_bindgen]
1056pub fn random_walk_index_alloc(len: usize) -> *mut f64 {
1057 let mut vec = Vec::<f64>::with_capacity(len);
1058 let ptr = vec.as_mut_ptr();
1059 std::mem::forget(vec);
1060 ptr
1061}
1062
1063#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1064#[wasm_bindgen]
1065pub fn random_walk_index_free(ptr: *mut f64, len: usize) {
1066 if !ptr.is_null() {
1067 unsafe {
1068 let _ = Vec::from_raw_parts(ptr, len, len);
1069 }
1070 }
1071}
1072
1073#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1074#[wasm_bindgen]
1075pub fn random_walk_index_into(
1076 high_ptr: *const f64,
1077 low_ptr: *const f64,
1078 close_ptr: *const f64,
1079 out_high_ptr: *mut f64,
1080 out_low_ptr: *mut f64,
1081 len: usize,
1082 length: usize,
1083) -> Result<(), JsValue> {
1084 if high_ptr.is_null()
1085 || low_ptr.is_null()
1086 || close_ptr.is_null()
1087 || out_high_ptr.is_null()
1088 || out_low_ptr.is_null()
1089 {
1090 return Err(JsValue::from_str("Null pointer provided"));
1091 }
1092 unsafe {
1093 let high = std::slice::from_raw_parts(high_ptr, len);
1094 let low = std::slice::from_raw_parts(low_ptr, len);
1095 let close = std::slice::from_raw_parts(close_ptr, len);
1096 let out_high = std::slice::from_raw_parts_mut(out_high_ptr, len);
1097 let out_low = std::slice::from_raw_parts_mut(out_low_ptr, len);
1098 let input = RandomWalkIndexInput::from_slices(
1099 high,
1100 low,
1101 close,
1102 RandomWalkIndexParams {
1103 length: Some(length),
1104 },
1105 );
1106 random_walk_index_into_slice(out_high, out_low, &input, Kernel::Auto)
1107 .map_err(|e| JsValue::from_str(&e.to_string()))
1108 }
1109}
1110
1111#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1112#[wasm_bindgen]
1113pub fn random_walk_index_batch_into(
1114 high_ptr: *const f64,
1115 low_ptr: *const f64,
1116 close_ptr: *const f64,
1117 out_high_ptr: *mut f64,
1118 out_low_ptr: *mut f64,
1119 len: usize,
1120 length_start: usize,
1121 length_end: usize,
1122 length_step: usize,
1123) -> Result<usize, JsValue> {
1124 if high_ptr.is_null()
1125 || low_ptr.is_null()
1126 || close_ptr.is_null()
1127 || out_high_ptr.is_null()
1128 || out_low_ptr.is_null()
1129 {
1130 return Err(JsValue::from_str(
1131 "null pointer passed to random_walk_index_batch_into",
1132 ));
1133 }
1134 unsafe {
1135 let high = std::slice::from_raw_parts(high_ptr, len);
1136 let low = std::slice::from_raw_parts(low_ptr, len);
1137 let close = std::slice::from_raw_parts(close_ptr, len);
1138 let sweep = RandomWalkIndexBatchRange {
1139 length: (length_start, length_end, length_step),
1140 };
1141 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1142 let rows = combos.len();
1143 let total = rows.checked_mul(len).ok_or_else(|| {
1144 JsValue::from_str("rows*cols overflow in random_walk_index_batch_into")
1145 })?;
1146 let out_high = std::slice::from_raw_parts_mut(out_high_ptr, total);
1147 let out_low = std::slice::from_raw_parts_mut(out_low_ptr, total);
1148 random_walk_index_batch_into_slice(
1149 out_high,
1150 out_low,
1151 high,
1152 low,
1153 close,
1154 &sweep,
1155 Kernel::Auto,
1156 )
1157 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1158 Ok(rows)
1159 }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165
1166 fn manual_random_walk_index(
1167 high: &[f64],
1168 low: &[f64],
1169 close: &[f64],
1170 length: usize,
1171 ) -> (Vec<f64>, Vec<f64>) {
1172 let n = close.len();
1173 let mut out_high = vec![f64::NAN; n];
1174 let mut out_low = vec![f64::NAN; n];
1175 let first = first_valid_hlc(high, low, close).unwrap();
1176 compute_random_walk_index_into(
1177 high,
1178 low,
1179 close,
1180 length,
1181 first,
1182 &mut out_high,
1183 &mut out_low,
1184 );
1185 (out_high, out_low)
1186 }
1187
1188 fn sample_hlc(n: usize) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
1189 let close: Vec<f64> = (0..n)
1190 .map(|i| 100.0 + ((i as f64) * 0.19).sin() * 2.0 + (i as f64) * 0.03)
1191 .collect();
1192 let high: Vec<f64> = close
1193 .iter()
1194 .enumerate()
1195 .map(|(i, &c)| c + 1.5 + ((i as f64) * 0.11).cos().abs())
1196 .collect();
1197 let low: Vec<f64> = close
1198 .iter()
1199 .enumerate()
1200 .map(|(i, &c)| c - 1.3 - ((i as f64) * 0.07).sin().abs())
1201 .collect();
1202 (high, low, close)
1203 }
1204
1205 fn assert_close(lhs: &[f64], rhs: &[f64]) {
1206 assert_eq!(lhs.len(), rhs.len());
1207 for (idx, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
1208 if a.is_nan() && b.is_nan() {
1209 continue;
1210 }
1211 let diff = (a - b).abs();
1212 assert!(diff <= 1e-12, "mismatch at {idx}: {a} vs {b}");
1213 }
1214 }
1215
1216 #[test]
1217 fn manual_reference_matches_api() {
1218 let (high, low, close) = sample_hlc(128);
1219 let input = RandomWalkIndexInput::from_slices(
1220 &high,
1221 &low,
1222 &close,
1223 RandomWalkIndexParams { length: Some(14) },
1224 );
1225 let out = random_walk_index(&input).unwrap();
1226 let (want_high, want_low) = manual_random_walk_index(&high, &low, &close, 14);
1227 assert_close(&out.high, &want_high);
1228 assert_close(&out.low, &want_low);
1229 }
1230
1231 #[test]
1232 fn stream_matches_batch() {
1233 let (high, low, close) = sample_hlc(96);
1234 let input = RandomWalkIndexInput::from_slices(
1235 &high,
1236 &low,
1237 &close,
1238 RandomWalkIndexParams { length: Some(14) },
1239 );
1240 let out = random_walk_index(&input).unwrap();
1241 let mut stream =
1242 RandomWalkIndexStream::try_new(RandomWalkIndexParams { length: Some(14) }).unwrap();
1243 let mut got_high = Vec::with_capacity(high.len());
1244 let mut got_low = Vec::with_capacity(high.len());
1245 for i in 0..high.len() {
1246 let (h, l) = stream.update(high[i], low[i], close[i]);
1247 got_high.push(h);
1248 got_low.push(l);
1249 }
1250 assert_close(&out.high, &got_high);
1251 assert_close(&out.low, &got_low);
1252 }
1253
1254 #[test]
1255 fn batch_first_row_matches_single() {
1256 let (high, low, close) = sample_hlc(80);
1257 let batch = random_walk_index_batch_with_kernel(
1258 &high,
1259 &low,
1260 &close,
1261 &RandomWalkIndexBatchRange {
1262 length: (14, 16, 2),
1263 },
1264 Kernel::Auto,
1265 )
1266 .unwrap();
1267 let input = RandomWalkIndexInput::from_slices(
1268 &high,
1269 &low,
1270 &close,
1271 RandomWalkIndexParams { length: Some(14) },
1272 );
1273 let single = random_walk_index(&input).unwrap();
1274 assert_eq!(batch.rows, 2);
1275 assert_close(&batch.high[..80], single.high.as_slice());
1276 assert_close(&batch.low[..80], single.low.as_slice());
1277 }
1278
1279 #[test]
1280 fn into_slice_matches_single() {
1281 let (high, low, close) = sample_hlc(72);
1282 let input = RandomWalkIndexInput::from_slices(
1283 &high,
1284 &low,
1285 &close,
1286 RandomWalkIndexParams { length: Some(14) },
1287 );
1288 let single = random_walk_index(&input).unwrap();
1289 let mut out_high = vec![0.0; close.len()];
1290 let mut out_low = vec![0.0; close.len()];
1291 random_walk_index_into_slice(&mut out_high, &mut out_low, &input, Kernel::Auto).unwrap();
1292 assert_close(&single.high, &out_high);
1293 assert_close(&single.low, &out_low);
1294 }
1295
1296 #[test]
1297 fn invalid_length_is_rejected() {
1298 let (high, low, close) = sample_hlc(8);
1299 let input = RandomWalkIndexInput::from_slices(
1300 &high,
1301 &low,
1302 &close,
1303 RandomWalkIndexParams { length: Some(0) },
1304 );
1305 assert!(matches!(
1306 random_walk_index(&input),
1307 Err(RandomWalkIndexError::InvalidLength { .. })
1308 ));
1309 }
1310}