Skip to main content

vector_ta/indicators/
demand_index.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(feature = "python")]
10use pyo3::wrap_pyfunction;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::Candles;
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::mem::ManuallyDrop;
28use thiserror::Error;
29
30#[derive(Debug, Clone)]
31pub enum DemandIndexData<'a> {
32    Candles {
33        candles: &'a Candles,
34    },
35    Slices {
36        high: &'a [f64],
37        low: &'a [f64],
38        close: &'a [f64],
39        volume: &'a [f64],
40    },
41}
42
43#[derive(Debug, Clone)]
44pub struct DemandIndexOutput {
45    pub demand_index: Vec<f64>,
46    pub signal: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
50#[cfg_attr(
51    all(target_arch = "wasm32", feature = "wasm"),
52    derive(Serialize, Deserialize)
53)]
54pub struct DemandIndexParams {
55    pub len_bs: Option<usize>,
56    pub len_bs_ma: Option<usize>,
57    pub len_di_ma: Option<usize>,
58    pub ma_type: Option<String>,
59}
60
61impl Default for DemandIndexParams {
62    fn default() -> Self {
63        Self {
64            len_bs: Some(19),
65            len_bs_ma: Some(19),
66            len_di_ma: Some(19),
67            ma_type: Some("ema".to_string()),
68        }
69    }
70}
71
72#[derive(Debug, Clone)]
73pub struct DemandIndexInput<'a> {
74    pub data: DemandIndexData<'a>,
75    pub params: DemandIndexParams,
76}
77
78impl<'a> DemandIndexInput<'a> {
79    #[inline]
80    pub fn from_candles(candles: &'a Candles, params: DemandIndexParams) -> Self {
81        Self {
82            data: DemandIndexData::Candles { candles },
83            params,
84        }
85    }
86
87    #[inline]
88    pub fn from_slices(
89        high: &'a [f64],
90        low: &'a [f64],
91        close: &'a [f64],
92        volume: &'a [f64],
93        params: DemandIndexParams,
94    ) -> Self {
95        Self {
96            data: DemandIndexData::Slices {
97                high,
98                low,
99                close,
100                volume,
101            },
102            params,
103        }
104    }
105
106    #[inline]
107    pub fn with_default_candles(candles: &'a Candles) -> Self {
108        Self::from_candles(candles, DemandIndexParams::default())
109    }
110
111    #[inline]
112    pub fn get_len_bs(&self) -> usize {
113        self.params.len_bs.unwrap_or(19)
114    }
115
116    #[inline]
117    pub fn get_len_bs_ma(&self) -> usize {
118        self.params.len_bs_ma.unwrap_or(19)
119    }
120
121    #[inline]
122    pub fn get_len_di_ma(&self) -> usize {
123        self.params.len_di_ma.unwrap_or(19)
124    }
125
126    #[inline]
127    pub fn get_ma_type(&self) -> &str {
128        self.params.ma_type.as_deref().unwrap_or("ema")
129    }
130}
131
132#[derive(Copy, Clone, Debug)]
133pub struct DemandIndexBuilder {
134    len_bs: Option<usize>,
135    len_bs_ma: Option<usize>,
136    len_di_ma: Option<usize>,
137    ma_type: Option<MaType>,
138    kernel: Kernel,
139}
140
141impl Default for DemandIndexBuilder {
142    fn default() -> Self {
143        Self {
144            len_bs: None,
145            len_bs_ma: None,
146            len_di_ma: None,
147            ma_type: None,
148            kernel: Kernel::Auto,
149        }
150    }
151}
152
153impl DemandIndexBuilder {
154    #[inline]
155    pub fn new() -> Self {
156        Self::default()
157    }
158
159    #[inline]
160    pub fn len_bs(mut self, len_bs: usize) -> Self {
161        self.len_bs = Some(len_bs);
162        self
163    }
164
165    #[inline]
166    pub fn len_bs_ma(mut self, len_bs_ma: usize) -> Self {
167        self.len_bs_ma = Some(len_bs_ma);
168        self
169    }
170
171    #[inline]
172    pub fn len_di_ma(mut self, len_di_ma: usize) -> Self {
173        self.len_di_ma = Some(len_di_ma);
174        self
175    }
176
177    #[inline]
178    pub fn ma_type(mut self, ma_type: &str) -> Result<Self, DemandIndexError> {
179        self.ma_type = Some(MaType::parse(ma_type)?);
180        Ok(self)
181    }
182
183    #[inline]
184    pub fn kernel(mut self, kernel: Kernel) -> Self {
185        self.kernel = kernel;
186        self
187    }
188
189    #[inline]
190    pub fn apply(self, candles: &Candles) -> Result<DemandIndexOutput, DemandIndexError> {
191        let input = DemandIndexInput::from_candles(
192            candles,
193            DemandIndexParams {
194                len_bs: self.len_bs,
195                len_bs_ma: self.len_bs_ma,
196                len_di_ma: self.len_di_ma,
197                ma_type: Some(self.ma_type.unwrap_or(MaType::Ema).as_str().to_string()),
198            },
199        );
200        demand_index_with_kernel(&input, self.kernel)
201    }
202
203    #[inline]
204    pub fn apply_slices(
205        self,
206        high: &[f64],
207        low: &[f64],
208        close: &[f64],
209        volume: &[f64],
210    ) -> Result<DemandIndexOutput, DemandIndexError> {
211        let input = DemandIndexInput::from_slices(
212            high,
213            low,
214            close,
215            volume,
216            DemandIndexParams {
217                len_bs: self.len_bs,
218                len_bs_ma: self.len_bs_ma,
219                len_di_ma: self.len_di_ma,
220                ma_type: Some(self.ma_type.unwrap_or(MaType::Ema).as_str().to_string()),
221            },
222        );
223        demand_index_with_kernel(&input, self.kernel)
224    }
225
226    #[inline]
227    pub fn into_stream(self) -> Result<DemandIndexStream, DemandIndexError> {
228        DemandIndexStream::try_new(DemandIndexParams {
229            len_bs: self.len_bs,
230            len_bs_ma: self.len_bs_ma,
231            len_di_ma: self.len_di_ma,
232            ma_type: Some(self.ma_type.unwrap_or(MaType::Ema).as_str().to_string()),
233        })
234    }
235}
236
237#[derive(Debug, Error)]
238pub enum DemandIndexError {
239    #[error("demand_index: Input data slice is empty.")]
240    EmptyInputData,
241    #[error("demand_index: All values are NaN.")]
242    AllValuesNaN,
243    #[error("demand_index: Invalid len_bs: len_bs = {len_bs}, data length = {data_len}")]
244    InvalidLenBs { len_bs: usize, data_len: usize },
245    #[error("demand_index: Invalid len_bs_ma: len_bs_ma = {len_bs_ma}, data length = {data_len}")]
246    InvalidLenBsMa { len_bs_ma: usize, data_len: usize },
247    #[error("demand_index: Invalid len_di_ma: len_di_ma = {len_di_ma}, data length = {data_len}")]
248    InvalidLenDiMa { len_di_ma: usize, data_len: usize },
249    #[error("demand_index: Invalid ma_type: {ma_type}")]
250    InvalidMaType { ma_type: String },
251    #[error(
252        "demand_index: Inconsistent slice lengths: high={high_len}, low={low_len}, close={close_len}, volume={volume_len}"
253    )]
254    InconsistentSliceLengths {
255        high_len: usize,
256        low_len: usize,
257        close_len: usize,
258        volume_len: usize,
259    },
260    #[error("demand_index: Not enough valid data: needed = {needed}, valid = {valid}")]
261    NotEnoughValidData { needed: usize, valid: usize },
262    #[error(
263        "demand_index: Output length mismatch: expected = {expected}, demand_index = {demand_index_got}, signal = {signal_got}"
264    )]
265    OutputLengthMismatch {
266        expected: usize,
267        demand_index_got: usize,
268        signal_got: usize,
269    },
270    #[error("demand_index: Invalid range: start={start}, end={end}, step={step}")]
271    InvalidRange {
272        start: String,
273        end: String,
274        step: String,
275    },
276    #[error("demand_index: Invalid kernel for batch: {0:?}")]
277    InvalidKernelForBatch(Kernel),
278}
279
280#[derive(Copy, Clone, Debug, PartialEq, Eq)]
281enum MaType {
282    Ema,
283    Sma,
284    Wma,
285    Rma,
286}
287
288impl MaType {
289    #[inline]
290    fn parse(value: &str) -> Result<Self, DemandIndexError> {
291        if value.eq_ignore_ascii_case("ema") {
292            return Ok(Self::Ema);
293        }
294        if value.eq_ignore_ascii_case("sma") {
295            return Ok(Self::Sma);
296        }
297        if value.eq_ignore_ascii_case("wma") {
298            return Ok(Self::Wma);
299        }
300        if value.eq_ignore_ascii_case("rma") || value.eq_ignore_ascii_case("wilders") {
301            return Ok(Self::Rma);
302        }
303        Err(DemandIndexError::InvalidMaType {
304            ma_type: value.to_string(),
305        })
306    }
307
308    #[inline]
309    fn as_str(self) -> &'static str {
310        match self {
311            Self::Ema => "ema",
312            Self::Sma => "sma",
313            Self::Wma => "wma",
314            Self::Rma => "rma",
315        }
316    }
317}
318
319#[derive(Debug, Clone)]
320struct SmaState {
321    period: usize,
322    values: Vec<f64>,
323    valid: Vec<u8>,
324    idx: usize,
325    count: usize,
326    valid_count: usize,
327    sum: f64,
328}
329
330impl SmaState {
331    #[inline]
332    fn new(period: usize) -> Self {
333        Self {
334            period,
335            values: vec![0.0; period],
336            valid: vec![0u8; period],
337            idx: 0,
338            count: 0,
339            valid_count: 0,
340            sum: 0.0,
341        }
342    }
343
344    #[inline]
345    fn update(&mut self, value: Option<f64>) -> Option<f64> {
346        if self.count >= self.period {
347            let old_idx = self.idx;
348            if self.valid[old_idx] != 0 {
349                self.valid_count = self.valid_count.saturating_sub(1);
350                self.sum -= self.values[old_idx];
351            }
352        } else {
353            self.count += 1;
354        }
355
356        match value {
357            Some(v) if v.is_finite() => {
358                self.values[self.idx] = v;
359                self.valid[self.idx] = 1;
360                self.valid_count += 1;
361                self.sum += v;
362            }
363            _ => {
364                self.values[self.idx] = 0.0;
365                self.valid[self.idx] = 0;
366            }
367        }
368
369        self.idx += 1;
370        if self.idx == self.period {
371            self.idx = 0;
372        }
373
374        if self.count < self.period {
375            None
376        } else if self.valid_count == self.period {
377            Some(self.sum / self.period as f64)
378        } else {
379            Some(f64::NAN)
380        }
381    }
382}
383
384#[derive(Debug, Clone)]
385struct WmaState {
386    period: usize,
387    denom: f64,
388    values: Vec<f64>,
389    valid: Vec<u8>,
390    idx: usize,
391    count: usize,
392    valid_count: usize,
393}
394
395impl WmaState {
396    #[inline]
397    fn new(period: usize) -> Self {
398        Self {
399            period,
400            denom: (period * (period + 1) / 2) as f64,
401            values: vec![0.0; period],
402            valid: vec![0u8; period],
403            idx: 0,
404            count: 0,
405            valid_count: 0,
406        }
407    }
408
409    #[inline]
410    fn update(&mut self, value: Option<f64>) -> Option<f64> {
411        if self.count >= self.period {
412            let old_idx = self.idx;
413            if self.valid[old_idx] != 0 {
414                self.valid_count = self.valid_count.saturating_sub(1);
415            }
416        } else {
417            self.count += 1;
418        }
419
420        match value {
421            Some(v) if v.is_finite() => {
422                self.values[self.idx] = v;
423                self.valid[self.idx] = 1;
424                self.valid_count += 1;
425            }
426            _ => {
427                self.values[self.idx] = 0.0;
428                self.valid[self.idx] = 0;
429            }
430        }
431
432        self.idx += 1;
433        if self.idx == self.period {
434            self.idx = 0;
435        }
436
437        if self.count < self.period {
438            return None;
439        }
440        if self.valid_count != self.period {
441            return Some(f64::NAN);
442        }
443
444        let mut weighted = 0.0;
445        let mut weight = 1.0;
446        let mut pos = self.idx;
447        for _ in 0..self.period {
448            weighted += self.values[pos] * weight;
449            weight += 1.0;
450            pos += 1;
451            if pos == self.period {
452                pos = 0;
453            }
454        }
455        Some(weighted / self.denom)
456    }
457}
458
459#[derive(Debug, Clone)]
460struct ExpState {
461    alpha: f64,
462    value: f64,
463    initialized: bool,
464}
465
466impl ExpState {
467    #[inline]
468    fn ema(period: usize) -> Self {
469        Self {
470            alpha: 2.0 / (period as f64 + 1.0),
471            value: f64::NAN,
472            initialized: false,
473        }
474    }
475
476    #[inline]
477    fn rma(period: usize) -> Self {
478        Self {
479            alpha: 1.0 / period as f64,
480            value: f64::NAN,
481            initialized: false,
482        }
483    }
484
485    #[inline]
486    fn update(&mut self, value: Option<f64>) -> Option<f64> {
487        match value {
488            Some(v) if v.is_finite() => {
489                if self.initialized {
490                    self.value += self.alpha * (v - self.value);
491                } else {
492                    self.value = v;
493                    self.initialized = true;
494                }
495                Some(self.value)
496            }
497            _ => {
498                self.value = f64::NAN;
499                self.initialized = false;
500                Some(f64::NAN)
501            }
502        }
503    }
504}
505
506#[derive(Debug, Clone)]
507enum MaState {
508    Sma(SmaState),
509    Wma(WmaState),
510    Ema(ExpState),
511    Rma(ExpState),
512}
513
514impl MaState {
515    #[inline]
516    fn new(kind: MaType, period: usize) -> Self {
517        match kind {
518            MaType::Sma => Self::Sma(SmaState::new(period)),
519            MaType::Wma => Self::Wma(WmaState::new(period)),
520            MaType::Ema => Self::Ema(ExpState::ema(period)),
521            MaType::Rma => Self::Rma(ExpState::rma(period)),
522        }
523    }
524
525    #[inline]
526    fn update(&mut self, value: Option<f64>) -> Option<f64> {
527        match self {
528            Self::Sma(state) => state.update(value),
529            Self::Wma(state) => state.update(value),
530            Self::Ema(state) => state.update(value),
531            Self::Rma(state) => state.update(value),
532        }
533    }
534}
535
536#[derive(Debug, Clone)]
537pub struct DemandIndexStream {
538    len_bs: usize,
539    len_bs_ma: usize,
540    len_di_ma: usize,
541    ma_type: MaType,
542    h0: f64,
543    l0: f64,
544    has_h0_l0: bool,
545    prev_p: f64,
546    has_prev_p: bool,
547    volume_avg: MaState,
548    bp_avg: MaState,
549    sp_avg: MaState,
550    signal_avg: SmaState,
551}
552
553impl DemandIndexStream {
554    pub fn try_new(params: DemandIndexParams) -> Result<Self, DemandIndexError> {
555        let len_bs = params.len_bs.unwrap_or(19);
556        let len_bs_ma = params.len_bs_ma.unwrap_or(19);
557        let len_di_ma = params.len_di_ma.unwrap_or(19);
558        let ma_type = MaType::parse(params.ma_type.as_deref().unwrap_or("ema"))?;
559        validate_lengths(len_bs, len_bs_ma, len_di_ma, 0)?;
560
561        Ok(Self {
562            len_bs,
563            len_bs_ma,
564            len_di_ma,
565            ma_type,
566            h0: f64::NAN,
567            l0: f64::NAN,
568            has_h0_l0: false,
569            prev_p: f64::NAN,
570            has_prev_p: false,
571            volume_avg: MaState::new(ma_type, len_bs),
572            bp_avg: MaState::new(ma_type, len_bs_ma),
573            sp_avg: MaState::new(ma_type, len_bs_ma),
574            signal_avg: SmaState::new(len_di_ma),
575        })
576    }
577
578    #[inline]
579    pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
580        if !valid_ohlcv_bar(high, low, close, volume) {
581            self.volume_avg.update(None);
582            let di = demand_index_finalize(None, None, &mut self.bp_avg, &mut self.sp_avg);
583            let signal = update_signal(&mut self.signal_avg, di);
584            self.has_prev_p = false;
585            return if di.is_none() && signal.is_none() {
586                None
587            } else {
588                Some((di.unwrap_or(f64::NAN), signal.unwrap_or(f64::NAN)))
589            };
590        }
591
592        if !self.has_h0_l0 {
593            self.h0 = high;
594            self.l0 = low;
595            self.has_h0_l0 = true;
596        }
597
598        let volume_avg = self.volume_avg.update(Some(volume));
599        let p = high + low + 2.0 * close;
600
601        if !self.has_prev_p {
602            let di = demand_index_finalize(None, None, &mut self.bp_avg, &mut self.sp_avg);
603            let signal = update_signal(&mut self.signal_avg, di);
604            self.prev_p = p;
605            self.has_prev_p = true;
606            return if di.is_none() && signal.is_none() {
607                None
608            } else {
609                Some((di.unwrap_or(f64::NAN), signal.unwrap_or(f64::NAN)))
610            };
611        }
612
613        let bp_sp = compute_bp_sp(volume_avg, volume, p, self.prev_p, self.h0, self.l0);
614        let di = demand_index_finalize(
615            bp_sp.map(|x| x.0),
616            bp_sp.map(|x| x.1),
617            &mut self.bp_avg,
618            &mut self.sp_avg,
619        );
620        let signal = update_signal(&mut self.signal_avg, di);
621
622        self.prev_p = p;
623        self.has_prev_p = true;
624        Some((di.unwrap_or(f64::NAN), signal.unwrap_or(f64::NAN)))
625    }
626
627    #[inline]
628    pub fn get_warmup_period(&self) -> usize {
629        signal_warmup(self.ma_type, self.len_bs, self.len_bs_ma, self.len_di_ma)
630    }
631}
632
633#[inline]
634pub fn demand_index(input: &DemandIndexInput) -> Result<DemandIndexOutput, DemandIndexError> {
635    demand_index_with_kernel(input, Kernel::Auto)
636}
637
638#[inline]
639fn validate_lengths(
640    len_bs: usize,
641    len_bs_ma: usize,
642    len_di_ma: usize,
643    data_len: usize,
644) -> Result<(), DemandIndexError> {
645    if len_bs == 0 {
646        return Err(DemandIndexError::InvalidLenBs { len_bs, data_len });
647    }
648    if len_bs_ma == 0 {
649        return Err(DemandIndexError::InvalidLenBsMa {
650            len_bs_ma,
651            data_len,
652        });
653    }
654    if len_di_ma == 0 {
655        return Err(DemandIndexError::InvalidLenDiMa {
656            len_di_ma,
657            data_len,
658        });
659    }
660    Ok(())
661}
662
663#[inline]
664fn valid_ohlcv_bar(high: f64, low: f64, close: f64, volume: f64) -> bool {
665    high.is_finite() && low.is_finite() && close.is_finite() && volume.is_finite()
666}
667
668#[inline]
669fn extract_input<'a>(
670    input: &'a DemandIndexInput<'a>,
671) -> Result<(&'a [f64], &'a [f64], &'a [f64], &'a [f64]), DemandIndexError> {
672    let (high, low, close, volume) = match &input.data {
673        DemandIndexData::Candles { candles } => (
674            candles.high.as_slice(),
675            candles.low.as_slice(),
676            candles.close.as_slice(),
677            candles.volume.as_slice(),
678        ),
679        DemandIndexData::Slices {
680            high,
681            low,
682            close,
683            volume,
684        } => (*high, *low, *close, *volume),
685    };
686
687    if high.is_empty() || low.is_empty() || close.is_empty() || volume.is_empty() {
688        return Err(DemandIndexError::EmptyInputData);
689    }
690    if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
691        return Err(DemandIndexError::InconsistentSliceLengths {
692            high_len: high.len(),
693            low_len: low.len(),
694            close_len: close.len(),
695            volume_len: volume.len(),
696        });
697    }
698    Ok((high, low, close, volume))
699}
700
701#[inline]
702fn first_valid_ohlcv(high: &[f64], low: &[f64], close: &[f64], volume: &[f64]) -> Option<usize> {
703    let mut i = 0usize;
704    while i < high.len() {
705        if valid_ohlcv_bar(high[i], low[i], close[i], volume[i]) {
706            return Some(i);
707        }
708        i += 1;
709    }
710    None
711}
712
713#[inline]
714fn count_valid_ohlcv(high: &[f64], low: &[f64], close: &[f64], volume: &[f64]) -> usize {
715    let mut count = 0usize;
716    for i in 0..high.len() {
717        if valid_ohlcv_bar(high[i], low[i], close[i], volume[i]) {
718            count += 1;
719        }
720    }
721    count
722}
723
724#[inline]
725fn di_warmup(ma_type: MaType, len_bs: usize, len_bs_ma: usize) -> usize {
726    match ma_type {
727        MaType::Ema | MaType::Rma => 1,
728        MaType::Sma | MaType::Wma => len_bs.saturating_sub(1).max(1) + len_bs_ma.saturating_sub(1),
729    }
730}
731
732#[inline]
733fn signal_warmup(ma_type: MaType, len_bs: usize, len_bs_ma: usize, len_di_ma: usize) -> usize {
734    di_warmup(ma_type, len_bs, len_bs_ma) + len_di_ma.saturating_sub(1)
735}
736
737#[inline]
738fn needed_valid_bars(ma_type: MaType, len_bs: usize, len_bs_ma: usize, len_di_ma: usize) -> usize {
739    signal_warmup(ma_type, len_bs, len_bs_ma, len_di_ma) + 1
740}
741
742#[inline]
743fn normalize_h0_l0(h0: f64, l0: f64) -> f64 {
744    (h0 - l0).abs().max(1e-12)
745}
746
747#[inline]
748fn compute_bp_sp(
749    volume_avg: Option<f64>,
750    volume: f64,
751    p: f64,
752    prev_p: f64,
753    h0: f64,
754    l0: f64,
755) -> Option<(f64, f64)> {
756    let v_avg = volume_avg?;
757    if !v_avg.is_finite() || !p.is_finite() || !prev_p.is_finite() {
758        return None;
759    }
760
761    let v_ratio = if v_avg == 0.0 { 1.0 } else { volume / v_avg };
762    if !v_ratio.is_finite() {
763        return None;
764    }
765
766    let denom = normalize_h0_l0(h0, l0);
767    let k = 0.375;
768
769    if p < prev_p {
770        if p == 0.0 {
771            return None;
772        }
773        let exponent = (k * (p + prev_p) / denom) * ((prev_p - p) / p);
774        return Some((v_ratio / exponent.exp(), v_ratio));
775    }
776    if p > prev_p {
777        if prev_p == 0.0 {
778            return None;
779        }
780        let exponent = (k * (p + prev_p) / denom) * ((p - prev_p) / prev_p);
781        return Some((v_ratio, v_ratio / exponent.exp()));
782    }
783    Some((v_ratio, v_ratio))
784}
785
786#[inline]
787fn finalize_di(bp_ma: Option<f64>, sp_ma: Option<f64>) -> Option<f64> {
788    let bp = bp_ma?;
789    let sp = sp_ma?;
790    if !bp.is_finite() || !sp.is_finite() {
791        return None;
792    }
793
794    if bp > sp {
795        if bp == 0.0 {
796            return Some(100.0);
797        }
798        return Some(100.0 * (1.0 - sp / bp));
799    }
800    if bp < sp {
801        if sp == 0.0 {
802            return Some(-100.0);
803        }
804        return Some(100.0 * (bp / sp - 1.0));
805    }
806    Some(0.0)
807}
808
809#[inline]
810fn demand_index_finalize(
811    bp: Option<f64>,
812    sp: Option<f64>,
813    bp_avg: &mut MaState,
814    sp_avg: &mut MaState,
815) -> Option<f64> {
816    let bp_ma = bp_avg.update(bp.filter(|v| v.is_finite()));
817    let sp_ma = sp_avg.update(sp.filter(|v| v.is_finite()));
818    finalize_di(bp_ma, sp_ma)
819}
820
821#[inline]
822fn update_signal(signal_avg: &mut SmaState, di: Option<f64>) -> Option<f64> {
823    signal_avg.update(di.filter(|v| v.is_finite()))
824}
825
826#[inline]
827fn demand_index_compute_into(
828    high: &[f64],
829    low: &[f64],
830    close: &[f64],
831    volume: &[f64],
832    len_bs: usize,
833    len_bs_ma: usize,
834    len_di_ma: usize,
835    ma_type: MaType,
836    demand_index_out: &mut [f64],
837    signal_out: &mut [f64],
838) {
839    let mut volume_avg = MaState::new(ma_type, len_bs);
840    let mut bp_avg = MaState::new(ma_type, len_bs_ma);
841    let mut sp_avg = MaState::new(ma_type, len_bs_ma);
842    let mut signal_avg = SmaState::new(len_di_ma);
843    let mut h0 = f64::NAN;
844    let mut l0 = f64::NAN;
845    let mut has_h0_l0 = false;
846    let mut prev_p = f64::NAN;
847    let mut has_prev_p = false;
848
849    for i in 0..high.len() {
850        let h = high[i];
851        let l = low[i];
852        let c = close[i];
853        let v = volume[i];
854
855        if !valid_ohlcv_bar(h, l, c, v) {
856            volume_avg.update(None);
857            let di = demand_index_finalize(None, None, &mut bp_avg, &mut sp_avg);
858            let signal = update_signal(&mut signal_avg, di);
859            demand_index_out[i] = di.unwrap_or(f64::NAN);
860            signal_out[i] = signal.unwrap_or(f64::NAN);
861            has_prev_p = false;
862            continue;
863        }
864
865        if !has_h0_l0 {
866            h0 = h;
867            l0 = l;
868            has_h0_l0 = true;
869        }
870
871        let volume_avg_now = volume_avg.update(Some(v));
872        let p = h + l + 2.0 * c;
873
874        if !has_prev_p {
875            let di = demand_index_finalize(None, None, &mut bp_avg, &mut sp_avg);
876            let signal = update_signal(&mut signal_avg, di);
877            demand_index_out[i] = di.unwrap_or(f64::NAN);
878            signal_out[i] = signal.unwrap_or(f64::NAN);
879            prev_p = p;
880            has_prev_p = true;
881            continue;
882        }
883
884        let bp_sp = compute_bp_sp(volume_avg_now, v, p, prev_p, h0, l0);
885        let di = demand_index_finalize(
886            bp_sp.map(|x| x.0),
887            bp_sp.map(|x| x.1),
888            &mut bp_avg,
889            &mut sp_avg,
890        );
891        let signal = update_signal(&mut signal_avg, di);
892        demand_index_out[i] = di.unwrap_or(f64::NAN);
893        signal_out[i] = signal.unwrap_or(f64::NAN);
894        prev_p = p;
895        has_prev_p = true;
896    }
897}
898
899#[inline]
900pub fn demand_index_with_kernel(
901    input: &DemandIndexInput,
902    kernel: Kernel,
903) -> Result<DemandIndexOutput, DemandIndexError> {
904    let (high, low, close, volume) = extract_input(input)?;
905    let len = high.len();
906    let len_bs = input.get_len_bs();
907    let len_bs_ma = input.get_len_bs_ma();
908    let len_di_ma = input.get_len_di_ma();
909    validate_lengths(len_bs, len_bs_ma, len_di_ma, len)?;
910    let ma_type = MaType::parse(input.get_ma_type())?;
911
912    if len_bs > len {
913        return Err(DemandIndexError::InvalidLenBs {
914            len_bs,
915            data_len: len,
916        });
917    }
918    if len_bs_ma > len {
919        return Err(DemandIndexError::InvalidLenBsMa {
920            len_bs_ma,
921            data_len: len,
922        });
923    }
924    if len_di_ma > len {
925        return Err(DemandIndexError::InvalidLenDiMa {
926            len_di_ma,
927            data_len: len,
928        });
929    }
930
931    let first =
932        first_valid_ohlcv(high, low, close, volume).ok_or(DemandIndexError::AllValuesNaN)?;
933    let valid = count_valid_ohlcv(high, low, close, volume);
934    let needed = needed_valid_bars(ma_type, len_bs, len_bs_ma, len_di_ma);
935    if valid < needed {
936        return Err(DemandIndexError::NotEnoughValidData { needed, valid });
937    }
938
939    let _ = match kernel {
940        Kernel::Auto => detect_best_kernel(),
941        other => other.to_non_batch(),
942    };
943
944    let mut demand_index_out = alloc_with_nan_prefix(len, first);
945    let mut signal_out = alloc_with_nan_prefix(len, first);
946    demand_index_compute_into(
947        high,
948        low,
949        close,
950        volume,
951        len_bs,
952        len_bs_ma,
953        len_di_ma,
954        ma_type,
955        &mut demand_index_out,
956        &mut signal_out,
957    );
958
959    Ok(DemandIndexOutput {
960        demand_index: demand_index_out,
961        signal: signal_out,
962    })
963}
964
965#[inline]
966pub fn demand_index_into_slices(
967    demand_index_out: &mut [f64],
968    signal_out: &mut [f64],
969    input: &DemandIndexInput,
970    kernel: Kernel,
971) -> Result<(), DemandIndexError> {
972    let (high, low, close, volume) = extract_input(input)?;
973    let len = high.len();
974    if demand_index_out.len() != len || signal_out.len() != len {
975        return Err(DemandIndexError::OutputLengthMismatch {
976            expected: len,
977            demand_index_got: demand_index_out.len(),
978            signal_got: signal_out.len(),
979        });
980    }
981
982    let len_bs = input.get_len_bs();
983    let len_bs_ma = input.get_len_bs_ma();
984    let len_di_ma = input.get_len_di_ma();
985    validate_lengths(len_bs, len_bs_ma, len_di_ma, len)?;
986    let ma_type = MaType::parse(input.get_ma_type())?;
987    let valid = count_valid_ohlcv(high, low, close, volume);
988    let needed = needed_valid_bars(ma_type, len_bs, len_bs_ma, len_di_ma);
989    if valid < needed {
990        return Err(DemandIndexError::NotEnoughValidData { needed, valid });
991    }
992
993    let _ = match kernel {
994        Kernel::Auto => detect_best_kernel(),
995        other => other.to_non_batch(),
996    };
997
998    demand_index_compute_into(
999        high,
1000        low,
1001        close,
1002        volume,
1003        len_bs,
1004        len_bs_ma,
1005        len_di_ma,
1006        ma_type,
1007        demand_index_out,
1008        signal_out,
1009    );
1010    Ok(())
1011}
1012
1013#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1014#[inline]
1015pub fn demand_index_into(
1016    input: &DemandIndexInput,
1017    demand_index_out: &mut [f64],
1018    signal_out: &mut [f64],
1019) -> Result<(), DemandIndexError> {
1020    demand_index_into_slices(demand_index_out, signal_out, input, Kernel::Auto)
1021}
1022
1023#[derive(Clone, Debug)]
1024pub struct DemandIndexBatchRange {
1025    pub len_bs: (usize, usize, usize),
1026    pub len_bs_ma: (usize, usize, usize),
1027    pub len_di_ma: (usize, usize, usize),
1028    pub ma_type: Option<String>,
1029}
1030
1031impl Default for DemandIndexBatchRange {
1032    fn default() -> Self {
1033        Self {
1034            len_bs: (19, 19, 0),
1035            len_bs_ma: (19, 19, 0),
1036            len_di_ma: (19, 19, 0),
1037            ma_type: Some("ema".to_string()),
1038        }
1039    }
1040}
1041
1042#[derive(Clone, Debug, Default)]
1043pub struct DemandIndexBatchBuilder {
1044    range: DemandIndexBatchRange,
1045    kernel: Kernel,
1046}
1047
1048impl DemandIndexBatchBuilder {
1049    #[inline]
1050    pub fn new() -> Self {
1051        Self::default()
1052    }
1053
1054    #[inline]
1055    pub fn kernel(mut self, kernel: Kernel) -> Self {
1056        self.kernel = kernel;
1057        self
1058    }
1059
1060    #[inline]
1061    pub fn len_bs_range(mut self, start: usize, end: usize, step: usize) -> Self {
1062        self.range.len_bs = (start, end, step);
1063        self
1064    }
1065
1066    #[inline]
1067    pub fn len_bs_ma_range(mut self, start: usize, end: usize, step: usize) -> Self {
1068        self.range.len_bs_ma = (start, end, step);
1069        self
1070    }
1071
1072    #[inline]
1073    pub fn len_di_ma_range(mut self, start: usize, end: usize, step: usize) -> Self {
1074        self.range.len_di_ma = (start, end, step);
1075        self
1076    }
1077
1078    #[inline]
1079    pub fn ma_type(mut self, ma_type: &str) -> Result<Self, DemandIndexError> {
1080        self.range.ma_type = Some(MaType::parse(ma_type)?.as_str().to_string());
1081        Ok(self)
1082    }
1083
1084    #[inline]
1085    pub fn apply_slices(
1086        self,
1087        high: &[f64],
1088        low: &[f64],
1089        close: &[f64],
1090        volume: &[f64],
1091    ) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1092        demand_index_batch_with_kernel(high, low, close, volume, &self.range, self.kernel)
1093    }
1094
1095    #[inline]
1096    pub fn apply_candles(
1097        self,
1098        candles: &Candles,
1099    ) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1100        self.apply_slices(&candles.high, &candles.low, &candles.close, &candles.volume)
1101    }
1102}
1103
1104#[derive(Clone, Debug)]
1105pub struct DemandIndexBatchOutput {
1106    pub demand_index: Vec<f64>,
1107    pub signal: Vec<f64>,
1108    pub combos: Vec<DemandIndexParams>,
1109    pub rows: usize,
1110    pub cols: usize,
1111}
1112
1113impl DemandIndexBatchOutput {
1114    pub fn row_for_params(&self, params: &DemandIndexParams) -> Option<usize> {
1115        self.combos.iter().position(|combo| {
1116            combo.len_bs.unwrap_or(19) == params.len_bs.unwrap_or(19)
1117                && combo.len_bs_ma.unwrap_or(19) == params.len_bs_ma.unwrap_or(19)
1118                && combo.len_di_ma.unwrap_or(19) == params.len_di_ma.unwrap_or(19)
1119                && combo
1120                    .ma_type
1121                    .as_deref()
1122                    .unwrap_or("ema")
1123                    .eq_ignore_ascii_case(params.ma_type.as_deref().unwrap_or("ema"))
1124        })
1125    }
1126}
1127
1128#[inline]
1129fn expand_axis_usize(
1130    (start, end, step): (usize, usize, usize),
1131) -> Result<Vec<usize>, DemandIndexError> {
1132    if step == 0 || start == end {
1133        return Ok(vec![start]);
1134    }
1135
1136    let mut out = Vec::new();
1137    if start < end {
1138        let mut x = start;
1139        while x <= end {
1140            out.push(x);
1141            let next = x.saturating_add(step);
1142            if next == x {
1143                break;
1144            }
1145            x = next;
1146        }
1147    } else {
1148        let mut x = start;
1149        loop {
1150            out.push(x);
1151            if x == end {
1152                break;
1153            }
1154            let next = x.saturating_sub(step);
1155            if next == x || next < end {
1156                break;
1157            }
1158            x = next;
1159        }
1160    }
1161
1162    if out.is_empty() {
1163        return Err(DemandIndexError::InvalidRange {
1164            start: start.to_string(),
1165            end: end.to_string(),
1166            step: step.to_string(),
1167        });
1168    }
1169    Ok(out)
1170}
1171
1172#[inline]
1173fn expand_grid_demand_index(
1174    range: &DemandIndexBatchRange,
1175) -> Result<Vec<DemandIndexParams>, DemandIndexError> {
1176    let len_bs_values = expand_axis_usize(range.len_bs)?;
1177    let len_bs_ma_values = expand_axis_usize(range.len_bs_ma)?;
1178    let len_di_ma_values = expand_axis_usize(range.len_di_ma)?;
1179    let ma_type = MaType::parse(range.ma_type.as_deref().unwrap_or("ema"))?;
1180
1181    let mut out =
1182        Vec::with_capacity(len_bs_values.len() * len_bs_ma_values.len() * len_di_ma_values.len());
1183    for &len_bs in &len_bs_values {
1184        for &len_bs_ma in &len_bs_ma_values {
1185            for &len_di_ma in &len_di_ma_values {
1186                validate_lengths(len_bs, len_bs_ma, len_di_ma, 0)?;
1187                out.push(DemandIndexParams {
1188                    len_bs: Some(len_bs),
1189                    len_bs_ma: Some(len_bs_ma),
1190                    len_di_ma: Some(len_di_ma),
1191                    ma_type: Some(ma_type.as_str().to_string()),
1192                });
1193            }
1194        }
1195    }
1196    Ok(out)
1197}
1198
1199#[inline]
1200pub fn demand_index_batch_with_kernel(
1201    high: &[f64],
1202    low: &[f64],
1203    close: &[f64],
1204    volume: &[f64],
1205    sweep: &DemandIndexBatchRange,
1206    kernel: Kernel,
1207) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1208    let batch_kernel = match kernel {
1209        Kernel::Auto => detect_best_batch_kernel(),
1210        other if other.is_batch() => other,
1211        other => return Err(DemandIndexError::InvalidKernelForBatch(other)),
1212    };
1213    demand_index_batch_par_slice(high, low, close, volume, sweep, batch_kernel.to_non_batch())
1214}
1215
1216#[inline]
1217pub fn demand_index_batch_slice(
1218    high: &[f64],
1219    low: &[f64],
1220    close: &[f64],
1221    volume: &[f64],
1222    sweep: &DemandIndexBatchRange,
1223    kernel: Kernel,
1224) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1225    demand_index_batch_inner(high, low, close, volume, sweep, kernel, false)
1226}
1227
1228#[inline]
1229pub fn demand_index_batch_par_slice(
1230    high: &[f64],
1231    low: &[f64],
1232    close: &[f64],
1233    volume: &[f64],
1234    sweep: &DemandIndexBatchRange,
1235    kernel: Kernel,
1236) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1237    demand_index_batch_inner(high, low, close, volume, sweep, kernel, true)
1238}
1239
1240fn demand_index_batch_inner(
1241    high: &[f64],
1242    low: &[f64],
1243    close: &[f64],
1244    volume: &[f64],
1245    sweep: &DemandIndexBatchRange,
1246    kernel: Kernel,
1247    parallel: bool,
1248) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1249    if high.is_empty() || low.is_empty() || close.is_empty() || volume.is_empty() {
1250        return Err(DemandIndexError::EmptyInputData);
1251    }
1252    if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
1253        return Err(DemandIndexError::InconsistentSliceLengths {
1254            high_len: high.len(),
1255            low_len: low.len(),
1256            close_len: close.len(),
1257            volume_len: volume.len(),
1258        });
1259    }
1260
1261    let combos = expand_grid_demand_index(sweep)?;
1262    let rows = combos.len();
1263    let cols = high.len();
1264    let first =
1265        first_valid_ohlcv(high, low, close, volume).ok_or(DemandIndexError::AllValuesNaN)?;
1266    let valid = count_valid_ohlcv(high, low, close, volume);
1267
1268    let mut di_mu = make_uninit_matrix(rows, cols);
1269    let mut signal_mu = make_uninit_matrix(rows, cols);
1270    let mut di_warms = Vec::with_capacity(rows);
1271    let mut signal_warms = Vec::with_capacity(rows);
1272
1273    for combo in &combos {
1274        let len_bs = combo.len_bs.unwrap_or(19);
1275        let len_bs_ma = combo.len_bs_ma.unwrap_or(19);
1276        let len_di_ma = combo.len_di_ma.unwrap_or(19);
1277        let ma_type = MaType::parse(combo.ma_type.as_deref().unwrap_or("ema"))?;
1278        let needed = needed_valid_bars(ma_type, len_bs, len_bs_ma, len_di_ma);
1279        if valid < needed {
1280            return Err(DemandIndexError::NotEnoughValidData { needed, valid });
1281        }
1282        di_warms.push((first + di_warmup(ma_type, len_bs, len_bs_ma)).min(cols));
1283        signal_warms.push((first + signal_warmup(ma_type, len_bs, len_bs_ma, len_di_ma)).min(cols));
1284    }
1285
1286    init_matrix_prefixes(&mut di_mu, cols, &di_warms);
1287    init_matrix_prefixes(&mut signal_mu, cols, &signal_warms);
1288
1289    let mut di_guard = ManuallyDrop::new(di_mu);
1290    let mut signal_guard = ManuallyDrop::new(signal_mu);
1291    let di_out = unsafe {
1292        std::slice::from_raw_parts_mut(di_guard.as_mut_ptr() as *mut f64, di_guard.len())
1293    };
1294    let signal_out = unsafe {
1295        std::slice::from_raw_parts_mut(signal_guard.as_mut_ptr() as *mut f64, signal_guard.len())
1296    };
1297
1298    if parallel {
1299        #[cfg(not(target_arch = "wasm32"))]
1300        {
1301            di_out
1302                .par_chunks_mut(cols)
1303                .zip(signal_out.par_chunks_mut(cols))
1304                .zip(combos.par_iter())
1305                .for_each(|((dst_di, dst_signal), combo)| {
1306                    let ma_type = MaType::parse(combo.ma_type.as_deref().unwrap_or("ema"))
1307                        .unwrap_or(MaType::Ema);
1308                    demand_index_compute_into(
1309                        high,
1310                        low,
1311                        close,
1312                        volume,
1313                        combo.len_bs.unwrap_or(19),
1314                        combo.len_bs_ma.unwrap_or(19),
1315                        combo.len_di_ma.unwrap_or(19),
1316                        ma_type,
1317                        dst_di,
1318                        dst_signal,
1319                    );
1320                });
1321        }
1322    } else {
1323        let _ = kernel;
1324        for (row, combo) in combos.iter().enumerate() {
1325            let start = row * cols;
1326            let end = start + cols;
1327            demand_index_compute_into(
1328                high,
1329                low,
1330                close,
1331                volume,
1332                combo.len_bs.unwrap_or(19),
1333                combo.len_bs_ma.unwrap_or(19),
1334                combo.len_di_ma.unwrap_or(19),
1335                MaType::parse(combo.ma_type.as_deref().unwrap_or("ema"))?,
1336                &mut di_out[start..end],
1337                &mut signal_out[start..end],
1338            );
1339        }
1340    }
1341
1342    let demand_index = unsafe {
1343        Vec::from_raw_parts(
1344            di_guard.as_mut_ptr() as *mut f64,
1345            di_guard.len(),
1346            di_guard.capacity(),
1347        )
1348    };
1349    let signal = unsafe {
1350        Vec::from_raw_parts(
1351            signal_guard.as_mut_ptr() as *mut f64,
1352            signal_guard.len(),
1353            signal_guard.capacity(),
1354        )
1355    };
1356    core::mem::forget(di_guard);
1357    core::mem::forget(signal_guard);
1358
1359    Ok(DemandIndexBatchOutput {
1360        demand_index,
1361        signal,
1362        combos,
1363        rows,
1364        cols,
1365    })
1366}
1367
1368pub fn demand_index_batch_inner_into(
1369    high: &[f64],
1370    low: &[f64],
1371    close: &[f64],
1372    volume: &[f64],
1373    sweep: &DemandIndexBatchRange,
1374    kernel: Kernel,
1375    demand_index_out: &mut [f64],
1376    signal_out: &mut [f64],
1377) -> Result<Vec<DemandIndexParams>, DemandIndexError> {
1378    let out = demand_index_batch_inner(high, low, close, volume, sweep, kernel, false)?;
1379    let total = out.rows * out.cols;
1380    if demand_index_out.len() != total || signal_out.len() != total {
1381        return Err(DemandIndexError::OutputLengthMismatch {
1382            expected: total,
1383            demand_index_got: demand_index_out.len(),
1384            signal_got: signal_out.len(),
1385        });
1386    }
1387    demand_index_out.copy_from_slice(&out.demand_index);
1388    signal_out.copy_from_slice(&out.signal);
1389    Ok(out.combos)
1390}
1391
1392#[cfg(feature = "python")]
1393#[pyfunction(name = "demand_index")]
1394#[pyo3(signature = (high, low, close, volume, len_bs=None, len_bs_ma=None, len_di_ma=None, ma_type=None, kernel=None))]
1395pub fn demand_index_py<'py>(
1396    py: Python<'py>,
1397    high: PyReadonlyArray1<'py, f64>,
1398    low: PyReadonlyArray1<'py, f64>,
1399    close: PyReadonlyArray1<'py, f64>,
1400    volume: PyReadonlyArray1<'py, f64>,
1401    len_bs: Option<usize>,
1402    len_bs_ma: Option<usize>,
1403    len_di_ma: Option<usize>,
1404    ma_type: Option<&str>,
1405    kernel: Option<&str>,
1406) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1407    let high = high.as_slice()?;
1408    let low = low.as_slice()?;
1409    let close = close.as_slice()?;
1410    let volume = volume.as_slice()?;
1411    let kern = validate_kernel(kernel, false)?;
1412    let input = DemandIndexInput::from_slices(
1413        high,
1414        low,
1415        close,
1416        volume,
1417        DemandIndexParams {
1418            len_bs,
1419            len_bs_ma,
1420            len_di_ma,
1421            ma_type: ma_type.map(|s| s.to_string()),
1422        },
1423    );
1424    let out = py
1425        .allow_threads(|| demand_index_with_kernel(&input, kern))
1426        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1427    Ok((
1428        out.demand_index.into_pyarray(py),
1429        out.signal.into_pyarray(py),
1430    ))
1431}
1432
1433#[cfg(feature = "python")]
1434#[pyclass(name = "DemandIndexStream")]
1435pub struct DemandIndexStreamPy {
1436    stream: DemandIndexStream,
1437}
1438
1439#[cfg(feature = "python")]
1440#[pymethods]
1441impl DemandIndexStreamPy {
1442    #[new]
1443    #[pyo3(signature = (len_bs=19, len_bs_ma=19, len_di_ma=19, ma_type="ema"))]
1444    fn new(len_bs: usize, len_bs_ma: usize, len_di_ma: usize, ma_type: &str) -> PyResult<Self> {
1445        let stream = DemandIndexStream::try_new(DemandIndexParams {
1446            len_bs: Some(len_bs),
1447            len_bs_ma: Some(len_bs_ma),
1448            len_di_ma: Some(len_di_ma),
1449            ma_type: Some(ma_type.to_string()),
1450        })
1451        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1452        Ok(Self { stream })
1453    }
1454
1455    fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
1456        self.stream.update(high, low, close, volume)
1457    }
1458
1459    #[getter]
1460    fn warmup_period(&self) -> usize {
1461        self.stream.get_warmup_period()
1462    }
1463}
1464
1465#[cfg(feature = "python")]
1466#[pyfunction(name = "demand_index_batch")]
1467#[pyo3(signature = (high, low, close, volume, len_bs_range, len_bs_ma_range, len_di_ma_range, ma_type=None, kernel=None))]
1468pub fn demand_index_batch_py<'py>(
1469    py: Python<'py>,
1470    high: PyReadonlyArray1<'py, f64>,
1471    low: PyReadonlyArray1<'py, f64>,
1472    close: PyReadonlyArray1<'py, f64>,
1473    volume: PyReadonlyArray1<'py, f64>,
1474    len_bs_range: (usize, usize, usize),
1475    len_bs_ma_range: (usize, usize, usize),
1476    len_di_ma_range: (usize, usize, usize),
1477    ma_type: Option<&str>,
1478    kernel: Option<&str>,
1479) -> PyResult<Bound<'py, PyDict>> {
1480    let high = high.as_slice()?;
1481    let low = low.as_slice()?;
1482    let close = close.as_slice()?;
1483    let volume = volume.as_slice()?;
1484    let kern = validate_kernel(kernel, true)?;
1485    let sweep = DemandIndexBatchRange {
1486        len_bs: len_bs_range,
1487        len_bs_ma: len_bs_ma_range,
1488        len_di_ma: len_di_ma_range,
1489        ma_type: Some(ma_type.unwrap_or("ema").to_string()),
1490    };
1491    let out = py
1492        .allow_threads(|| demand_index_batch_with_kernel(high, low, close, volume, &sweep, kern))
1493        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1494
1495    let dict = PyDict::new(py);
1496    dict.set_item(
1497        "demand_index",
1498        out.demand_index
1499            .into_pyarray(py)
1500            .reshape((out.rows, out.cols))?,
1501    )?;
1502    dict.set_item(
1503        "signal",
1504        out.signal.into_pyarray(py).reshape((out.rows, out.cols))?,
1505    )?;
1506    dict.set_item(
1507        "len_bs",
1508        out.combos
1509            .iter()
1510            .map(|p| p.len_bs.unwrap_or(19) as u64)
1511            .collect::<Vec<_>>()
1512            .into_pyarray(py),
1513    )?;
1514    dict.set_item(
1515        "len_bs_ma",
1516        out.combos
1517            .iter()
1518            .map(|p| p.len_bs_ma.unwrap_or(19) as u64)
1519            .collect::<Vec<_>>()
1520            .into_pyarray(py),
1521    )?;
1522    dict.set_item(
1523        "len_di_ma",
1524        out.combos
1525            .iter()
1526            .map(|p| p.len_di_ma.unwrap_or(19) as u64)
1527            .collect::<Vec<_>>()
1528            .into_pyarray(py),
1529    )?;
1530    dict.set_item(
1531        "ma_types",
1532        out.combos
1533            .iter()
1534            .map(|p| p.ma_type.as_deref().unwrap_or("ema").to_string())
1535            .collect::<Vec<_>>(),
1536    )?;
1537    dict.set_item("rows", out.rows)?;
1538    dict.set_item("cols", out.cols)?;
1539    Ok(dict)
1540}
1541
1542#[cfg(feature = "python")]
1543pub fn register_demand_index_module(module: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1544    module.add_function(wrap_pyfunction!(demand_index_py, module)?)?;
1545    module.add_function(wrap_pyfunction!(demand_index_batch_py, module)?)?;
1546    module.add_class::<DemandIndexStreamPy>()?;
1547    Ok(())
1548}
1549
1550#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1551#[wasm_bindgen(js_name = "demand_index_js")]
1552pub fn demand_index_js(
1553    high: &[f64],
1554    low: &[f64],
1555    close: &[f64],
1556    volume: &[f64],
1557    len_bs: usize,
1558    len_bs_ma: usize,
1559    len_di_ma: usize,
1560    ma_type: &str,
1561) -> Result<js_sys::Object, JsValue> {
1562    let input = DemandIndexInput::from_slices(
1563        high,
1564        low,
1565        close,
1566        volume,
1567        DemandIndexParams {
1568            len_bs: Some(len_bs),
1569            len_bs_ma: Some(len_bs_ma),
1570            len_di_ma: Some(len_di_ma),
1571            ma_type: Some(ma_type.to_string()),
1572        },
1573    );
1574    let out = demand_index(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1575    let result = js_sys::Object::new();
1576    let di_array = js_sys::Float64Array::new_with_length(out.demand_index.len() as u32);
1577    di_array.copy_from(&out.demand_index);
1578    js_sys::Reflect::set(&result, &JsValue::from_str("demand_index"), &di_array)?;
1579    let signal_array = js_sys::Float64Array::new_with_length(out.signal.len() as u32);
1580    signal_array.copy_from(&out.signal);
1581    js_sys::Reflect::set(&result, &JsValue::from_str("signal"), &signal_array)?;
1582    Ok(result)
1583}
1584
1585#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1586#[wasm_bindgen]
1587pub fn demand_index_alloc(len: usize) -> *mut f64 {
1588    let mut vec = Vec::<f64>::with_capacity(len);
1589    let ptr = vec.as_mut_ptr();
1590    std::mem::forget(vec);
1591    ptr
1592}
1593
1594#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1595#[wasm_bindgen]
1596pub fn demand_index_free(ptr: *mut f64, len: usize) {
1597    if !ptr.is_null() {
1598        unsafe {
1599            let _ = Vec::from_raw_parts(ptr, len, len);
1600        }
1601    }
1602}
1603
1604#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1605#[wasm_bindgen]
1606pub fn demand_index_into(
1607    high_ptr: *const f64,
1608    low_ptr: *const f64,
1609    close_ptr: *const f64,
1610    volume_ptr: *const f64,
1611    demand_index_ptr: *mut f64,
1612    signal_ptr: *mut f64,
1613    len: usize,
1614    len_bs: usize,
1615    len_bs_ma: usize,
1616    len_di_ma: usize,
1617    ma_type: &str,
1618) -> Result<(), JsValue> {
1619    if high_ptr.is_null()
1620        || low_ptr.is_null()
1621        || close_ptr.is_null()
1622        || volume_ptr.is_null()
1623        || demand_index_ptr.is_null()
1624        || signal_ptr.is_null()
1625    {
1626        return Err(JsValue::from_str("Null pointer provided"));
1627    }
1628
1629    unsafe {
1630        let high = std::slice::from_raw_parts(high_ptr, len);
1631        let low = std::slice::from_raw_parts(low_ptr, len);
1632        let close = std::slice::from_raw_parts(close_ptr, len);
1633        let volume = std::slice::from_raw_parts(volume_ptr, len);
1634        let input = DemandIndexInput::from_slices(
1635            high,
1636            low,
1637            close,
1638            volume,
1639            DemandIndexParams {
1640                len_bs: Some(len_bs),
1641                len_bs_ma: Some(len_bs_ma),
1642                len_di_ma: Some(len_di_ma),
1643                ma_type: Some(ma_type.to_string()),
1644            },
1645        );
1646        let di_out = std::slice::from_raw_parts_mut(demand_index_ptr, len);
1647        let signal_out = std::slice::from_raw_parts_mut(signal_ptr, len);
1648        demand_index_into_slices(di_out, signal_out, &input, Kernel::Auto)
1649            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1650    }
1651    Ok(())
1652}
1653
1654#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1655#[derive(Serialize, Deserialize)]
1656pub struct DemandIndexBatchConfig {
1657    pub len_bs_range: (usize, usize, usize),
1658    pub len_bs_ma_range: (usize, usize, usize),
1659    pub len_di_ma_range: (usize, usize, usize),
1660    pub ma_type: Option<String>,
1661}
1662
1663#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1664#[derive(Serialize, Deserialize)]
1665pub struct DemandIndexBatchJsOutput {
1666    pub demand_index: Vec<f64>,
1667    pub signal: Vec<f64>,
1668    pub combos: Vec<DemandIndexParams>,
1669    pub len_bs: Vec<usize>,
1670    pub len_bs_ma: Vec<usize>,
1671    pub len_di_ma: Vec<usize>,
1672    pub ma_types: Vec<String>,
1673    pub rows: usize,
1674    pub cols: usize,
1675}
1676
1677#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1678#[wasm_bindgen(js_name = "demand_index_batch_js")]
1679pub fn demand_index_batch_js(
1680    high: &[f64],
1681    low: &[f64],
1682    close: &[f64],
1683    volume: &[f64],
1684    config: JsValue,
1685) -> Result<JsValue, JsValue> {
1686    let config: DemandIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1687        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1688    let sweep = DemandIndexBatchRange {
1689        len_bs: config.len_bs_range,
1690        len_bs_ma: config.len_bs_ma_range,
1691        len_di_ma: config.len_di_ma_range,
1692        ma_type: Some(config.ma_type.unwrap_or_else(|| "ema".to_string())),
1693    };
1694    let out = demand_index_batch_inner(
1695        high,
1696        low,
1697        close,
1698        volume,
1699        &sweep,
1700        detect_best_kernel(),
1701        false,
1702    )
1703    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1704    serde_wasm_bindgen::to_value(&DemandIndexBatchJsOutput {
1705        len_bs: out.combos.iter().map(|p| p.len_bs.unwrap_or(19)).collect(),
1706        len_bs_ma: out
1707            .combos
1708            .iter()
1709            .map(|p| p.len_bs_ma.unwrap_or(19))
1710            .collect(),
1711        len_di_ma: out
1712            .combos
1713            .iter()
1714            .map(|p| p.len_di_ma.unwrap_or(19))
1715            .collect(),
1716        ma_types: out
1717            .combos
1718            .iter()
1719            .map(|p| p.ma_type.as_deref().unwrap_or("ema").to_string())
1720            .collect(),
1721        demand_index: out.demand_index,
1722        signal: out.signal,
1723        combos: out.combos,
1724        rows: out.rows,
1725        cols: out.cols,
1726    })
1727    .map_err(|e| JsValue::from_str(&e.to_string()))
1728}
1729
1730#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1731#[wasm_bindgen]
1732pub fn demand_index_batch_into(
1733    high_ptr: *const f64,
1734    low_ptr: *const f64,
1735    close_ptr: *const f64,
1736    volume_ptr: *const f64,
1737    out_ptr: *mut f64,
1738    len: usize,
1739    len_bs_start: usize,
1740    len_bs_end: usize,
1741    len_bs_step: usize,
1742    len_bs_ma_start: usize,
1743    len_bs_ma_end: usize,
1744    len_bs_ma_step: usize,
1745    len_di_ma_start: usize,
1746    len_di_ma_end: usize,
1747    len_di_ma_step: usize,
1748    ma_type: &str,
1749) -> Result<usize, JsValue> {
1750    if high_ptr.is_null()
1751        || low_ptr.is_null()
1752        || close_ptr.is_null()
1753        || volume_ptr.is_null()
1754        || out_ptr.is_null()
1755    {
1756        return Err(JsValue::from_str("Null pointer provided"));
1757    }
1758
1759    let sweep = DemandIndexBatchRange {
1760        len_bs: (len_bs_start, len_bs_end, len_bs_step),
1761        len_bs_ma: (len_bs_ma_start, len_bs_ma_end, len_bs_ma_step),
1762        len_di_ma: (len_di_ma_start, len_di_ma_end, len_di_ma_step),
1763        ma_type: Some(ma_type.to_string()),
1764    };
1765    let combos = expand_grid_demand_index(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1766    let rows = combos.len();
1767    let total = rows
1768        .checked_mul(len)
1769        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1770
1771    unsafe {
1772        let high = std::slice::from_raw_parts(high_ptr, len);
1773        let low = std::slice::from_raw_parts(low_ptr, len);
1774        let close = std::slice::from_raw_parts(close_ptr, len);
1775        let volume = std::slice::from_raw_parts(volume_ptr, len);
1776        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1777        let batch = demand_index_batch_inner(
1778            high,
1779            low,
1780            close,
1781            volume,
1782            &sweep,
1783            detect_best_kernel(),
1784            false,
1785        )
1786        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1787        out.copy_from_slice(&batch.demand_index);
1788    }
1789    Ok(rows)
1790}
1791
1792#[cfg(test)]
1793mod tests {
1794    use super::*;
1795    use crate::utilities::data_loader::read_candles_from_csv;
1796    use std::error::Error;
1797
1798    fn load_ohlcv() -> Result<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>), Box<dyn Error>> {
1799        let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
1800        Ok((candles.high, candles.low, candles.close, candles.volume))
1801    }
1802
1803    fn assert_series_eq(left: &[f64], right: &[f64]) {
1804        assert_eq!(left.len(), right.len());
1805        for (idx, (&a, &b)) in left.iter().zip(right.iter()).enumerate() {
1806            if a.is_nan() && b.is_nan() {
1807                continue;
1808            }
1809            assert!(
1810                (a - b).abs() <= 1e-10,
1811                "mismatch at {}: left={}, right={}",
1812                idx,
1813                a,
1814                b
1815            );
1816        }
1817    }
1818
1819    #[test]
1820    fn demand_index_output_contract() -> Result<(), Box<dyn Error>> {
1821        let (high, low, close, volume) = load_ohlcv()?;
1822        let input = DemandIndexInput::from_slices(
1823            &high,
1824            &low,
1825            &close,
1826            &volume,
1827            DemandIndexParams::default(),
1828        );
1829        let out = demand_index_with_kernel(&input, Kernel::Scalar)?;
1830        assert_eq!(out.demand_index.len(), close.len());
1831        assert_eq!(out.signal.len(), close.len());
1832        assert!(out.demand_index.iter().any(|v| v.is_finite()));
1833        assert!(out.signal.iter().any(|v| v.is_finite()));
1834        for &value in out.demand_index.iter().filter(|v| v.is_finite()).take(64) {
1835            assert!((-100.0..=100.0).contains(&value));
1836        }
1837        Ok(())
1838    }
1839
1840    #[test]
1841    fn demand_index_into_matches_api() -> Result<(), Box<dyn Error>> {
1842        let (high, low, close, volume) = load_ohlcv()?;
1843        let input = DemandIndexInput::from_slices(
1844            &high,
1845            &low,
1846            &close,
1847            &volume,
1848            DemandIndexParams::default(),
1849        );
1850        let expected = demand_index(&input)?;
1851        let mut di = vec![f64::NAN; close.len()];
1852        let mut signal = vec![f64::NAN; close.len()];
1853        demand_index_into_slices(&mut di, &mut signal, &input, Kernel::Auto)?;
1854        assert_series_eq(&di, &expected.demand_index);
1855        assert_series_eq(&signal, &expected.signal);
1856        Ok(())
1857    }
1858
1859    #[test]
1860    fn demand_index_kernel_parity() -> Result<(), Box<dyn Error>> {
1861        let (high, low, close, volume) = load_ohlcv()?;
1862        let input = DemandIndexInput::from_slices(
1863            &high,
1864            &low,
1865            &close,
1866            &volume,
1867            DemandIndexParams::default(),
1868        );
1869        let auto = demand_index_with_kernel(&input, Kernel::Auto)?;
1870        let scalar = demand_index_with_kernel(&input, Kernel::Scalar)?;
1871        assert_series_eq(&auto.demand_index, &scalar.demand_index);
1872        assert_series_eq(&auto.signal, &scalar.signal);
1873        Ok(())
1874    }
1875
1876    #[test]
1877    fn demand_index_invalid_len_bs() {
1878        let high = [1.0, 2.0, 3.0];
1879        let low = [0.5, 1.5, 2.5];
1880        let close = [0.75, 1.75, 2.75];
1881        let volume = [10.0, 11.0, 12.0];
1882        let input = DemandIndexInput::from_slices(
1883            &high,
1884            &low,
1885            &close,
1886            &volume,
1887            DemandIndexParams {
1888                len_bs: Some(0),
1889                len_bs_ma: Some(19),
1890                len_di_ma: Some(19),
1891                ma_type: Some("ema".to_string()),
1892            },
1893        );
1894        assert!(matches!(
1895            demand_index(&input),
1896            Err(DemandIndexError::InvalidLenBs { .. })
1897        ));
1898    }
1899
1900    #[test]
1901    fn demand_index_stream_matches_batch() -> Result<(), Box<dyn Error>> {
1902        let (high, low, close, volume) = load_ohlcv()?;
1903        let params = DemandIndexParams {
1904            len_bs: Some(10),
1905            len_bs_ma: Some(10),
1906            len_di_ma: Some(8),
1907            ma_type: Some("ema".to_string()),
1908        };
1909        let batch = demand_index(&DemandIndexInput::from_slices(
1910            &high,
1911            &low,
1912            &close,
1913            &volume,
1914            params.clone(),
1915        ))?;
1916        let mut stream = DemandIndexStream::try_new(params)?;
1917        let mut di = Vec::with_capacity(close.len());
1918        let mut signal = Vec::with_capacity(close.len());
1919        for i in 0..close.len() {
1920            if let Some((di_value, signal_value)) =
1921                stream.update(high[i], low[i], close[i], volume[i])
1922            {
1923                di.push(di_value);
1924                signal.push(signal_value);
1925            } else {
1926                di.push(f64::NAN);
1927                signal.push(f64::NAN);
1928            }
1929        }
1930        assert_series_eq(&di, &batch.demand_index);
1931        assert_series_eq(&signal, &batch.signal);
1932        Ok(())
1933    }
1934
1935    #[test]
1936    fn demand_index_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1937        let (high, low, close, volume) = load_ohlcv()?;
1938        let sweep = DemandIndexBatchRange::default();
1939        let batch =
1940            demand_index_batch_with_kernel(&high, &low, &close, &volume, &sweep, Kernel::Auto)?;
1941        assert_eq!(batch.rows, 1);
1942        assert_eq!(batch.cols, close.len());
1943        let single = demand_index(&DemandIndexInput::from_slices(
1944            &high,
1945            &low,
1946            &close,
1947            &volume,
1948            DemandIndexParams::default(),
1949        ))?;
1950        assert_series_eq(&batch.demand_index[..close.len()], &single.demand_index);
1951        assert_series_eq(&batch.signal[..close.len()], &single.signal);
1952        Ok(())
1953    }
1954
1955    #[test]
1956    fn demand_index_invalid_window_recovers() -> Result<(), Box<dyn Error>> {
1957        let (mut high, mut low, mut close, mut volume) = load_ohlcv()?;
1958        high[80] = f64::NAN;
1959        low[80] = f64::NAN;
1960        close[80] = f64::NAN;
1961        volume[80] = f64::NAN;
1962
1963        let out = demand_index(&DemandIndexInput::from_slices(
1964            &high,
1965            &low,
1966            &close,
1967            &volume,
1968            DemandIndexParams::default(),
1969        ))?;
1970        assert!(out.demand_index[80].is_nan());
1971        assert!(out.signal[80].is_nan());
1972        assert!(out.demand_index.iter().skip(140).any(|v| v.is_finite()));
1973        assert!(out.signal.iter().skip(160).any(|v| v.is_finite()));
1974        Ok(())
1975    }
1976}