Skip to main content

vector_ta/indicators/
volume_weighted_rsi.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19    make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23#[cfg(not(target_arch = "wasm32"))]
24use rayon::prelude::*;
25use std::error::Error;
26use thiserror::Error;
27
28#[derive(Debug, Clone)]
29pub enum VolumeWeightedRsiData<'a> {
30    Candles {
31        candles: &'a Candles,
32        close_source: &'a str,
33    },
34    Slices {
35        close: &'a [f64],
36        volume: &'a [f64],
37    },
38}
39
40#[derive(Debug, Clone)]
41pub struct VolumeWeightedRsiOutput {
42    pub values: Vec<f64>,
43}
44
45#[derive(Debug, Clone)]
46#[cfg_attr(
47    all(target_arch = "wasm32", feature = "wasm"),
48    derive(Serialize, Deserialize)
49)]
50pub struct VolumeWeightedRsiParams {
51    pub period: Option<usize>,
52}
53
54impl Default for VolumeWeightedRsiParams {
55    fn default() -> Self {
56        Self { period: Some(14) }
57    }
58}
59
60#[derive(Debug, Clone)]
61pub struct VolumeWeightedRsiInput<'a> {
62    pub data: VolumeWeightedRsiData<'a>,
63    pub params: VolumeWeightedRsiParams,
64}
65
66impl<'a> VolumeWeightedRsiInput<'a> {
67    #[inline]
68    pub fn from_candles(
69        candles: &'a Candles,
70        close_source: &'a str,
71        params: VolumeWeightedRsiParams,
72    ) -> Self {
73        Self {
74            data: VolumeWeightedRsiData::Candles {
75                candles,
76                close_source,
77            },
78            params,
79        }
80    }
81
82    #[inline]
83    pub fn from_slices(
84        close: &'a [f64],
85        volume: &'a [f64],
86        params: VolumeWeightedRsiParams,
87    ) -> Self {
88        Self {
89            data: VolumeWeightedRsiData::Slices { close, volume },
90            params,
91        }
92    }
93
94    #[inline]
95    pub fn with_default_candles(candles: &'a Candles) -> Self {
96        Self::from_candles(candles, "close", VolumeWeightedRsiParams::default())
97    }
98
99    #[inline]
100    pub fn get_period(&self) -> usize {
101        self.params.period.unwrap_or(14)
102    }
103}
104
105#[derive(Copy, Clone, Debug)]
106pub struct VolumeWeightedRsiBuilder {
107    period: Option<usize>,
108    kernel: Kernel,
109}
110
111impl Default for VolumeWeightedRsiBuilder {
112    fn default() -> Self {
113        Self {
114            period: None,
115            kernel: Kernel::Auto,
116        }
117    }
118}
119
120impl VolumeWeightedRsiBuilder {
121    #[inline(always)]
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    #[inline(always)]
127    pub fn period(mut self, value: usize) -> Self {
128        self.period = Some(value);
129        self
130    }
131
132    #[inline(always)]
133    pub fn kernel(mut self, value: Kernel) -> Self {
134        self.kernel = value;
135        self
136    }
137
138    #[inline(always)]
139    pub fn apply(
140        self,
141        candles: &Candles,
142    ) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
143        let params = VolumeWeightedRsiParams {
144            period: self.period,
145        };
146        volume_weighted_rsi_with_kernel(
147            &VolumeWeightedRsiInput::from_candles(candles, "close", params),
148            self.kernel,
149        )
150    }
151
152    #[inline(always)]
153    pub fn apply_slices(
154        self,
155        close: &[f64],
156        volume: &[f64],
157    ) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
158        let params = VolumeWeightedRsiParams {
159            period: self.period,
160        };
161        volume_weighted_rsi_with_kernel(
162            &VolumeWeightedRsiInput::from_slices(close, volume, params),
163            self.kernel,
164        )
165    }
166
167    #[inline(always)]
168    pub fn into_stream(self) -> Result<VolumeWeightedRsiStream, VolumeWeightedRsiError> {
169        VolumeWeightedRsiStream::try_new(VolumeWeightedRsiParams {
170            period: self.period,
171        })
172    }
173}
174
175#[derive(Debug, Error)]
176pub enum VolumeWeightedRsiError {
177    #[error("volume_weighted_rsi: Input data slice is empty.")]
178    EmptyInputData,
179    #[error(
180        "volume_weighted_rsi: Input length mismatch: close = {close_len}, volume = {volume_len}"
181    )]
182    InputLengthMismatch { close_len: usize, volume_len: usize },
183    #[error("volume_weighted_rsi: All values are NaN.")]
184    AllValuesNaN,
185    #[error("volume_weighted_rsi: Invalid period: period = {period}, data length = {data_len}")]
186    InvalidPeriod { period: usize, data_len: usize },
187    #[error("volume_weighted_rsi: Not enough valid data: needed = {needed}, valid = {valid}")]
188    NotEnoughValidData { needed: usize, valid: usize },
189    #[error("volume_weighted_rsi: Output length mismatch: expected = {expected}, got = {got}")]
190    OutputLengthMismatch { expected: usize, got: usize },
191    #[error("volume_weighted_rsi: Invalid range: start={start}, end={end}, step={step}")]
192    InvalidRange {
193        start: usize,
194        end: usize,
195        step: usize,
196    },
197    #[error("volume_weighted_rsi: Invalid kernel for batch: {0:?}")]
198    InvalidKernelForBatch(Kernel),
199    #[error(
200        "volume_weighted_rsi: Output length mismatch: dst = {dst_len}, expected = {expected_len}"
201    )]
202    MismatchedOutputLen { dst_len: usize, expected_len: usize },
203    #[error("volume_weighted_rsi: Invalid input: {msg}")]
204    InvalidInput { msg: String },
205}
206
207#[derive(Debug, Clone)]
208pub struct VolumeWeightedRsiStream {
209    period: usize,
210    inv_period: f64,
211    beta: f64,
212    prev_close: f64,
213    has_prev: bool,
214    seeded: usize,
215    sum_up: f64,
216    sum_down: f64,
217    avg_up: f64,
218    avg_down: f64,
219}
220
221impl VolumeWeightedRsiStream {
222    #[inline(always)]
223    pub fn try_new(params: VolumeWeightedRsiParams) -> Result<Self, VolumeWeightedRsiError> {
224        let period = params.period.unwrap_or(14);
225        if period == 0 {
226            return Err(VolumeWeightedRsiError::InvalidPeriod {
227                period,
228                data_len: 0,
229            });
230        }
231        let inv_period = 1.0 / period as f64;
232        Ok(Self {
233            period,
234            inv_period,
235            beta: 1.0 - inv_period,
236            prev_close: f64::NAN,
237            has_prev: false,
238            seeded: 0,
239            sum_up: 0.0,
240            sum_down: 0.0,
241            avg_up: 0.0,
242            avg_down: 0.0,
243        })
244    }
245
246    #[inline(always)]
247    pub fn reset(&mut self) {
248        self.prev_close = f64::NAN;
249        self.has_prev = false;
250        self.seeded = 0;
251        self.sum_up = 0.0;
252        self.sum_down = 0.0;
253        self.avg_up = 0.0;
254        self.avg_down = 0.0;
255    }
256
257    #[inline(always)]
258    pub fn update(&mut self, close: f64, volume: f64) -> Option<f64> {
259        if !is_valid_pair(close, volume) {
260            self.reset();
261            return None;
262        }
263
264        let (up, down) = if self.has_prev {
265            if close > self.prev_close {
266                (volume, 0.0)
267            } else if close < self.prev_close {
268                (0.0, volume)
269            } else {
270                (0.0, 0.0)
271            }
272        } else {
273            (0.0, 0.0)
274        };
275
276        self.prev_close = close;
277        self.has_prev = true;
278
279        if self.seeded < self.period {
280            self.sum_up += up;
281            self.sum_down += down;
282            self.seeded += 1;
283            if self.seeded < self.period {
284                return None;
285            }
286            self.avg_up = self.sum_up * self.inv_period;
287            self.avg_down = self.sum_down * self.inv_period;
288            return Some(rsi_from_components(self.avg_up, self.avg_down));
289        }
290
291        self.avg_up = self.avg_up.mul_add(self.beta, self.inv_period * up);
292        self.avg_down = self.avg_down.mul_add(self.beta, self.inv_period * down);
293        Some(rsi_from_components(self.avg_up, self.avg_down))
294    }
295
296    #[inline(always)]
297    pub fn get_warmup_period(&self) -> usize {
298        self.period.saturating_sub(1)
299    }
300}
301
302#[inline(always)]
303fn is_valid_pair(close: f64, volume: f64) -> bool {
304    close.is_finite() && volume.is_finite()
305}
306
307#[inline(always)]
308fn rsi_from_components(avg_up: f64, avg_down: f64) -> f64 {
309    let denom = avg_up + avg_down;
310    if denom == 0.0 {
311        50.0
312    } else {
313        100.0 * avg_up / denom
314    }
315}
316
317#[inline(always)]
318fn longest_valid_pair_run(close: &[f64], volume: &[f64]) -> usize {
319    let mut best = 0usize;
320    let mut cur = 0usize;
321    for (&c, &v) in close.iter().zip(volume.iter()) {
322        if is_valid_pair(c, v) {
323            cur += 1;
324            if cur > best {
325                best = cur;
326            }
327        } else {
328            cur = 0;
329        }
330    }
331    best
332}
333
334#[inline(always)]
335fn input_slices<'a>(
336    input: &'a VolumeWeightedRsiInput<'a>,
337) -> Result<(&'a [f64], &'a [f64]), VolumeWeightedRsiError> {
338    match &input.data {
339        VolumeWeightedRsiData::Candles {
340            candles,
341            close_source,
342        } => Ok((
343            source_type(candles, close_source),
344            candles.volume.as_slice(),
345        )),
346        VolumeWeightedRsiData::Slices { close, volume } => Ok((*close, *volume)),
347    }
348}
349
350#[inline(always)]
351fn validate_common(
352    close: &[f64],
353    volume: &[f64],
354    period: usize,
355) -> Result<(), VolumeWeightedRsiError> {
356    if close.is_empty() || volume.is_empty() {
357        return Err(VolumeWeightedRsiError::EmptyInputData);
358    }
359    if close.len() != volume.len() {
360        return Err(VolumeWeightedRsiError::InputLengthMismatch {
361            close_len: close.len(),
362            volume_len: volume.len(),
363        });
364    }
365    if period == 0 || period > close.len() {
366        return Err(VolumeWeightedRsiError::InvalidPeriod {
367            period,
368            data_len: close.len(),
369        });
370    }
371
372    let max_run = longest_valid_pair_run(close, volume);
373    if max_run == 0 {
374        return Err(VolumeWeightedRsiError::AllValuesNaN);
375    }
376    if max_run < period {
377        return Err(VolumeWeightedRsiError::NotEnoughValidData {
378            needed: period,
379            valid: max_run,
380        });
381    }
382    Ok(())
383}
384
385#[inline(always)]
386fn compute_row(close: &[f64], volume: &[f64], period: usize, out: &mut [f64]) {
387    let inv_period = 1.0 / period as f64;
388    let beta = 1.0 - inv_period;
389
390    let len = close.len();
391    let mut i = 0usize;
392    while i < len {
393        while i < len && !is_valid_pair(close[i], volume[i]) {
394            out[i] = f64::NAN;
395            i += 1;
396        }
397        if i >= len {
398            break;
399        }
400
401        let seg_start = i;
402        i += 1;
403        while i < len && is_valid_pair(close[i], volume[i]) {
404            i += 1;
405        }
406        let seg_end = i;
407        let seg_len = seg_end - seg_start;
408
409        let warm_end = seg_start + period.saturating_sub(1);
410        let prefix_end = warm_end.min(seg_end);
411        for v in &mut out[seg_start..prefix_end] {
412            *v = f64::NAN;
413        }
414        if seg_len < period {
415            continue;
416        }
417
418        let mut sum_up = 0.0f64;
419        let mut sum_down = 0.0f64;
420        let mut prev_close = close[seg_start];
421        let seed_end = seg_start + period;
422        let mut j = seg_start;
423        while j < seed_end {
424            let c = close[j];
425            let vol = volume[j];
426            let (up, down) = if j == seg_start {
427                (0.0, 0.0)
428            } else if c > prev_close {
429                (vol, 0.0)
430            } else if c < prev_close {
431                (0.0, vol)
432            } else {
433                (0.0, 0.0)
434            };
435            sum_up += up;
436            sum_down += down;
437            prev_close = c;
438            j += 1;
439        }
440
441        let mut avg_up = sum_up * inv_period;
442        let mut avg_down = sum_down * inv_period;
443        out[seed_end - 1] = rsi_from_components(avg_up, avg_down);
444
445        let mut k = seed_end;
446        while k < seg_end {
447            let c = close[k];
448            let vol = volume[k];
449            let (up, down) = if c > prev_close {
450                (vol, 0.0)
451            } else if c < prev_close {
452                (0.0, vol)
453            } else {
454                (0.0, 0.0)
455            };
456            avg_up = avg_up.mul_add(beta, inv_period * up);
457            avg_down = avg_down.mul_add(beta, inv_period * down);
458            out[k] = rsi_from_components(avg_up, avg_down);
459            prev_close = c;
460            k += 1;
461        }
462    }
463}
464
465#[inline]
466pub fn volume_weighted_rsi(
467    input: &VolumeWeightedRsiInput,
468) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
469    volume_weighted_rsi_with_kernel(input, Kernel::Auto)
470}
471
472pub fn volume_weighted_rsi_with_kernel(
473    input: &VolumeWeightedRsiInput,
474    kernel: Kernel,
475) -> Result<VolumeWeightedRsiOutput, VolumeWeightedRsiError> {
476    let (close, volume) = input_slices(input)?;
477    let period = input.get_period();
478    validate_common(close, volume, period)?;
479
480    let _chosen = match kernel {
481        Kernel::Auto => detect_best_kernel(),
482        other => other,
483    };
484
485    let mut out = alloc_with_nan_prefix(close.len(), 0);
486    out.fill(f64::NAN);
487    compute_row(close, volume, period, &mut out);
488    Ok(VolumeWeightedRsiOutput { values: out })
489}
490
491pub fn volume_weighted_rsi_into_slice(
492    dst: &mut [f64],
493    input: &VolumeWeightedRsiInput,
494    kernel: Kernel,
495) -> Result<(), VolumeWeightedRsiError> {
496    let (close, volume) = input_slices(input)?;
497    let period = input.get_period();
498    validate_common(close, volume, period)?;
499
500    if dst.len() != close.len() {
501        return Err(VolumeWeightedRsiError::OutputLengthMismatch {
502            expected: close.len(),
503            got: dst.len(),
504        });
505    }
506
507    let _chosen = match kernel {
508        Kernel::Auto => detect_best_kernel(),
509        other => other,
510    };
511
512    dst.fill(f64::NAN);
513    compute_row(close, volume, period, dst);
514    Ok(())
515}
516
517#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
518pub fn volume_weighted_rsi_into(
519    input: &VolumeWeightedRsiInput,
520    out: &mut [f64],
521) -> Result<(), VolumeWeightedRsiError> {
522    volume_weighted_rsi_into_slice(out, input, Kernel::Auto)
523}
524
525#[derive(Debug, Clone, Copy)]
526pub struct VolumeWeightedRsiBatchRange {
527    pub period: (usize, usize, usize),
528}
529
530impl Default for VolumeWeightedRsiBatchRange {
531    fn default() -> Self {
532        Self {
533            period: (14, 14, 0),
534        }
535    }
536}
537
538#[derive(Debug, Clone)]
539pub struct VolumeWeightedRsiBatchOutput {
540    pub values: Vec<f64>,
541    pub combos: Vec<VolumeWeightedRsiParams>,
542    pub rows: usize,
543    pub cols: usize,
544}
545
546#[derive(Debug, Clone, Copy)]
547pub struct VolumeWeightedRsiBatchBuilder {
548    range: VolumeWeightedRsiBatchRange,
549    kernel: Kernel,
550}
551
552impl Default for VolumeWeightedRsiBatchBuilder {
553    fn default() -> Self {
554        Self {
555            range: VolumeWeightedRsiBatchRange::default(),
556            kernel: Kernel::Auto,
557        }
558    }
559}
560
561impl VolumeWeightedRsiBatchBuilder {
562    #[inline(always)]
563    pub fn new() -> Self {
564        Self::default()
565    }
566
567    #[inline(always)]
568    pub fn kernel(mut self, value: Kernel) -> Self {
569        self.kernel = value;
570        self
571    }
572
573    #[inline(always)]
574    pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
575        self.range.period = (start, end, step);
576        self
577    }
578
579    #[inline(always)]
580    pub fn period_static(mut self, value: usize) -> Self {
581        self.range.period = (value, value, 0);
582        self
583    }
584
585    #[inline(always)]
586    pub fn apply_slices(
587        self,
588        close: &[f64],
589        volume: &[f64],
590    ) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
591        volume_weighted_rsi_batch_with_kernel(close, volume, &self.range, self.kernel)
592    }
593
594    #[inline(always)]
595    pub fn apply_candles(
596        self,
597        candles: &Candles,
598    ) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
599        volume_weighted_rsi_batch_with_kernel(
600            candles.close.as_slice(),
601            candles.volume.as_slice(),
602            &self.range,
603            self.kernel,
604        )
605    }
606}
607
608#[inline(always)]
609fn expand_grid_checked(
610    range: &VolumeWeightedRsiBatchRange,
611) -> Result<Vec<VolumeWeightedRsiParams>, VolumeWeightedRsiError> {
612    let (start, end, step) = range.period;
613    if start == 0 || end == 0 {
614        return Err(VolumeWeightedRsiError::InvalidRange { start, end, step });
615    }
616    if step == 0 {
617        return Ok(vec![VolumeWeightedRsiParams {
618            period: Some(start),
619        }]);
620    }
621    if start > end {
622        return Err(VolumeWeightedRsiError::InvalidRange { start, end, step });
623    }
624
625    let mut out = Vec::new();
626    let mut cur = start;
627    loop {
628        out.push(VolumeWeightedRsiParams { period: Some(cur) });
629        if cur >= end {
630            break;
631        }
632        let next = cur.saturating_add(step);
633        if next <= cur {
634            return Err(VolumeWeightedRsiError::InvalidRange { start, end, step });
635        }
636        cur = next.min(end);
637        if cur == *out.last().and_then(|p| p.period.as_ref()).unwrap() {
638            break;
639        }
640    }
641    Ok(out)
642}
643
644#[inline(always)]
645pub fn expand_grid_volume_weighted_rsi(
646    range: &VolumeWeightedRsiBatchRange,
647) -> Vec<VolumeWeightedRsiParams> {
648    expand_grid_checked(range).unwrap_or_default()
649}
650
651pub fn volume_weighted_rsi_batch_with_kernel(
652    close: &[f64],
653    volume: &[f64],
654    sweep: &VolumeWeightedRsiBatchRange,
655    kernel: Kernel,
656) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
657    match kernel {
658        Kernel::Auto
659        | Kernel::Scalar
660        | Kernel::ScalarBatch
661        | Kernel::Avx2
662        | Kernel::Avx2Batch
663        | Kernel::Avx512
664        | Kernel::Avx512Batch => {}
665        other => return Err(VolumeWeightedRsiError::InvalidKernelForBatch(other)),
666    }
667
668    validate_common(close, volume, 1)?;
669    let combos = expand_grid_checked(sweep)?;
670    let max_period = combos
671        .iter()
672        .map(|params| params.period.unwrap_or(14))
673        .max()
674        .unwrap_or(0);
675    validate_common(close, volume, max_period)?;
676
677    let rows = combos.len();
678    let cols = close.len();
679    let mut values_mu = make_uninit_matrix(rows, cols);
680    let warmups: Vec<usize> = combos
681        .iter()
682        .map(|params| params.period.unwrap_or(14).saturating_sub(1))
683        .collect();
684    init_matrix_prefixes(&mut values_mu, cols, &warmups);
685    let mut values = unsafe {
686        Vec::from_raw_parts(
687            values_mu.as_mut_ptr() as *mut f64,
688            values_mu.len(),
689            values_mu.capacity(),
690        )
691    };
692    std::mem::forget(values_mu);
693
694    volume_weighted_rsi_batch_inner_into(close, volume, sweep, kernel, true, &mut values)?;
695
696    Ok(VolumeWeightedRsiBatchOutput {
697        values,
698        combos,
699        rows,
700        cols,
701    })
702}
703
704pub fn volume_weighted_rsi_batch_slice(
705    close: &[f64],
706    volume: &[f64],
707    sweep: &VolumeWeightedRsiBatchRange,
708    kernel: Kernel,
709) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
710    volume_weighted_rsi_batch_inner(close, volume, sweep, kernel, false)
711}
712
713pub fn volume_weighted_rsi_batch_par_slice(
714    close: &[f64],
715    volume: &[f64],
716    sweep: &VolumeWeightedRsiBatchRange,
717    kernel: Kernel,
718) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
719    volume_weighted_rsi_batch_inner(close, volume, sweep, kernel, true)
720}
721
722fn volume_weighted_rsi_batch_inner(
723    close: &[f64],
724    volume: &[f64],
725    sweep: &VolumeWeightedRsiBatchRange,
726    kernel: Kernel,
727    parallel: bool,
728) -> Result<VolumeWeightedRsiBatchOutput, VolumeWeightedRsiError> {
729    let combos = expand_grid_checked(sweep)?;
730    let rows = combos.len();
731    let cols = close.len();
732    let total = rows
733        .checked_mul(cols)
734        .ok_or_else(|| VolumeWeightedRsiError::InvalidInput {
735            msg: "volume_weighted_rsi: rows*cols overflow in batch".to_string(),
736        })?;
737
738    let mut values_mu = make_uninit_matrix(rows, cols);
739    let warmups: Vec<usize> = combos
740        .iter()
741        .map(|params| params.period.unwrap_or(14).saturating_sub(1))
742        .collect();
743    init_matrix_prefixes(&mut values_mu, cols, &warmups);
744    let mut values = unsafe {
745        Vec::from_raw_parts(
746            values_mu.as_mut_ptr() as *mut f64,
747            values_mu.len(),
748            values_mu.capacity(),
749        )
750    };
751    std::mem::forget(values_mu);
752
753    debug_assert_eq!(values.len(), total);
754
755    volume_weighted_rsi_batch_inner_into(close, volume, sweep, kernel, parallel, &mut values)?;
756
757    Ok(VolumeWeightedRsiBatchOutput {
758        values,
759        combos,
760        rows,
761        cols,
762    })
763}
764
765fn volume_weighted_rsi_batch_inner_into(
766    close: &[f64],
767    volume: &[f64],
768    sweep: &VolumeWeightedRsiBatchRange,
769    kernel: Kernel,
770    parallel: bool,
771    out: &mut [f64],
772) -> Result<Vec<VolumeWeightedRsiParams>, VolumeWeightedRsiError> {
773    match kernel {
774        Kernel::Auto
775        | Kernel::Scalar
776        | Kernel::ScalarBatch
777        | Kernel::Avx2
778        | Kernel::Avx2Batch
779        | Kernel::Avx512
780        | Kernel::Avx512Batch => {}
781        other => return Err(VolumeWeightedRsiError::InvalidKernelForBatch(other)),
782    }
783
784    let combos = expand_grid_checked(sweep)?;
785    let len = close.len();
786    if len == 0 || volume.is_empty() {
787        return Err(VolumeWeightedRsiError::EmptyInputData);
788    }
789    if len != volume.len() {
790        return Err(VolumeWeightedRsiError::InputLengthMismatch {
791            close_len: len,
792            volume_len: volume.len(),
793        });
794    }
795
796    let total =
797        combos
798            .len()
799            .checked_mul(len)
800            .ok_or_else(|| VolumeWeightedRsiError::InvalidInput {
801                msg: "volume_weighted_rsi: rows*cols overflow in batch_into".to_string(),
802            })?;
803    if out.len() != total {
804        return Err(VolumeWeightedRsiError::MismatchedOutputLen {
805            dst_len: out.len(),
806            expected_len: total,
807        });
808    }
809
810    let max_period = combos
811        .iter()
812        .map(|params| params.period.unwrap_or(14))
813        .max()
814        .unwrap_or(0);
815    validate_common(close, volume, max_period)?;
816
817    let _chosen = match kernel {
818        Kernel::Auto => detect_best_batch_kernel(),
819        other => other,
820    };
821
822    let worker = |row: usize, dst: &mut [f64]| {
823        dst.fill(f64::NAN);
824        let period = combos[row].period.unwrap_or(14);
825        compute_row(close, volume, period, dst);
826    };
827
828    #[cfg(not(target_arch = "wasm32"))]
829    if parallel {
830        out.par_chunks_mut(len)
831            .enumerate()
832            .for_each(|(row, dst)| worker(row, dst));
833    } else {
834        for (row, dst) in out.chunks_mut(len).enumerate() {
835            worker(row, dst);
836        }
837    }
838
839    #[cfg(target_arch = "wasm32")]
840    {
841        let _ = parallel;
842        for (row, dst) in out.chunks_mut(len).enumerate() {
843            worker(row, dst);
844        }
845    }
846
847    Ok(combos)
848}
849
850#[cfg(feature = "python")]
851#[pyfunction(name = "volume_weighted_rsi")]
852#[pyo3(signature = (close, volume, period=14, kernel=None))]
853pub fn volume_weighted_rsi_py<'py>(
854    py: Python<'py>,
855    close: PyReadonlyArray1<'py, f64>,
856    volume: PyReadonlyArray1<'py, f64>,
857    period: usize,
858    kernel: Option<&str>,
859) -> PyResult<Bound<'py, PyArray1<f64>>> {
860    let close = close.as_slice()?;
861    let volume = volume.as_slice()?;
862    let kern = validate_kernel(kernel, false)?;
863    let input = VolumeWeightedRsiInput::from_slices(
864        close,
865        volume,
866        VolumeWeightedRsiParams {
867            period: Some(period),
868        },
869    );
870    let out = py
871        .allow_threads(|| volume_weighted_rsi_with_kernel(&input, kern))
872        .map_err(|e| PyValueError::new_err(e.to_string()))?;
873    Ok(out.values.into_pyarray(py))
874}
875
876#[cfg(feature = "python")]
877#[pyclass(name = "VolumeWeightedRsiStream")]
878pub struct VolumeWeightedRsiStreamPy {
879    stream: VolumeWeightedRsiStream,
880}
881
882#[cfg(feature = "python")]
883#[pymethods]
884impl VolumeWeightedRsiStreamPy {
885    #[new]
886    fn new(period: usize) -> PyResult<Self> {
887        let stream = VolumeWeightedRsiStream::try_new(VolumeWeightedRsiParams {
888            period: Some(period),
889        })
890        .map_err(|e| PyValueError::new_err(e.to_string()))?;
891        Ok(Self { stream })
892    }
893
894    fn update(&mut self, close: f64, volume: f64) -> Option<f64> {
895        self.stream.update(close, volume)
896    }
897}
898
899#[cfg(feature = "python")]
900#[pyfunction(name = "volume_weighted_rsi_batch")]
901#[pyo3(signature = (close, volume, period_range=(14,14,0), kernel=None))]
902pub fn volume_weighted_rsi_batch_py<'py>(
903    py: Python<'py>,
904    close: PyReadonlyArray1<'py, f64>,
905    volume: PyReadonlyArray1<'py, f64>,
906    period_range: (usize, usize, usize),
907    kernel: Option<&str>,
908) -> PyResult<Bound<'py, PyDict>> {
909    let close = close.as_slice()?;
910    let volume = volume.as_slice()?;
911    let kern = validate_kernel(kernel, true)?;
912
913    let output = py
914        .allow_threads(|| {
915            volume_weighted_rsi_batch_with_kernel(
916                close,
917                volume,
918                &VolumeWeightedRsiBatchRange {
919                    period: period_range,
920                },
921                kern,
922            )
923        })
924        .map_err(|e| PyValueError::new_err(e.to_string()))?;
925
926    let rows = output.rows;
927    let cols = output.cols;
928    let dict = PyDict::new(py);
929    dict.set_item(
930        "values",
931        output.values.into_pyarray(py).reshape((rows, cols))?,
932    )?;
933    dict.set_item(
934        "periods",
935        output
936            .combos
937            .iter()
938            .map(|params| params.period.unwrap_or(14) as u64)
939            .collect::<Vec<_>>()
940            .into_pyarray(py),
941    )?;
942    dict.set_item("rows", rows)?;
943    dict.set_item("cols", cols)?;
944    Ok(dict)
945}
946
947#[cfg(feature = "python")]
948pub fn register_volume_weighted_rsi_module(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
949    m.add_function(wrap_pyfunction!(volume_weighted_rsi_py, m)?)?;
950    m.add_function(wrap_pyfunction!(volume_weighted_rsi_batch_py, m)?)?;
951    m.add_class::<VolumeWeightedRsiStreamPy>()?;
952    Ok(())
953}
954
955#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
956#[derive(Debug, Clone, Serialize, Deserialize)]
957pub struct VolumeWeightedRsiBatchConfig {
958    pub period_range: Vec<usize>,
959}
960
961#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
962#[wasm_bindgen(js_name = volume_weighted_rsi_js)]
963pub fn volume_weighted_rsi_js(
964    close: &[f64],
965    volume: &[f64],
966    period: usize,
967) -> Result<JsValue, JsValue> {
968    let input = VolumeWeightedRsiInput::from_slices(
969        close,
970        volume,
971        VolumeWeightedRsiParams {
972            period: Some(period),
973        },
974    );
975    let out = volume_weighted_rsi_with_kernel(&input, Kernel::Auto)
976        .map_err(|e| JsValue::from_str(&e.to_string()))?;
977    serde_wasm_bindgen::to_value(&out.values).map_err(|e| JsValue::from_str(&e.to_string()))
978}
979
980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
981#[wasm_bindgen(js_name = volume_weighted_rsi_batch_js)]
982pub fn volume_weighted_rsi_batch_js(
983    close: &[f64],
984    volume: &[f64],
985    config: JsValue,
986) -> Result<JsValue, JsValue> {
987    let config: VolumeWeightedRsiBatchConfig = serde_wasm_bindgen::from_value(config)
988        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
989    if config.period_range.len() != 3 {
990        return Err(JsValue::from_str(
991            "Invalid config: period_range must have exactly 3 elements [start, end, step]",
992        ));
993    }
994
995    let out = volume_weighted_rsi_batch_with_kernel(
996        close,
997        volume,
998        &VolumeWeightedRsiBatchRange {
999            period: (
1000                config.period_range[0],
1001                config.period_range[1],
1002                config.period_range[2],
1003            ),
1004        },
1005        Kernel::Auto,
1006    )
1007    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1008
1009    let obj = js_sys::Object::new();
1010    js_sys::Reflect::set(
1011        &obj,
1012        &JsValue::from_str("values"),
1013        &serde_wasm_bindgen::to_value(&out.values).unwrap(),
1014    )?;
1015    js_sys::Reflect::set(
1016        &obj,
1017        &JsValue::from_str("rows"),
1018        &JsValue::from_f64(out.rows as f64),
1019    )?;
1020    js_sys::Reflect::set(
1021        &obj,
1022        &JsValue::from_str("cols"),
1023        &JsValue::from_f64(out.cols as f64),
1024    )?;
1025    js_sys::Reflect::set(
1026        &obj,
1027        &JsValue::from_str("combos"),
1028        &serde_wasm_bindgen::to_value(&out.combos).unwrap(),
1029    )?;
1030    Ok(obj.into())
1031}
1032
1033#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1034#[wasm_bindgen]
1035pub fn volume_weighted_rsi_alloc(len: usize) -> *mut f64 {
1036    let mut vec = Vec::<f64>::with_capacity(len);
1037    let ptr = vec.as_mut_ptr();
1038    std::mem::forget(vec);
1039    ptr
1040}
1041
1042#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1043#[wasm_bindgen]
1044pub fn volume_weighted_rsi_free(ptr: *mut f64, len: usize) {
1045    if !ptr.is_null() {
1046        unsafe {
1047            let _ = Vec::from_raw_parts(ptr, len, len);
1048        }
1049    }
1050}
1051
1052#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1053#[wasm_bindgen]
1054pub fn volume_weighted_rsi_into(
1055    close_ptr: *const f64,
1056    volume_ptr: *const f64,
1057    out_ptr: *mut f64,
1058    len: usize,
1059    period: usize,
1060) -> Result<(), JsValue> {
1061    if close_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1062        return Err(JsValue::from_str(
1063            "null pointer passed to volume_weighted_rsi_into",
1064        ));
1065    }
1066
1067    unsafe {
1068        let close = std::slice::from_raw_parts(close_ptr, len);
1069        let volume = std::slice::from_raw_parts(volume_ptr, len);
1070        let out = std::slice::from_raw_parts_mut(out_ptr, len);
1071        let input = VolumeWeightedRsiInput::from_slices(
1072            close,
1073            volume,
1074            VolumeWeightedRsiParams {
1075                period: Some(period),
1076            },
1077        );
1078        volume_weighted_rsi_into_slice(out, &input, Kernel::Auto)
1079            .map_err(|e| JsValue::from_str(&e.to_string()))
1080    }
1081}
1082
1083#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1084#[wasm_bindgen]
1085pub fn volume_weighted_rsi_batch_into(
1086    close_ptr: *const f64,
1087    volume_ptr: *const f64,
1088    out_ptr: *mut f64,
1089    len: usize,
1090    period_start: usize,
1091    period_end: usize,
1092    period_step: usize,
1093) -> Result<usize, JsValue> {
1094    if close_ptr.is_null() || volume_ptr.is_null() || out_ptr.is_null() {
1095        return Err(JsValue::from_str(
1096            "null pointer passed to volume_weighted_rsi_batch_into",
1097        ));
1098    }
1099
1100    let sweep = VolumeWeightedRsiBatchRange {
1101        period: (period_start, period_end, period_step),
1102    };
1103    let combos = expand_grid_checked(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1104    let rows = combos.len();
1105    let total = rows
1106        .checked_mul(len)
1107        .ok_or_else(|| JsValue::from_str("rows*cols overflow in volume_weighted_rsi_batch_into"))?;
1108
1109    unsafe {
1110        let close = std::slice::from_raw_parts(close_ptr, len);
1111        let volume = std::slice::from_raw_parts(volume_ptr, len);
1112        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1113        volume_weighted_rsi_batch_inner_into(close, volume, &sweep, Kernel::Auto, false, out)
1114            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1115    }
1116
1117    Ok(rows)
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122    use super::*;
1123    use crate::indicators::dispatch::{
1124        compute_cpu, IndicatorComputeRequest, IndicatorDataRef, ParamKV, ParamValue,
1125    };
1126
1127    fn sample_close_volume(len: usize) -> (Vec<f64>, Vec<f64>) {
1128        let close: Vec<f64> = (0..len)
1129            .map(|i| 100.0 + ((i as f64) * 0.13).sin() * 4.0 + (i as f64) * 0.02)
1130            .collect();
1131        let volume: Vec<f64> = (0..len)
1132            .map(|i| 1000.0 + ((i as f64) * 0.17).cos().abs() * 250.0 + (i % 11) as f64 * 7.0)
1133            .collect();
1134        (close, volume)
1135    }
1136
1137    fn naive_volume_weighted_rsi(close: &[f64], volume: &[f64], period: usize) -> Vec<f64> {
1138        let mut out = vec![f64::NAN; close.len()];
1139        compute_row(close, volume, period, &mut out);
1140        out
1141    }
1142
1143    #[test]
1144    fn volume_weighted_rsi_matches_naive() -> Result<(), Box<dyn Error>> {
1145        let (close, volume) = sample_close_volume(256);
1146        let input = VolumeWeightedRsiInput::from_slices(
1147            &close,
1148            &volume,
1149            VolumeWeightedRsiParams { period: Some(14) },
1150        );
1151        let out = volume_weighted_rsi(&input)?;
1152        let expected = naive_volume_weighted_rsi(&close, &volume, 14);
1153        for (a, b) in out.values.iter().zip(expected.iter()) {
1154            if a.is_nan() || b.is_nan() {
1155                assert!(a.is_nan() && b.is_nan());
1156            } else {
1157                assert!((a - b).abs() < 1e-12);
1158            }
1159        }
1160        Ok(())
1161    }
1162
1163    #[test]
1164    fn volume_weighted_rsi_into_matches_api() -> Result<(), Box<dyn Error>> {
1165        let (close, volume) = sample_close_volume(200);
1166        let input = VolumeWeightedRsiInput::from_slices(
1167            &close,
1168            &volume,
1169            VolumeWeightedRsiParams { period: Some(10) },
1170        );
1171        let base = volume_weighted_rsi(&input)?;
1172        let mut out = vec![0.0; close.len()];
1173        volume_weighted_rsi_into_slice(&mut out, &input, Kernel::Auto)?;
1174        for (a, b) in out.iter().zip(base.values.iter()) {
1175            if a.is_nan() || b.is_nan() {
1176                assert!(a.is_nan() && b.is_nan());
1177            } else {
1178                assert!((a - b).abs() < 1e-12);
1179            }
1180        }
1181        Ok(())
1182    }
1183
1184    #[test]
1185    fn volume_weighted_rsi_stream_matches_batch() -> Result<(), Box<dyn Error>> {
1186        let (close, volume) = sample_close_volume(220);
1187        let period = 12;
1188        let input = VolumeWeightedRsiInput::from_slices(
1189            &close,
1190            &volume,
1191            VolumeWeightedRsiParams {
1192                period: Some(period),
1193            },
1194        );
1195        let batch = volume_weighted_rsi(&input)?;
1196        let mut stream = VolumeWeightedRsiStream::try_new(VolumeWeightedRsiParams {
1197            period: Some(period),
1198        })?;
1199        let mut values = Vec::with_capacity(close.len());
1200        for (&c, &v) in close.iter().zip(volume.iter()) {
1201            values.push(stream.update(c, v).unwrap_or(f64::NAN));
1202        }
1203        for (a, b) in values.iter().zip(batch.values.iter()) {
1204            if a.is_nan() || b.is_nan() {
1205                assert!(a.is_nan() && b.is_nan());
1206            } else {
1207                assert!((a - b).abs() < 1e-12);
1208            }
1209        }
1210        Ok(())
1211    }
1212
1213    #[test]
1214    fn volume_weighted_rsi_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1215        let (close, volume) = sample_close_volume(128);
1216        let single = volume_weighted_rsi(&VolumeWeightedRsiInput::from_slices(
1217            &close,
1218            &volume,
1219            VolumeWeightedRsiParams { period: Some(14) },
1220        ))?;
1221        let batch = volume_weighted_rsi_batch_with_kernel(
1222            &close,
1223            &volume,
1224            &VolumeWeightedRsiBatchRange {
1225                period: (14, 14, 0),
1226            },
1227            Kernel::Auto,
1228        )?;
1229        assert_eq!(batch.rows, 1);
1230        assert_eq!(batch.cols, close.len());
1231        for (a, b) in batch.values.iter().zip(single.values.iter()) {
1232            if a.is_nan() || b.is_nan() {
1233                assert!(a.is_nan() && b.is_nan());
1234            } else {
1235                assert!((a - b).abs() < 1e-12);
1236            }
1237        }
1238        Ok(())
1239    }
1240
1241    #[test]
1242    fn volume_weighted_rsi_rejects_invalid_params() {
1243        let (close, volume) = sample_close_volume(16);
1244        let err = volume_weighted_rsi(&VolumeWeightedRsiInput::from_slices(
1245            &close,
1246            &volume,
1247            VolumeWeightedRsiParams { period: Some(0) },
1248        ))
1249        .unwrap_err();
1250        assert!(matches!(err, VolumeWeightedRsiError::InvalidPeriod { .. }));
1251    }
1252
1253    #[test]
1254    fn volume_weighted_rsi_dispatch_compute_returns_value() {
1255        let (close, volume) = sample_close_volume(128);
1256        let params = [ParamKV {
1257            key: "period",
1258            value: ParamValue::Int(14),
1259        }];
1260        let out = compute_cpu(IndicatorComputeRequest {
1261            indicator_id: "volume_weighted_rsi",
1262            output_id: Some("value"),
1263            data: IndicatorDataRef::CloseVolume {
1264                close: &close,
1265                volume: &volume,
1266            },
1267            params: &params,
1268            kernel: Kernel::Auto,
1269        })
1270        .unwrap();
1271        assert_eq!(out.output_id, "value");
1272        assert_eq!(out.cols, close.len());
1273    }
1274}