1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25use std::error::Error;
26use thiserror::Error;
27
28#[derive(Debug, Clone)]
29pub enum VolumeWeightedRsiData<'a> {
30 Candles {
31 candles: &'a Candles,
32 close_source: &'a str,
33 },
34 Slices {
35 close: &'a [f64],
36 volume: &'a [f64],
37 },
38}
39
40#[derive(Debug, Clone)]
41pub struct VolumeWeightedRsiOutput {
42 pub values: Vec<f64>,
43}
44
45#[derive(Debug, Clone)]
46#[cfg_attr(
47 all(target_arch = "wasm32", feature = "wasm"),
48 derive(Serialize, Deserialize)
49)]
50pub struct VolumeWeightedRsiParams {
51 pub period: Option<usize>,
52}
53
54impl Default for VolumeWeightedRsiParams {
55 fn default() -> Self {
56 Self { period: Some(14) }
57 }
58}
59
60#[derive(Debug, Clone)]
61pub struct VolumeWeightedRsiInput<'a> {
62 pub data: VolumeWeightedRsiData<'a>,
63 pub params: VolumeWeightedRsiParams,
64}
65
66impl<'a> VolumeWeightedRsiInput<'a> {
67 #[inline]
68 pub fn from_candles(
69 candles: &'a Candles,
70 close_source: &'a str,
71 params: VolumeWeightedRsiParams,
72 ) -> Self {
73 Self {
74 data: VolumeWeightedRsiData::Candles {
75 candles,
76 close_source,
77 },
78 params,
79 }
80 }
81
82 #[inline]
83 pub fn from_slices(
84 close: &'a [f64],
85 volume: &'a [f64],
86 params: VolumeWeightedRsiParams,
87 ) -> Self {
88 Self {
89 data: VolumeWeightedRsiData::Slices { close, volume },
90 params,
91 }
92 }
93
94 #[inline]
95 pub fn with_default_candles(candles: &'a Candles) -> Self {
96 Self::from_candles(candles, "close", VolumeWeightedRsiParams::default())
97 }
98
99 #[inline]
100 pub fn get_period(&self) -> usize {
101 self.params.period.unwrap_or(14)
102 }
103}
104
105#[derive(Copy, Clone, Debug)]
106pub struct VolumeWeightedRsiBuilder {
107 period: Option<usize>,
108 kernel: Kernel,
109}
110
111impl Default for VolumeWeightedRsiBuilder {
112 fn default() -> Self {
113 Self {
114 period: None,
115 kernel: Kernel::Auto,
116 }
117 }
118}
119
120impl VolumeWeightedRsiBuilder {
121 #[inline(always)]
122 pub fn new() -> Self {
123 Self::default()
124 }
125
126 #[inline(always)]
127 pub fn period(mut self, value: usize) -> Self {
128 self.period = Some(value);
129 self
130 }
131
132 #[inline(always)]
133 pub fn kernel(mut self, value: Kernel) -> Self {
134 self.kernel = value;
135 self
136 }
137
138 #[inline(always)]
139 pub fn apply(
140 self,
141 candles: &Candles,
142 ) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
143 let params = VolumeWeightedRsiParams {
144 period: self.period,
145 };
146 volume_weighted_rsi_with_kernel(
147 &VolumeWeightedRsiInput::from_candles(candles, "close", params),
148 self.kernel,
149 )
150 }
151
152 #[inline(always)]
153 pub fn apply_slices(
154 self,
155 close: &[f64],
156 volume: &[f64],
157 ) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
158 let params = VolumeWeightedRsiParams {
159 period: self.period,
160 };
161 volume_weighted_rsi_with_kernel(
162 &VolumeWeightedRsiInput::from_slices(close, volume, params),
163 self.kernel,
164 )
165 }
166
167 #[inline(always)]
168 pub fn into_stream(self) -> Result<VolumeWeightedRsiStream, VolumeWeightedRsiError> {
169 VolumeWeightedRsiStream::try_new(VolumeWeightedRsiParams {
170 period: self.period,
171 })
172 }
173}
174
175#[derive(Debug, Error)]
176pub enum VolumeWeightedRsiError {
177 #[error("volume_weighted_rsi: Input data slice is empty.")]
178 EmptyInputData,
179 #[error(
180 "volume_weighted_rsi: Input length mismatch: close = {close_len}, volume = {volume_len}"
181 )]
182 InputLengthMismatch { close_len: usize, volume_len: usize },
183 #[error("volume_weighted_rsi: All values are NaN.")]
184 AllValuesNaN,
185 #[error("volume_weighted_rsi: Invalid period: period = {period}, data length = {data_len}")]
186 InvalidPeriod { period: usize, data_len: usize },
187 #[error("volume_weighted_rsi: Not enough valid data: needed = {needed}, valid = {valid}")]
188 NotEnoughValidData { needed: usize, valid: usize },
189 #[error("volume_weighted_rsi: Output length mismatch: expected = {expected}, got = {got}")]
190 OutputLengthMismatch { expected: usize, got: usize },
191 #[error("volume_weighted_rsi: Invalid range: start={start}, end={end}, step={step}")]
192 InvalidRange {
193 start: usize,
194 end: usize,
195 step: usize,
196 },
197 #[error("volume_weighted_rsi: Invalid kernel for batch: {0:?}")]
198 InvalidKernelForBatch(Kernel),
199 #[error(
200 "volume_weighted_rsi: Output length mismatch: dst = {dst_len}, expected = {expected_len}"
201 )]
202 MismatchedOutputLen { dst_len: usize, expected_len: usize },
203 #[error("volume_weighted_rsi: Invalid input: {msg}")]
204 InvalidInput { msg: String },
205}
206
207#[derive(Debug, Clone)]
208pub struct VolumeWeightedRsiStream {
209 period: usize,
210 inv_period: f64,
211 beta: f64,
212 prev_close: f64,
213 has_prev: bool,
214 seeded: usize,
215 sum_up: f64,
216 sum_down: f64,
217 avg_up: f64,
218 avg_down: f64,
219}
220
221impl VolumeWeightedRsiStream {
222 #[inline(always)]
223 pub fn try_new(params: VolumeWeightedRsiParams) -> Result<Self, VolumeWeightedRsiError> {
224 let period = params.period.unwrap_or(14);
225 if period == 0 {
226 return Err(VolumeWeightedRsiError::InvalidPeriod {
227 period,
228 data_len: 0,
229 });
230 }
231 let inv_period = 1.0 / period as f64;
232 Ok(Self {
233 period,
234 inv_period,
235 beta: 1.0 - inv_period,
236 prev_close: f64::NAN,
237 has_prev: false,
238 seeded: 0,
239 sum_up: 0.0,
240 sum_down: 0.0,
241 avg_up: 0.0,
242 avg_down: 0.0,
243 })
244 }
245
246 #[inline(always)]
247 pub fn reset(&mut self) {
248 self.prev_close = f64::NAN;
249 self.has_prev = false;
250 self.seeded = 0;
251 self.sum_up = 0.0;
252 self.sum_down = 0.0;
253 self.avg_up = 0.0;
254 self.avg_down = 0.0;
255 }
256
257 #[inline(always)]
258 pub fn update(&mut self, close: f64, volume: f64) -> Option<f64> {
259 if !is_valid_pair(close, volume) {
260 self.reset();
261 return None;
262 }
263
264 let (up, down) = if self.has_prev {
265 if close > self.prev_close {
266 (volume, 0.0)
267 } else if close < self.prev_close {
268 (0.0, volume)
269 } else {
270 (0.0, 0.0)
271 }
272 } else {
273 (0.0, 0.0)
274 };
275
276 self.prev_close = close;
277 self.has_prev = true;
278
279 if self.seeded < self.period {
280 self.sum_up += up;
281 self.sum_down += down;
282 self.seeded += 1;
283 if self.seeded < self.period {
284 return None;
285 }
286 self.avg_up = self.sum_up * self.inv_period;
287 self.avg_down = self.sum_down * self.inv_period;
288 return Some(rsi_from_components(self.avg_up, self.avg_down));
289 }
290
291 self.avg_up = self.avg_up.mul_add(self.beta, self.inv_period * up);
292 self.avg_down = self.avg_down.mul_add(self.beta, self.inv_period * down);
293 Some(rsi_from_components(self.avg_up, self.avg_down))
294 }
295
296 #[inline(always)]
297 pub fn get_warmup_period(&self) -> usize {
298 self.period.saturating_sub(1)
299 }
300}
301
302#[inline(always)]
303fn is_valid_pair(close: f64, volume: f64) -> bool {
304 close.is_finite() && volume.is_finite()
305}
306
307#[inline(always)]
308fn rsi_from_components(avg_up: f64, avg_down: f64) -> f64 {
309 let denom = avg_up + avg_down;
310 if denom == 0.0 {
311 50.0
312 } else {
313 100.0 * avg_up / denom
314 }
315}
316
317#[inline(always)]
318fn longest_valid_pair_run(close: &[f64], volume: &[f64]) -> usize {
319 let mut best = 0usize;
320 let mut cur = 0usize;
321 for (&c, &v) in close.iter().zip(volume.iter()) {
322 if is_valid_pair(c, v) {
323 cur += 1;
324 if cur > best {
325 best = cur;
326 }
327 } else {
328 cur = 0;
329 }
330 }
331 best
332}
333
334#[inline(always)]
335fn input_slices<'a>(
336 input: &'a VolumeWeightedRsiInput<'a>,
337) -> Result<(&'a [f64], &'a [f64]), VolumeWeightedRsiError> {
338 match &input.data {
339 VolumeWeightedRsiData::Candles {
340 candles,
341 close_source,
342 } => Ok((
343 source_type(candles, close_source),
344 candles.volume.as_slice(),
345 )),
346 VolumeWeightedRsiData::Slices { close, volume } => Ok((*close, *volume)),
347 }
348}
349
350#[inline(always)]
351fn validate_common(
352 close: &[f64],
353 volume: &[f64],
354 period: usize,
355) -> Result<(), VolumeWeightedRsiError> {
356 if close.is_empty() || volume.is_empty() {
357 return Err(VolumeWeightedRsiError::EmptyInputData);
358 }
359 if close.len() != volume.len() {
360 return Err(VolumeWeightedRsiError::InputLengthMismatch {
361 close_len: close.len(),
362 volume_len: volume.len(),
363 });
364 }
365 if period == 0 || period > close.len() {
366 return Err(VolumeWeightedRsiError::InvalidPeriod {
367 period,
368 data_len: close.len(),
369 });
370 }
371
372 let max_run = longest_valid_pair_run(close, volume);
373 if max_run == 0 {
374 return Err(VolumeWeightedRsiError::AllValuesNaN);
375 }
376 if max_run < period {
377 return Err(VolumeWeightedRsiError::NotEnoughValidData {
378 needed: period,
379 valid: max_run,
380 });
381 }
382 Ok(())
383}
384
385#[inline(always)]
386fn compute_row(close: &[f64], volume: &[f64], period: usize, out: &mut [f64]) {
387 let inv_period = 1.0 / period as f64;
388 let beta = 1.0 - inv_period;
389
390 let len = close.len();
391 let mut i = 0usize;
392 while i < len {
393 while i < len && !is_valid_pair(close[i], volume[i]) {
394 out[i] = f64::NAN;
395 i += 1;
396 }
397 if i >= len {
398 break;
399 }
400
401 let seg_start = i;
402 i += 1;
403 while i < len && is_valid_pair(close[i], volume[i]) {
404 i += 1;
405 }
406 let seg_end = i;
407 let seg_len = seg_end - seg_start;
408
409 let warm_end = seg_start + period.saturating_sub(1);
410 let prefix_end = warm_end.min(seg_end);
411 for v in &mut out[seg_start..prefix_end] {
412 *v = f64::NAN;
413 }
414 if seg_len < period {
415 continue;
416 }
417
418 let mut sum_up = 0.0f64;
419 let mut sum_down = 0.0f64;
420 let mut prev_close = close[seg_start];
421 let seed_end = seg_start + period;
422 let mut j = seg_start;
423 while j < seed_end {
424 let c = close[j];
425 let vol = volume[j];
426 let (up, down) = if j == seg_start {
427 (0.0, 0.0)
428 } else if c > prev_close {
429 (vol, 0.0)
430 } else if c < prev_close {
431 (0.0, vol)
432 } else {
433 (0.0, 0.0)
434 };
435 sum_up += up;
436 sum_down += down;
437 prev_close = c;
438 j += 1;
439 }
440
441 let mut avg_up = sum_up * inv_period;
442 let mut avg_down = sum_down * inv_period;
443 out[seed_end - 1] = rsi_from_components(avg_up, avg_down);
444
445 let mut k = seed_end;
446 while k < seg_end {
447 let c = close[k];
448 let vol = volume[k];
449 let (up, down) = if c > prev_close {
450 (vol, 0.0)
451 } else if c < prev_close {
452 (0.0, vol)
453 } else {
454 (0.0, 0.0)
455 };
456 avg_up = avg_up.mul_add(beta, inv_period * up);
457 avg_down = avg_down.mul_add(beta, inv_period * down);
458 out[k] = rsi_from_components(avg_up, avg_down);
459 prev_close = c;
460 k += 1;
461 }
462 }
463}
464
465#[inline]
466pub fn volume_weighted_rsi(
467 input: &VolumeWeightedRsiInput,
468) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
469 volume_weighted_rsi_with_kernel(input, Kernel::Auto)
470}
471
472pub fn volume_weighted_rsi_with_kernel(
473 input: &VolumeWeightedRsiInput,
474 kernel: Kernel,
475) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
476 let (close, volume) = input_slices(input)?;
477 let period = input.get_period();
478 validate_common(close, volume, period)?;
479
480 let _chosen = match kernel {
481 Kernel::Auto => detect_best_kernel(),
482 other => other,
483 };
484
485 let mut out = alloc_with_nan_prefix(close.len(), 0);
486 out.fill(f64::NAN);
487 compute_row(close, volume, period, &mut out);
488 Ok(VolumeWeightedRsiOutput { values: out })
489}
490
491pub fn volume_weighted_rsi_into_slice(
492 dst: &mut [f64],
493 input: &VolumeWeightedRsiInput,
494 kernel: Kernel,
495) -> Result<(), VolumeWeightedRsiError> {
496 let (close, volume) = input_slices(input)?;
497 let period = input.get_period();
498 validate_common(close, volume, period)?;
499
500 if dst.len() != close.len() {
501 return Err(VolumeWeightedRsiError::OutputLengthMismatch {
502 expected: close.len(),
503 got: dst.len(),
504 });
505 }
506
507 let _chosen = match kernel {
508 Kernel::Auto => detect_best_kernel(),
509 other => other,
510 };
511
512 dst.fill(f64::NAN);
513 compute_row(close, volume, period, dst);
514 Ok(())
515}
516
517#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
518pub fn volume_weighted_rsi_into(
519 input: &VolumeWeightedRsiInput,
520 out: &mut [f64],
521) -> Result<(), VolumeWeightedRsiError> {
522 volume_weighted_rsi_into_slice(out, input, Kernel::Auto)
523}
524
525#[derive(Debug, Clone, Copy)]
526pub struct VolumeWeightedRsiBatchRange {
527 pub period: (usize, usize, usize),
528}
529
530impl Default for VolumeWeightedRsiBatchRange {
531 fn default() -> Self {
532 Self {
533 period: (14, 14, 0),
534 }
535 }
536}
537
538#[derive(Debug, Clone)]
539pub struct VolumeWeightedRsiBatchOutput {
540 pub values: Vec<f64>,
541 pub combos: Vec<VolumeWeightedRsiParams>,
542 pub rows: usize,
543 pub cols: usize,
544}
545
546#[derive(Debug, Clone, Copy)]
547pub struct VolumeWeightedRsiBatchBuilder {
548 range: VolumeWeightedRsiBatchRange,
549 kernel: Kernel,
550}
551
552impl Default for VolumeWeightedRsiBatchBuilder {
553 fn default() -> Self {
554 Self {
555 range: VolumeWeightedRsiBatchRange::default(),
556 kernel: Kernel::Auto,
557 }
558 }
559}
560
561impl VolumeWeightedRsiBatchBuilder {
562 #[inline(always)]
563 pub fn new() -> Self {
564 Self::default()
565 }
566
567 #[inline(always)]
568 pub fn kernel(mut self, value: Kernel) -> Self {
569 self.kernel = value;
570 self
571 }
572
573 #[inline(always)]
574 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
575 self.range.period = (start, end, step);
576 self
577 }
578
579 #[inline(always)]
580 pub fn period_static(mut self, value: usize) -> Self {
581 self.range.period = (value, value, 0);
582 self
583 }
584
585 #[inline(always)]
586 pub fn apply_slices(
587 self,
588 close: &[f64],
589 volume: &[f64],
590 ) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
591 volume_weighted_rsi_batch_with_kernel(close, volume, &self.range, self.kernel)
592 }
593
594 #[inline(always)]
595 pub fn apply_candles(
596 self,
597 candles: &Candles,
598 ) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
599 volume_weighted_rsi_batch_with_kernel(
600 candles.close.as_slice(),
601 candles.volume.as_slice(),
602 &self.range,
603 self.kernel,
604 )
605 }
606}
607
608#[inline(always)]
609fn expand_grid_checked(
610 range: &VolumeWeightedRsiBatchRange,
611) -> Result<Vec<VolumeWeightedRsiParams>, VolumeWeightedRsiError> {
612 let (start, end, step) = range.period;
613 if start == 0 || end == 0 {
614 return Err(VolumeWeightedRsiError::InvalidRange { start, end, step });
615 }
616 if step == 0 {
617 return Ok(vec![VolumeWeightedRsiParams {
618 period: Some(start),
619 }]);
620 }
621 if start > end {
622 return Err(VolumeWeightedRsiError::InvalidRange { start, end, step });
623 }
624
625 let mut out = Vec::new();
626 let mut cur = start;
627 loop {
628 out.push(VolumeWeightedRsiParams { period: Some(cur) });
629 if cur >= end {
630 break;
631 }
632 let next = cur.saturating_add(step);
633 if next <= cur {
634 return Err(VolumeWeightedRsiError::InvalidRange { start, end, step });
635 }
636 cur = next.min(end);
637 if cur == *out.last().and_then(|p| p.period.as_ref()).unwrap() {
638 break;
639 }
640 }
641 Ok(out)
642}
643
644#[inline(always)]
645pub fn expand_grid_volume_weighted_rsi(
646 range: &VolumeWeightedRsiBatchRange,
647) -> Vec<VolumeWeightedRsiParams> {
648 expand_grid_checked(range).unwrap_or_default()
649}
650
651pub fn volume_weighted_rsi_batch_with_kernel(
652 close: &[f64],
653 volume: &[f64],
654 sweep: &VolumeWeightedRsiBatchRange,
655 kernel: Kernel,
656) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
657 match kernel {
658 Kernel::Auto
659 | Kernel::Scalar
660 | Kernel::ScalarBatch
661 | Kernel::Avx2
662 | Kernel::Avx2Batch
663 | Kernel::Avx512
664 | Kernel::Avx512Batch => {}
665 other => return Err(VolumeWeightedRsiError::InvalidKernelForBatch(other)),
666 }
667
668 validate_common(close, volume, 1)?;
669 let combos = expand_grid_checked(sweep)?;
670 let max_period = combos
671 .iter()
672 .map(|params| params.period.unwrap_or(14))
673 .max()
674 .unwrap_or(0);
675 validate_common(close, volume, max_period)?;
676
677 let rows = combos.len();
678 let cols = close.len();
679 let mut values_mu = make_uninit_matrix(rows, cols);
680 let warmups: Vec<usize> = combos
681 .iter()
682 .map(|params| params.period.unwrap_or(14).saturating_sub(1))
683 .collect();
684 init_matrix_prefixes(&mut values_mu, cols, &warmups);
685 let mut values = unsafe {
686 Vec::from_raw_parts(
687 values_mu.as_mut_ptr() as *mut f64,
688 values_mu.len(),
689 values_mu.capacity(),
690 )
691 };
692 std::mem::forget(values_mu);
693
694 volume_weighted_rsi_batch_inner_into(close, volume, sweep, kernel, true, &mut values)?;
695
696 Ok(VolumeWeightedRsiBatchOutput {
697 values,
698 combos,
699 rows,
700 cols,
701 })
702}
703
704pub fn volume_weighted_rsi_batch_slice(
705 close: &[f64],
706 volume: &[f64],
707 sweep: &VolumeWeightedRsiBatchRange,
708 kernel: Kernel,
709) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
710 volume_weighted_rsi_batch_inner(close, volume, sweep, kernel, false)
711}
712
713pub fn volume_weighted_rsi_batch_par_slice(
714 close: &[f64],
715 volume: &[f64],
716 sweep: &VolumeWeightedRsiBatchRange,
717 kernel: Kernel,
718) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
719 volume_weighted_rsi_batch_inner(close, volume, sweep, kernel, true)
720}
721
722fn volume_weighted_rsi_batch_inner(
723 close: &[f64],
724 volume: &[f64],
725 sweep: &VolumeWeightedRsiBatchRange,
726 kernel: Kernel,
727 parallel: bool,
728) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
729 let combos = expand_grid_checked(sweep)?;
730 let rows = combos.len();
731 let cols = close.len();
732 let total = rows
733 .checked_mul(cols)
734 .ok_or_else(|| VolumeWeightedRsiError::InvalidInput {
735 msg: "volume_weighted_rsi: rows*cols overflow in batch".to_string(),
736 })?;
737
738 let mut values_mu = make_uninit_matrix(rows, cols);
739 let warmups: Vec<usize> = combos
740 .iter()
741 .map(|params| params.period.unwrap_or(14).saturating_sub(1))
742 .collect();
743 init_matrix_prefixes(&mut values_mu, cols, &warmups);
744 let mut values = unsafe {
745 Vec::from_raw_parts(
746 values_mu.as_mut_ptr() as *mut f64,
747 values_mu.len(),
748 values_mu.capacity(),
749 )
750 };
751 std::mem::forget(values_mu);
752
753 debug_assert_eq!(values.len(), total);
754
755 volume_weighted_rsi_batch_inner_into(close, volume, sweep, kernel, parallel, &mut values)?;
756
757 Ok(VolumeWeightedRsiBatchOutput {
758 values,
759 combos,
760 rows,
761 cols,
762 })
763}
764
765fn volume_weighted_rsi_batch_inner_into(
766 close: &[f64],
767 volume: &[f64],
768 sweep: &VolumeWeightedRsiBatchRange,
769 kernel: Kernel,
770 parallel: bool,
771 out: &mut [f64],
772) -> Result<Vec<VolumeWeightedRsiParams>, VolumeWeightedRsiError> {
773 match kernel {
774 Kernel::Auto
775 | Kernel::Scalar
776 | Kernel::ScalarBatch
777 | Kernel::Avx2
778 | Kernel::Avx2Batch
779 | Kernel::Avx512
780 | Kernel::Avx512Batch => {}
781 other => return Err(VolumeWeightedRsiError::InvalidKernelForBatch(other)),
782 }
783
784 let combos = expand_grid_checked(sweep)?;
785 let len = close.len();
786 if len == 0 || volume.is_empty() {
787 return Err(VolumeWeightedRsiError::EmptyInputData);
788 }
789 if len != volume.len() {
790 return Err(VolumeWeightedRsiError::InputLengthMismatch {
791 close_len: len,
792 volume_len: volume.len(),
793 });
794 }
795
796 let total =
797 combos
798 .len()
799 .checked_mul(len)
800 .ok_or_else(|| VolumeWeightedRsiError::InvalidInput {
801 msg: "volume_weighted_rsi: rows*cols overflow in batch_into".to_string(),
802 })?;
803 if out.len() != total {
804 return Err(VolumeWeightedRsiError::MismatchedOutputLen {
805 dst_len: out.len(),
806 expected_len: total,
807 });
808 }
809
810 let max_period = combos
811 .iter()
812 .map(|params| params.period.unwrap_or(14))
813 .max()
814 .unwrap_or(0);
815 validate_common(close, volume, max_period)?;
816
817 let _chosen = match kernel {
818 Kernel::Auto => detect_best_batch_kernel(),
819 other => other,
820 };
821
822 let worker = |row: usize, dst: &mut [f64]| {
823 dst.fill(f64::NAN);
824 let period = combos[row].period.unwrap_or(14);
825 compute_row(close, volume, period, dst);
826 };
827
828 #[cfg(not(target_arch = "wasm32"))]
829 if parallel {
830 out.par_chunks_mut(len)
831 .enumerate()
832 .for_each(|(row, dst)| worker(row, dst));
833 } else {
834 for (row, dst) in out.chunks_mut(len).enumerate() {
835 worker(row, dst);
836 }
837 }
838
839 #[cfg(target_arch = "wasm32")]
840 {
841 let _ = parallel;
842 for (row, dst) in out.chunks_mut(len).enumerate() {
843 worker(row, dst);
844 }
845 }
846
847 Ok(combos)
848}
849
850#[cfg(feature = "python")]
851#[pyfunction(name = "volume_weighted_rsi")]
852#[pyo3(signature = (close, volume, period=14, kernel=None))]
853pub fn volume_weighted_rsi_py<'py>(
854 py: Python<'py>,
855 close: PyReadonlyArray1<'py, f64>,
856 volume: PyReadonlyArray1<'py, f64>,
857 period: usize,
858 kernel: Option<&str>,
859) -> PyResult<Bound<'py, PyArray1<f64>>> {
860 let close = close.as_slice()?;
861 let volume = volume.as_slice()?;
862 let kern = validate_kernel(kernel, false)?;
863 let input = VolumeWeightedRsiInput::from_slices(
864 close,
865 volume,
866 VolumeWeightedRsiParams {
867 period: Some(period),
868 },
869 );
870 let out = py
871 .allow_threads(|| volume_weighted_rsi_with_kernel(&input, kern))
872 .map_err(|e| PyValueError::new_err(e.to_string()))?;
873 Ok(out.values.into_pyarray(py))
874}
875
876#[cfg(feature = "python")]
877#[pyclass(name = "VolumeWeightedRsiStream")]
878pub struct VolumeWeightedRsiStreamPy {
879 stream: VolumeWeightedRsiStream,
880}
881
882#[cfg(feature = "python")]
883#[pymethods]
884impl VolumeWeightedRsiStreamPy {
885 #[new]
886 fn new(period: usize) -> PyResult<Self> {
887 let stream = VolumeWeightedRsiStream::try_new(VolumeWeightedRsiParams {
888 period: Some(period),
889 })
890 .map_err(|e| PyValueError::new_err(e.to_string()))?;
891 Ok(Self { stream })
892 }
893
894 fn update(&mut self, close: f64, volume: f64) -> Option<f64> {
895 self.stream.update(close, volume)
896 }
897}
898
899#[cfg(feature = "python")]
900#[pyfunction(name = "volume_weighted_rsi_batch")]
901#[pyo3(signature = (close, volume, period_range=(14,14,0), kernel=None))]
902pub fn volume_weighted_rsi_batch_py<'py>(
903 py: Python<'py>,
904 close: PyReadonlyArray1<'py, f64>,
905 volume: PyReadonlyArray1<'py, f64>,
906 period_range: (usize, usize, usize),
907 kernel: Option<&str>,
908) -> PyResult<Bound<'py, PyDict>> {
909 let close = close.as_slice()?;
910 let volume = volume.as_slice()?;
911 let kern = validate_kernel(kernel, true)?;
912
913 let output = py
914 .allow_threads(|| {
915 volume_weighted_rsi_batch_with_kernel(
916 close,
917 volume,
918 &VolumeWeightedRsiBatchRange {
919 period: period_range,
920 },
921 kern,
922 )
923 })
924 .map_err(|e| PyValueError::new_err(e.to_string()))?;
925
926 let rows = output.rows;
927 let cols = output.cols;
928 let dict = PyDict::new(py);
929 dict.set_item(
930 "values",
931 output.values.into_pyarray(py).reshape((rows, cols))?,
932 )?;
933 dict.set_item(
934 "periods",
935 output
936 .combos
937 .iter()
938 .map(|params| params.period.unwrap_or(14) as u64)
939 .collect::<Vec<_>>()
940 .into_pyarray(py),
941 )?;
942 dict.set_item("rows", rows)?;
943 dict.set_item("cols", cols)?;
944 Ok(dict)
945}
946
947#[cfg(feature = "python")]
948pub fn register_volume_weighted_rsi_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
949 m.add_function(wrap_pyfunction!(volume_weighted_rsi_py, m)?)?;
950 m.add_function(wrap_pyfunction!(volume_weighted_rsi_batch_py, m)?)?;
951 m.add_class::<VolumeWeightedRsiStreamPy>()?;
952 Ok(())
953}
954
955#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
956#[derive(Debug, Clone, Serialize, Deserialize)]
957pub struct VolumeWeightedRsiBatchConfig {
958 pub period_range: Vec<usize>,
959}
960
961#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
962#[wasm_bindgen(js_name = volume_weighted_rsi_js)]
963pub fn volume_weighted_rsi_js(
964 close: &[f64],
965 volume: &[f64],
966 period: usize,
967) -> Result<JsValue, JsValue> {
968 let input = VolumeWeightedRsiInput::from_slices(
969 close,
970 volume,
971 VolumeWeightedRsiParams {
972 period: Some(period),
973 },
974 );
975 let out = volume_weighted_rsi_with_kernel(&input, Kernel::Auto)
976 .map_err(|e| JsValue::from_str(&e.to_string()))?;
977 serde_wasm_bindgen::to_value(&out.values).map_err(|e| JsValue::from_str(&e.to_string()))
978}
979
980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
981#[wasm_bindgen(js_name = volume_weighted_rsi_batch_js)]
982pub fn volume_weighted_rsi_batch_js(
983 close: &[f64],
984 volume: &[f64],
985 config: JsValue,
986) -> Result<JsValue, JsValue> {
987 let config: VolumeWeightedRsiBatchConfig = serde_wasm_bindgen::from_value(config)
988 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
989 if config.period_range.len() != 3 {
990 return Err(JsValue::from_str(
991 "Invalid config: period_range must have exactly 3 elements [start, end, step]",
992 ));
993 }
994
995 let out = volume_weighted_rsi_batch_with_kernel(
996 close,
997 volume,
998 &VolumeWeightedRsiBatchRange {
999 period: (
1000 config.period_range[0],
1001 config.period_range[1],
1002 config.period_range[2],
1003 ),
1004 },
1005 Kernel::Auto,
1006 )
1007 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1008
1009 let obj = js_sys::Object::new();
1010 js_sys::Reflect::set(
1011 &obj,
1012 &JsValue::from_str("values"),
1013 &serde_wasm_bindgen::to_value(&out.values).unwrap(),
1014 )?;
1015 js_sys::Reflect::set(
1016 &obj,
1017 &JsValue::from_str("rows"),
1018 &JsValue::from_f64(out.rows as f64),
1019 )?;
1020 js_sys::Reflect::set(
1021 &obj,
1022 &JsValue::from_str("cols"),
1023 &JsValue::from_f64(out.cols as f64),
1024 )?;
1025 js_sys::Reflect::set(
1026 &obj,
1027 &JsValue::from_str("combos"),
1028 &serde_wasm_bindgen::to_value(&out.combos).unwrap(),
1029 )?;
1030 Ok(obj.into())
1031}
1032
1033#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1034#[wasm_bindgen]
1035pub fn volume_weighted_rsi_alloc(len: usize) -> *mut f64 {
1036 let mut vec = Vec::<f64>::with_capacity(len);
1037 let ptr = vec.as_mut_ptr();
1038 std::mem::forget(vec);
1039 ptr
1040}
1041
1042#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1043#[wasm_bindgen]
1044pub fn volume_weighted_rsi_free(ptr: *mut f64, len: usize) {
1045 if !ptr.is_null() {
1046 unsafe {
1047 let _ = Vec::from_raw_parts(ptr, len, len);
1048 }
1049 }
1050}
1051
1052#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1053#[wasm_bindgen]
1054pub fn volume_weighted_rsi_into(
1055 close_ptr: *const f64,
1056 volume_ptr: *const f64,
1057 out_ptr: *mut f64,
1058 len: usize,
1059 period: usize,
1060) -> Result<(), JsValue> {
1061 if close_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1062 return Err(JsValue::from_str(
1063 "null pointer passed to volume_weighted_rsi_into",
1064 ));
1065 }
1066
1067 unsafe {
1068 let close = std::slice::from_raw_parts(close_ptr, len);
1069 let volume = std::slice::from_raw_parts(volume_ptr, len);
1070 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1071 let input = VolumeWeightedRsiInput::from_slices(
1072 close,
1073 volume,
1074 VolumeWeightedRsiParams {
1075 period: Some(period),
1076 },
1077 );
1078 volume_weighted_rsi_into_slice(out, &input, Kernel::Auto)
1079 .map_err(|e| JsValue::from_str(&e.to_string()))
1080 }
1081}
1082
1083#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1084#[wasm_bindgen]
1085pub fn volume_weighted_rsi_batch_into(
1086 close_ptr: *const f64,
1087 volume_ptr: *const f64,
1088 out_ptr: *mut f64,
1089 len: usize,
1090 period_start: usize,
1091 period_end: usize,
1092 period_step: usize,
1093) -> Result<usize, JsValue> {
1094 if close_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1095 return Err(JsValue::from_str(
1096 "null pointer passed to volume_weighted_rsi_batch_into",
1097 ));
1098 }
1099
1100 let sweep = VolumeWeightedRsiBatchRange {
1101 period: (period_start, period_end, period_step),
1102 };
1103 let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1104 let rows = combos.len();
1105 let total = rows
1106 .checked_mul(len)
1107 .ok_or_else(|| JsValue::from_str("rows*cols overflow in volume_weighted_rsi_batch_into"))?;
1108
1109 unsafe {
1110 let close = std::slice::from_raw_parts(close_ptr, len);
1111 let volume = std::slice::from_raw_parts(volume_ptr, len);
1112 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1113 volume_weighted_rsi_batch_inner_into(close, volume, &sweep, Kernel::Auto, false, out)
1114 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1115 }
1116
1117 Ok(rows)
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122 use super::*;
1123 use crate::indicators::dispatch::{
1124 compute_cpu, IndicatorComputeRequest, IndicatorDataRef, ParamKV, ParamValue,
1125 };
1126
1127 fn sample_close_volume(len: usize) -> (Vec<f64>, Vec<f64>) {
1128 let close: Vec<f64> = (0..len)
1129 .map(|i| 100.0 + ((i as f64) * 0.13).sin() * 4.0 + (i as f64) * 0.02)
1130 .collect();
1131 let volume: Vec<f64> = (0..len)
1132 .map(|i| 1000.0 + ((i as f64) * 0.17).cos().abs() * 250.0 + (i % 11) as f64 * 7.0)
1133 .collect();
1134 (close, volume)
1135 }
1136
1137 fn naive_volume_weighted_rsi(close: &[f64], volume: &[f64], period: usize) -> Vec<f64> {
1138 let mut out = vec![f64::NAN; close.len()];
1139 compute_row(close, volume, period, &mut out);
1140 out
1141 }
1142
1143 #[test]
1144 fn volume_weighted_rsi_matches_naive() -> Result<(), Box<dyn Error>> {
1145 let (close, volume) = sample_close_volume(256);
1146 let input = VolumeWeightedRsiInput::from_slices(
1147 &close,
1148 &volume,
1149 VolumeWeightedRsiParams { period: Some(14) },
1150 );
1151 let out = volume_weighted_rsi(&input)?;
1152 let expected = naive_volume_weighted_rsi(&close, &volume, 14);
1153 for (a, b) in out.values.iter().zip(expected.iter()) {
1154 if a.is_nan() || b.is_nan() {
1155 assert!(a.is_nan() && b.is_nan());
1156 } else {
1157 assert!((a - b).abs() < 1e-12);
1158 }
1159 }
1160 Ok(())
1161 }
1162
1163 #[test]
1164 fn volume_weighted_rsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1165 let (close, volume) = sample_close_volume(200);
1166 let input = VolumeWeightedRsiInput::from_slices(
1167 &close,
1168 &volume,
1169 VolumeWeightedRsiParams { period: Some(10) },
1170 );
1171 let base = volume_weighted_rsi(&input)?;
1172 let mut out = vec![0.0; close.len()];
1173 volume_weighted_rsi_into_slice(&mut out, &input, Kernel::Auto)?;
1174 for (a, b) in out.iter().zip(base.values.iter()) {
1175 if a.is_nan() || b.is_nan() {
1176 assert!(a.is_nan() && b.is_nan());
1177 } else {
1178 assert!((a - b).abs() < 1e-12);
1179 }
1180 }
1181 Ok(())
1182 }
1183
1184 #[test]
1185 fn volume_weighted_rsi_stream_matches_batch() -> Result<(), Box<dyn Error>> {
1186 let (close, volume) = sample_close_volume(220);
1187 let period = 12;
1188 let input = VolumeWeightedRsiInput::from_slices(
1189 &close,
1190 &volume,
1191 VolumeWeightedRsiParams {
1192 period: Some(period),
1193 },
1194 );
1195 let batch = volume_weighted_rsi(&input)?;
1196 let mut stream = VolumeWeightedRsiStream::try_new(VolumeWeightedRsiParams {
1197 period: Some(period),
1198 })?;
1199 let mut values = Vec::with_capacity(close.len());
1200 for (&c, &v) in close.iter().zip(volume.iter()) {
1201 values.push(stream.update(c, v).unwrap_or(f64::NAN));
1202 }
1203 for (a, b) in values.iter().zip(batch.values.iter()) {
1204 if a.is_nan() || b.is_nan() {
1205 assert!(a.is_nan() && b.is_nan());
1206 } else {
1207 assert!((a - b).abs() < 1e-12);
1208 }
1209 }
1210 Ok(())
1211 }
1212
1213 #[test]
1214 fn volume_weighted_rsi_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1215 let (close, volume) = sample_close_volume(128);
1216 let single = volume_weighted_rsi(&VolumeWeightedRsiInput::from_slices(
1217 &close,
1218 &volume,
1219 VolumeWeightedRsiParams { period: Some(14) },
1220 ))?;
1221 let batch = volume_weighted_rsi_batch_with_kernel(
1222 &close,
1223 &volume,
1224 &VolumeWeightedRsiBatchRange {
1225 period: (14, 14, 0),
1226 },
1227 Kernel::Auto,
1228 )?;
1229 assert_eq!(batch.rows, 1);
1230 assert_eq!(batch.cols, close.len());
1231 for (a, b) in batch.values.iter().zip(single.values.iter()) {
1232 if a.is_nan() || b.is_nan() {
1233 assert!(a.is_nan() && b.is_nan());
1234 } else {
1235 assert!((a - b).abs() < 1e-12);
1236 }
1237 }
1238 Ok(())
1239 }
1240
1241 #[test]
1242 fn volume_weighted_rsi_rejects_invalid_params() {
1243 let (close, volume) = sample_close_volume(16);
1244 let err = volume_weighted_rsi(&VolumeWeightedRsiInput::from_slices(
1245 &close,
1246 &volume,
1247 VolumeWeightedRsiParams { period: Some(0) },
1248 ))
1249 .unwrap_err();
1250 assert!(matches!(err, VolumeWeightedRsiError::InvalidPeriod { .. }));
1251 }
1252
1253 #[test]
1254 fn volume_weighted_rsi_dispatch_compute_returns_value() {
1255 let (close, volume) = sample_close_volume(128);
1256 let params = [ParamKV {
1257 key: "period",
1258 value: ParamValue::Int(14),
1259 }];
1260 let out = compute_cpu(IndicatorComputeRequest {
1261 indicator_id: "volume_weighted_rsi",
1262 output_id: Some("value"),
1263 data: IndicatorDataRef::CloseVolume {
1264 close: &close,
1265 volume: &volume,
1266 },
1267 params: ¶ms,
1268 kernel: Kernel::Auto,
1269 })
1270 .unwrap();
1271 assert_eq!(out.output_id, "value");
1272 assert_eq!(out.cols, close.len());
1273 }
1274}