1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(feature = "python")]
10use pyo3::wrap_pyfunction;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::Candles;
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::mem::ManuallyDrop;
28use thiserror::Error;
29
30#[derive(Debug, Clone)]
31pub enum DemandIndexData<'a> {
32 Candles {
33 candles: &'a Candles,
34 },
35 Slices {
36 high: &'a [f64],
37 low: &'a [f64],
38 close: &'a [f64],
39 volume: &'a [f64],
40 },
41}
42
43#[derive(Debug, Clone)]
44pub struct DemandIndexOutput {
45 pub demand_index: Vec<f64>,
46 pub signal: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
50#[cfg_attr(
51 all(target_arch = "wasm32", feature = "wasm"),
52 derive(Serialize, Deserialize)
53)]
54pub struct DemandIndexParams {
55 pub len_bs: Option<usize>,
56 pub len_bs_ma: Option<usize>,
57 pub len_di_ma: Option<usize>,
58 pub ma_type: Option<String>,
59}
60
61impl Default for DemandIndexParams {
62 fn default() -> Self {
63 Self {
64 len_bs: Some(19),
65 len_bs_ma: Some(19),
66 len_di_ma: Some(19),
67 ma_type: Some("ema".to_string()),
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
73pub struct DemandIndexInput<'a> {
74 pub data: DemandIndexData<'a>,
75 pub params: DemandIndexParams,
76}
77
78impl<'a> DemandIndexInput<'a> {
79 #[inline]
80 pub fn from_candles(candles: &'a Candles, params: DemandIndexParams) -> Self {
81 Self {
82 data: DemandIndexData::Candles { candles },
83 params,
84 }
85 }
86
87 #[inline]
88 pub fn from_slices(
89 high: &'a [f64],
90 low: &'a [f64],
91 close: &'a [f64],
92 volume: &'a [f64],
93 params: DemandIndexParams,
94 ) -> Self {
95 Self {
96 data: DemandIndexData::Slices {
97 high,
98 low,
99 close,
100 volume,
101 },
102 params,
103 }
104 }
105
106 #[inline]
107 pub fn with_default_candles(candles: &'a Candles) -> Self {
108 Self::from_candles(candles, DemandIndexParams::default())
109 }
110
111 #[inline]
112 pub fn get_len_bs(&self) -> usize {
113 self.params.len_bs.unwrap_or(19)
114 }
115
116 #[inline]
117 pub fn get_len_bs_ma(&self) -> usize {
118 self.params.len_bs_ma.unwrap_or(19)
119 }
120
121 #[inline]
122 pub fn get_len_di_ma(&self) -> usize {
123 self.params.len_di_ma.unwrap_or(19)
124 }
125
126 #[inline]
127 pub fn get_ma_type(&self) -> &str {
128 self.params.ma_type.as_deref().unwrap_or("ema")
129 }
130}
131
132#[derive(Copy, Clone, Debug)]
133pub struct DemandIndexBuilder {
134 len_bs: Option<usize>,
135 len_bs_ma: Option<usize>,
136 len_di_ma: Option<usize>,
137 ma_type: Option<MaType>,
138 kernel: Kernel,
139}
140
141impl Default for DemandIndexBuilder {
142 fn default() -> Self {
143 Self {
144 len_bs: None,
145 len_bs_ma: None,
146 len_di_ma: None,
147 ma_type: None,
148 kernel: Kernel::Auto,
149 }
150 }
151}
152
153impl DemandIndexBuilder {
154 #[inline]
155 pub fn new() -> Self {
156 Self::default()
157 }
158
159 #[inline]
160 pub fn len_bs(mut self, len_bs: usize) -> Self {
161 self.len_bs = Some(len_bs);
162 self
163 }
164
165 #[inline]
166 pub fn len_bs_ma(mut self, len_bs_ma: usize) -> Self {
167 self.len_bs_ma = Some(len_bs_ma);
168 self
169 }
170
171 #[inline]
172 pub fn len_di_ma(mut self, len_di_ma: usize) -> Self {
173 self.len_di_ma = Some(len_di_ma);
174 self
175 }
176
177 #[inline]
178 pub fn ma_type(mut self, ma_type: &str) -> Result<Self, DemandIndexError> {
179 self.ma_type = Some(MaType::parse(ma_type)?);
180 Ok(self)
181 }
182
183 #[inline]
184 pub fn kernel(mut self, kernel: Kernel) -> Self {
185 self.kernel = kernel;
186 self
187 }
188
189 #[inline]
190 pub fn apply(self, candles: &Candles) -> Result<DemandIndexOutput, DemandIndexError> {
191 let input = DemandIndexInput::from_candles(
192 candles,
193 DemandIndexParams {
194 len_bs: self.len_bs,
195 len_bs_ma: self.len_bs_ma,
196 len_di_ma: self.len_di_ma,
197 ma_type: Some(self.ma_type.unwrap_or(MaType::Ema).as_str().to_string()),
198 },
199 );
200 demand_index_with_kernel(&input, self.kernel)
201 }
202
203 #[inline]
204 pub fn apply_slices(
205 self,
206 high: &[f64],
207 low: &[f64],
208 close: &[f64],
209 volume: &[f64],
210 ) -> Result<DemandIndexOutput, DemandIndexError> {
211 let input = DemandIndexInput::from_slices(
212 high,
213 low,
214 close,
215 volume,
216 DemandIndexParams {
217 len_bs: self.len_bs,
218 len_bs_ma: self.len_bs_ma,
219 len_di_ma: self.len_di_ma,
220 ma_type: Some(self.ma_type.unwrap_or(MaType::Ema).as_str().to_string()),
221 },
222 );
223 demand_index_with_kernel(&input, self.kernel)
224 }
225
226 #[inline]
227 pub fn into_stream(self) -> Result<DemandIndexStream, DemandIndexError> {
228 DemandIndexStream::try_new(DemandIndexParams {
229 len_bs: self.len_bs,
230 len_bs_ma: self.len_bs_ma,
231 len_di_ma: self.len_di_ma,
232 ma_type: Some(self.ma_type.unwrap_or(MaType::Ema).as_str().to_string()),
233 })
234 }
235}
236
237#[derive(Debug, Error)]
238pub enum DemandIndexError {
239 #[error("demand_index: Input data slice is empty.")]
240 EmptyInputData,
241 #[error("demand_index: All values are NaN.")]
242 AllValuesNaN,
243 #[error("demand_index: Invalid len_bs: len_bs = {len_bs}, data length = {data_len}")]
244 InvalidLenBs { len_bs: usize, data_len: usize },
245 #[error("demand_index: Invalid len_bs_ma: len_bs_ma = {len_bs_ma}, data length = {data_len}")]
246 InvalidLenBsMa { len_bs_ma: usize, data_len: usize },
247 #[error("demand_index: Invalid len_di_ma: len_di_ma = {len_di_ma}, data length = {data_len}")]
248 InvalidLenDiMa { len_di_ma: usize, data_len: usize },
249 #[error("demand_index: Invalid ma_type: {ma_type}")]
250 InvalidMaType { ma_type: String },
251 #[error(
252 "demand_index: Inconsistent slice lengths: high={high_len}, low={low_len}, close={close_len}, volume={volume_len}"
253 )]
254 InconsistentSliceLengths {
255 high_len: usize,
256 low_len: usize,
257 close_len: usize,
258 volume_len: usize,
259 },
260 #[error("demand_index: Not enough valid data: needed = {needed}, valid = {valid}")]
261 NotEnoughValidData { needed: usize, valid: usize },
262 #[error(
263 "demand_index: Output length mismatch: expected = {expected}, demand_index = {demand_index_got}, signal = {signal_got}"
264 )]
265 OutputLengthMismatch {
266 expected: usize,
267 demand_index_got: usize,
268 signal_got: usize,
269 },
270 #[error("demand_index: Invalid range: start={start}, end={end}, step={step}")]
271 InvalidRange {
272 start: String,
273 end: String,
274 step: String,
275 },
276 #[error("demand_index: Invalid kernel for batch: {0:?}")]
277 InvalidKernelForBatch(Kernel),
278}
279
280#[derive(Copy, Clone, Debug, PartialEq, Eq)]
281enum MaType {
282 Ema,
283 Sma,
284 Wma,
285 Rma,
286}
287
288impl MaType {
289 #[inline]
290 fn parse(value: &str) -> Result<Self, DemandIndexError> {
291 if value.eq_ignore_ascii_case("ema") {
292 return Ok(Self::Ema);
293 }
294 if value.eq_ignore_ascii_case("sma") {
295 return Ok(Self::Sma);
296 }
297 if value.eq_ignore_ascii_case("wma") {
298 return Ok(Self::Wma);
299 }
300 if value.eq_ignore_ascii_case("rma") || value.eq_ignore_ascii_case("wilders") {
301 return Ok(Self::Rma);
302 }
303 Err(DemandIndexError::InvalidMaType {
304 ma_type: value.to_string(),
305 })
306 }
307
308 #[inline]
309 fn as_str(self) -> &'static str {
310 match self {
311 Self::Ema => "ema",
312 Self::Sma => "sma",
313 Self::Wma => "wma",
314 Self::Rma => "rma",
315 }
316 }
317}
318
319#[derive(Debug, Clone)]
320struct SmaState {
321 period: usize,
322 values: Vec<f64>,
323 valid: Vec<u8>,
324 idx: usize,
325 count: usize,
326 valid_count: usize,
327 sum: f64,
328}
329
330impl SmaState {
331 #[inline]
332 fn new(period: usize) -> Self {
333 Self {
334 period,
335 values: vec![0.0; period],
336 valid: vec![0u8; period],
337 idx: 0,
338 count: 0,
339 valid_count: 0,
340 sum: 0.0,
341 }
342 }
343
344 #[inline]
345 fn update(&mut self, value: Option<f64>) -> Option<f64> {
346 if self.count >= self.period {
347 let old_idx = self.idx;
348 if self.valid[old_idx] != 0 {
349 self.valid_count = self.valid_count.saturating_sub(1);
350 self.sum -= self.values[old_idx];
351 }
352 } else {
353 self.count += 1;
354 }
355
356 match value {
357 Some(v) if v.is_finite() => {
358 self.values[self.idx] = v;
359 self.valid[self.idx] = 1;
360 self.valid_count += 1;
361 self.sum += v;
362 }
363 _ => {
364 self.values[self.idx] = 0.0;
365 self.valid[self.idx] = 0;
366 }
367 }
368
369 self.idx += 1;
370 if self.idx == self.period {
371 self.idx = 0;
372 }
373
374 if self.count < self.period {
375 None
376 } else if self.valid_count == self.period {
377 Some(self.sum / self.period as f64)
378 } else {
379 Some(f64::NAN)
380 }
381 }
382}
383
384#[derive(Debug, Clone)]
385struct WmaState {
386 period: usize,
387 denom: f64,
388 values: Vec<f64>,
389 valid: Vec<u8>,
390 idx: usize,
391 count: usize,
392 valid_count: usize,
393}
394
395impl WmaState {
396 #[inline]
397 fn new(period: usize) -> Self {
398 Self {
399 period,
400 denom: (period * (period + 1) / 2) as f64,
401 values: vec![0.0; period],
402 valid: vec![0u8; period],
403 idx: 0,
404 count: 0,
405 valid_count: 0,
406 }
407 }
408
409 #[inline]
410 fn update(&mut self, value: Option<f64>) -> Option<f64> {
411 if self.count >= self.period {
412 let old_idx = self.idx;
413 if self.valid[old_idx] != 0 {
414 self.valid_count = self.valid_count.saturating_sub(1);
415 }
416 } else {
417 self.count += 1;
418 }
419
420 match value {
421 Some(v) if v.is_finite() => {
422 self.values[self.idx] = v;
423 self.valid[self.idx] = 1;
424 self.valid_count += 1;
425 }
426 _ => {
427 self.values[self.idx] = 0.0;
428 self.valid[self.idx] = 0;
429 }
430 }
431
432 self.idx += 1;
433 if self.idx == self.period {
434 self.idx = 0;
435 }
436
437 if self.count < self.period {
438 return None;
439 }
440 if self.valid_count != self.period {
441 return Some(f64::NAN);
442 }
443
444 let mut weighted = 0.0;
445 let mut weight = 1.0;
446 let mut pos = self.idx;
447 for _ in 0..self.period {
448 weighted += self.values[pos] * weight;
449 weight += 1.0;
450 pos += 1;
451 if pos == self.period {
452 pos = 0;
453 }
454 }
455 Some(weighted / self.denom)
456 }
457}
458
459#[derive(Debug, Clone)]
460struct ExpState {
461 alpha: f64,
462 value: f64,
463 initialized: bool,
464}
465
466impl ExpState {
467 #[inline]
468 fn ema(period: usize) -> Self {
469 Self {
470 alpha: 2.0 / (period as f64 + 1.0),
471 value: f64::NAN,
472 initialized: false,
473 }
474 }
475
476 #[inline]
477 fn rma(period: usize) -> Self {
478 Self {
479 alpha: 1.0 / period as f64,
480 value: f64::NAN,
481 initialized: false,
482 }
483 }
484
485 #[inline]
486 fn update(&mut self, value: Option<f64>) -> Option<f64> {
487 match value {
488 Some(v) if v.is_finite() => {
489 if self.initialized {
490 self.value += self.alpha * (v - self.value);
491 } else {
492 self.value = v;
493 self.initialized = true;
494 }
495 Some(self.value)
496 }
497 _ => {
498 self.value = f64::NAN;
499 self.initialized = false;
500 Some(f64::NAN)
501 }
502 }
503 }
504}
505
506#[derive(Debug, Clone)]
507enum MaState {
508 Sma(SmaState),
509 Wma(WmaState),
510 Ema(ExpState),
511 Rma(ExpState),
512}
513
514impl MaState {
515 #[inline]
516 fn new(kind: MaType, period: usize) -> Self {
517 match kind {
518 MaType::Sma => Self::Sma(SmaState::new(period)),
519 MaType::Wma => Self::Wma(WmaState::new(period)),
520 MaType::Ema => Self::Ema(ExpState::ema(period)),
521 MaType::Rma => Self::Rma(ExpState::rma(period)),
522 }
523 }
524
525 #[inline]
526 fn update(&mut self, value: Option<f64>) -> Option<f64> {
527 match self {
528 Self::Sma(state) => state.update(value),
529 Self::Wma(state) => state.update(value),
530 Self::Ema(state) => state.update(value),
531 Self::Rma(state) => state.update(value),
532 }
533 }
534}
535
536#[derive(Debug, Clone)]
537pub struct DemandIndexStream {
538 len_bs: usize,
539 len_bs_ma: usize,
540 len_di_ma: usize,
541 ma_type: MaType,
542 h0: f64,
543 l0: f64,
544 has_h0_l0: bool,
545 prev_p: f64,
546 has_prev_p: bool,
547 volume_avg: MaState,
548 bp_avg: MaState,
549 sp_avg: MaState,
550 signal_avg: SmaState,
551}
552
553impl DemandIndexStream {
554 pub fn try_new(params: DemandIndexParams) -> Result<Self, DemandIndexError> {
555 let len_bs = params.len_bs.unwrap_or(19);
556 let len_bs_ma = params.len_bs_ma.unwrap_or(19);
557 let len_di_ma = params.len_di_ma.unwrap_or(19);
558 let ma_type = MaType::parse(params.ma_type.as_deref().unwrap_or("ema"))?;
559 validate_lengths(len_bs, len_bs_ma, len_di_ma, 0)?;
560
561 Ok(Self {
562 len_bs,
563 len_bs_ma,
564 len_di_ma,
565 ma_type,
566 h0: f64::NAN,
567 l0: f64::NAN,
568 has_h0_l0: false,
569 prev_p: f64::NAN,
570 has_prev_p: false,
571 volume_avg: MaState::new(ma_type, len_bs),
572 bp_avg: MaState::new(ma_type, len_bs_ma),
573 sp_avg: MaState::new(ma_type, len_bs_ma),
574 signal_avg: SmaState::new(len_di_ma),
575 })
576 }
577
578 #[inline]
579 pub fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
580 if !valid_ohlcv_bar(high, low, close, volume) {
581 self.volume_avg.update(None);
582 let di = demand_index_finalize(None, None, &mut self.bp_avg, &mut self.sp_avg);
583 let signal = update_signal(&mut self.signal_avg, di);
584 self.has_prev_p = false;
585 return if di.is_none() && signal.is_none() {
586 None
587 } else {
588 Some((di.unwrap_or(f64::NAN), signal.unwrap_or(f64::NAN)))
589 };
590 }
591
592 if !self.has_h0_l0 {
593 self.h0 = high;
594 self.l0 = low;
595 self.has_h0_l0 = true;
596 }
597
598 let volume_avg = self.volume_avg.update(Some(volume));
599 let p = high + low + 2.0 * close;
600
601 if !self.has_prev_p {
602 let di = demand_index_finalize(None, None, &mut self.bp_avg, &mut self.sp_avg);
603 let signal = update_signal(&mut self.signal_avg, di);
604 self.prev_p = p;
605 self.has_prev_p = true;
606 return if di.is_none() && signal.is_none() {
607 None
608 } else {
609 Some((di.unwrap_or(f64::NAN), signal.unwrap_or(f64::NAN)))
610 };
611 }
612
613 let bp_sp = compute_bp_sp(volume_avg, volume, p, self.prev_p, self.h0, self.l0);
614 let di = demand_index_finalize(
615 bp_sp.map(|x| x.0),
616 bp_sp.map(|x| x.1),
617 &mut self.bp_avg,
618 &mut self.sp_avg,
619 );
620 let signal = update_signal(&mut self.signal_avg, di);
621
622 self.prev_p = p;
623 self.has_prev_p = true;
624 Some((di.unwrap_or(f64::NAN), signal.unwrap_or(f64::NAN)))
625 }
626
627 #[inline]
628 pub fn get_warmup_period(&self) -> usize {
629 signal_warmup(self.ma_type, self.len_bs, self.len_bs_ma, self.len_di_ma)
630 }
631}
632
633#[inline]
634pub fn demand_index(input: &DemandIndexInput) -> Result<DemandIndexOutput, DemandIndexError> {
635 demand_index_with_kernel(input, Kernel::Auto)
636}
637
638#[inline]
639fn validate_lengths(
640 len_bs: usize,
641 len_bs_ma: usize,
642 len_di_ma: usize,
643 data_len: usize,
644) -> Result<(), DemandIndexError> {
645 if len_bs == 0 {
646 return Err(DemandIndexError::InvalidLenBs { len_bs, data_len });
647 }
648 if len_bs_ma == 0 {
649 return Err(DemandIndexError::InvalidLenBsMa {
650 len_bs_ma,
651 data_len,
652 });
653 }
654 if len_di_ma == 0 {
655 return Err(DemandIndexError::InvalidLenDiMa {
656 len_di_ma,
657 data_len,
658 });
659 }
660 Ok(())
661}
662
663#[inline]
664fn valid_ohlcv_bar(high: f64, low: f64, close: f64, volume: f64) -> bool {
665 high.is_finite() && low.is_finite() && close.is_finite() && volume.is_finite()
666}
667
668#[inline]
669fn extract_input<'a>(
670 input: &'a DemandIndexInput<'a>,
671) -> Result<(&'a [f64], &'a [f64], &'a [f64], &'a [f64]), DemandIndexError> {
672 let (high, low, close, volume) = match &input.data {
673 DemandIndexData::Candles { candles } => (
674 candles.high.as_slice(),
675 candles.low.as_slice(),
676 candles.close.as_slice(),
677 candles.volume.as_slice(),
678 ),
679 DemandIndexData::Slices {
680 high,
681 low,
682 close,
683 volume,
684 } => (*high, *low, *close, *volume),
685 };
686
687 if high.is_empty() || low.is_empty() || close.is_empty() || volume.is_empty() {
688 return Err(DemandIndexError::EmptyInputData);
689 }
690 if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
691 return Err(DemandIndexError::InconsistentSliceLengths {
692 high_len: high.len(),
693 low_len: low.len(),
694 close_len: close.len(),
695 volume_len: volume.len(),
696 });
697 }
698 Ok((high, low, close, volume))
699}
700
701#[inline]
702fn first_valid_ohlcv(high: &[f64], low: &[f64], close: &[f64], volume: &[f64]) -> Option<usize> {
703 let mut i = 0usize;
704 while i < high.len() {
705 if valid_ohlcv_bar(high[i], low[i], close[i], volume[i]) {
706 return Some(i);
707 }
708 i += 1;
709 }
710 None
711}
712
713#[inline]
714fn count_valid_ohlcv(high: &[f64], low: &[f64], close: &[f64], volume: &[f64]) -> usize {
715 let mut count = 0usize;
716 for i in 0..high.len() {
717 if valid_ohlcv_bar(high[i], low[i], close[i], volume[i]) {
718 count += 1;
719 }
720 }
721 count
722}
723
724#[inline]
725fn di_warmup(ma_type: MaType, len_bs: usize, len_bs_ma: usize) -> usize {
726 match ma_type {
727 MaType::Ema | MaType::Rma => 1,
728 MaType::Sma | MaType::Wma => len_bs.saturating_sub(1).max(1) + len_bs_ma.saturating_sub(1),
729 }
730}
731
732#[inline]
733fn signal_warmup(ma_type: MaType, len_bs: usize, len_bs_ma: usize, len_di_ma: usize) -> usize {
734 di_warmup(ma_type, len_bs, len_bs_ma) + len_di_ma.saturating_sub(1)
735}
736
737#[inline]
738fn needed_valid_bars(ma_type: MaType, len_bs: usize, len_bs_ma: usize, len_di_ma: usize) -> usize {
739 signal_warmup(ma_type, len_bs, len_bs_ma, len_di_ma) + 1
740}
741
742#[inline]
743fn normalize_h0_l0(h0: f64, l0: f64) -> f64 {
744 (h0 - l0).abs().max(1e-12)
745}
746
747#[inline]
748fn compute_bp_sp(
749 volume_avg: Option<f64>,
750 volume: f64,
751 p: f64,
752 prev_p: f64,
753 h0: f64,
754 l0: f64,
755) -> Option<(f64, f64)> {
756 let v_avg = volume_avg?;
757 if !v_avg.is_finite() || !p.is_finite() || !prev_p.is_finite() {
758 return None;
759 }
760
761 let v_ratio = if v_avg == 0.0 { 1.0 } else { volume / v_avg };
762 if !v_ratio.is_finite() {
763 return None;
764 }
765
766 let denom = normalize_h0_l0(h0, l0);
767 let k = 0.375;
768
769 if p < prev_p {
770 if p == 0.0 {
771 return None;
772 }
773 let exponent = (k * (p + prev_p) / denom) * ((prev_p - p) / p);
774 return Some((v_ratio / exponent.exp(), v_ratio));
775 }
776 if p > prev_p {
777 if prev_p == 0.0 {
778 return None;
779 }
780 let exponent = (k * (p + prev_p) / denom) * ((p - prev_p) / prev_p);
781 return Some((v_ratio, v_ratio / exponent.exp()));
782 }
783 Some((v_ratio, v_ratio))
784}
785
786#[inline]
787fn finalize_di(bp_ma: Option<f64>, sp_ma: Option<f64>) -> Option<f64> {
788 let bp = bp_ma?;
789 let sp = sp_ma?;
790 if !bp.is_finite() || !sp.is_finite() {
791 return None;
792 }
793
794 if bp > sp {
795 if bp == 0.0 {
796 return Some(100.0);
797 }
798 return Some(100.0 * (1.0 - sp / bp));
799 }
800 if bp < sp {
801 if sp == 0.0 {
802 return Some(-100.0);
803 }
804 return Some(100.0 * (bp / sp - 1.0));
805 }
806 Some(0.0)
807}
808
809#[inline]
810fn demand_index_finalize(
811 bp: Option<f64>,
812 sp: Option<f64>,
813 bp_avg: &mut MaState,
814 sp_avg: &mut MaState,
815) -> Option<f64> {
816 let bp_ma = bp_avg.update(bp.filter(|v| v.is_finite()));
817 let sp_ma = sp_avg.update(sp.filter(|v| v.is_finite()));
818 finalize_di(bp_ma, sp_ma)
819}
820
821#[inline]
822fn update_signal(signal_avg: &mut SmaState, di: Option<f64>) -> Option<f64> {
823 signal_avg.update(di.filter(|v| v.is_finite()))
824}
825
826#[inline]
827fn demand_index_compute_into(
828 high: &[f64],
829 low: &[f64],
830 close: &[f64],
831 volume: &[f64],
832 len_bs: usize,
833 len_bs_ma: usize,
834 len_di_ma: usize,
835 ma_type: MaType,
836 demand_index_out: &mut [f64],
837 signal_out: &mut [f64],
838) {
839 let mut volume_avg = MaState::new(ma_type, len_bs);
840 let mut bp_avg = MaState::new(ma_type, len_bs_ma);
841 let mut sp_avg = MaState::new(ma_type, len_bs_ma);
842 let mut signal_avg = SmaState::new(len_di_ma);
843 let mut h0 = f64::NAN;
844 let mut l0 = f64::NAN;
845 let mut has_h0_l0 = false;
846 let mut prev_p = f64::NAN;
847 let mut has_prev_p = false;
848
849 for i in 0..high.len() {
850 let h = high[i];
851 let l = low[i];
852 let c = close[i];
853 let v = volume[i];
854
855 if !valid_ohlcv_bar(h, l, c, v) {
856 volume_avg.update(None);
857 let di = demand_index_finalize(None, None, &mut bp_avg, &mut sp_avg);
858 let signal = update_signal(&mut signal_avg, di);
859 demand_index_out[i] = di.unwrap_or(f64::NAN);
860 signal_out[i] = signal.unwrap_or(f64::NAN);
861 has_prev_p = false;
862 continue;
863 }
864
865 if !has_h0_l0 {
866 h0 = h;
867 l0 = l;
868 has_h0_l0 = true;
869 }
870
871 let volume_avg_now = volume_avg.update(Some(v));
872 let p = h + l + 2.0 * c;
873
874 if !has_prev_p {
875 let di = demand_index_finalize(None, None, &mut bp_avg, &mut sp_avg);
876 let signal = update_signal(&mut signal_avg, di);
877 demand_index_out[i] = di.unwrap_or(f64::NAN);
878 signal_out[i] = signal.unwrap_or(f64::NAN);
879 prev_p = p;
880 has_prev_p = true;
881 continue;
882 }
883
884 let bp_sp = compute_bp_sp(volume_avg_now, v, p, prev_p, h0, l0);
885 let di = demand_index_finalize(
886 bp_sp.map(|x| x.0),
887 bp_sp.map(|x| x.1),
888 &mut bp_avg,
889 &mut sp_avg,
890 );
891 let signal = update_signal(&mut signal_avg, di);
892 demand_index_out[i] = di.unwrap_or(f64::NAN);
893 signal_out[i] = signal.unwrap_or(f64::NAN);
894 prev_p = p;
895 has_prev_p = true;
896 }
897}
898
899#[inline]
900pub fn demand_index_with_kernel(
901 input: &DemandIndexInput,
902 kernel: Kernel,
903) -> Result<DemandIndexOutput, DemandIndexError> {
904 let (high, low, close, volume) = extract_input(input)?;
905 let len = high.len();
906 let len_bs = input.get_len_bs();
907 let len_bs_ma = input.get_len_bs_ma();
908 let len_di_ma = input.get_len_di_ma();
909 validate_lengths(len_bs, len_bs_ma, len_di_ma, len)?;
910 let ma_type = MaType::parse(input.get_ma_type())?;
911
912 if len_bs > len {
913 return Err(DemandIndexError::InvalidLenBs {
914 len_bs,
915 data_len: len,
916 });
917 }
918 if len_bs_ma > len {
919 return Err(DemandIndexError::InvalidLenBsMa {
920 len_bs_ma,
921 data_len: len,
922 });
923 }
924 if len_di_ma > len {
925 return Err(DemandIndexError::InvalidLenDiMa {
926 len_di_ma,
927 data_len: len,
928 });
929 }
930
931 let first =
932 first_valid_ohlcv(high, low, close, volume).ok_or(DemandIndexError::AllValuesNaN)?;
933 let valid = count_valid_ohlcv(high, low, close, volume);
934 let needed = needed_valid_bars(ma_type, len_bs, len_bs_ma, len_di_ma);
935 if valid < needed {
936 return Err(DemandIndexError::NotEnoughValidData { needed, valid });
937 }
938
939 let _ = match kernel {
940 Kernel::Auto => detect_best_kernel(),
941 other => other.to_non_batch(),
942 };
943
944 let mut demand_index_out = alloc_with_nan_prefix(len, first);
945 let mut signal_out = alloc_with_nan_prefix(len, first);
946 demand_index_compute_into(
947 high,
948 low,
949 close,
950 volume,
951 len_bs,
952 len_bs_ma,
953 len_di_ma,
954 ma_type,
955 &mut demand_index_out,
956 &mut signal_out,
957 );
958
959 Ok(DemandIndexOutput {
960 demand_index: demand_index_out,
961 signal: signal_out,
962 })
963}
964
965#[inline]
966pub fn demand_index_into_slices(
967 demand_index_out: &mut [f64],
968 signal_out: &mut [f64],
969 input: &DemandIndexInput,
970 kernel: Kernel,
971) -> Result<(), DemandIndexError> {
972 let (high, low, close, volume) = extract_input(input)?;
973 let len = high.len();
974 if demand_index_out.len() != len || signal_out.len() != len {
975 return Err(DemandIndexError::OutputLengthMismatch {
976 expected: len,
977 demand_index_got: demand_index_out.len(),
978 signal_got: signal_out.len(),
979 });
980 }
981
982 let len_bs = input.get_len_bs();
983 let len_bs_ma = input.get_len_bs_ma();
984 let len_di_ma = input.get_len_di_ma();
985 validate_lengths(len_bs, len_bs_ma, len_di_ma, len)?;
986 let ma_type = MaType::parse(input.get_ma_type())?;
987 let valid = count_valid_ohlcv(high, low, close, volume);
988 let needed = needed_valid_bars(ma_type, len_bs, len_bs_ma, len_di_ma);
989 if valid < needed {
990 return Err(DemandIndexError::NotEnoughValidData { needed, valid });
991 }
992
993 let _ = match kernel {
994 Kernel::Auto => detect_best_kernel(),
995 other => other.to_non_batch(),
996 };
997
998 demand_index_compute_into(
999 high,
1000 low,
1001 close,
1002 volume,
1003 len_bs,
1004 len_bs_ma,
1005 len_di_ma,
1006 ma_type,
1007 demand_index_out,
1008 signal_out,
1009 );
1010 Ok(())
1011}
1012
1013#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1014#[inline]
1015pub fn demand_index_into(
1016 input: &DemandIndexInput,
1017 demand_index_out: &mut [f64],
1018 signal_out: &mut [f64],
1019) -> Result<(), DemandIndexError> {
1020 demand_index_into_slices(demand_index_out, signal_out, input, Kernel::Auto)
1021}
1022
1023#[derive(Clone, Debug)]
1024pub struct DemandIndexBatchRange {
1025 pub len_bs: (usize, usize, usize),
1026 pub len_bs_ma: (usize, usize, usize),
1027 pub len_di_ma: (usize, usize, usize),
1028 pub ma_type: Option<String>,
1029}
1030
1031impl Default for DemandIndexBatchRange {
1032 fn default() -> Self {
1033 Self {
1034 len_bs: (19, 19, 0),
1035 len_bs_ma: (19, 19, 0),
1036 len_di_ma: (19, 19, 0),
1037 ma_type: Some("ema".to_string()),
1038 }
1039 }
1040}
1041
1042#[derive(Clone, Debug, Default)]
1043pub struct DemandIndexBatchBuilder {
1044 range: DemandIndexBatchRange,
1045 kernel: Kernel,
1046}
1047
1048impl DemandIndexBatchBuilder {
1049 #[inline]
1050 pub fn new() -> Self {
1051 Self::default()
1052 }
1053
1054 #[inline]
1055 pub fn kernel(mut self, kernel: Kernel) -> Self {
1056 self.kernel = kernel;
1057 self
1058 }
1059
1060 #[inline]
1061 pub fn len_bs_range(mut self, start: usize, end: usize, step: usize) -> Self {
1062 self.range.len_bs = (start, end, step);
1063 self
1064 }
1065
1066 #[inline]
1067 pub fn len_bs_ma_range(mut self, start: usize, end: usize, step: usize) -> Self {
1068 self.range.len_bs_ma = (start, end, step);
1069 self
1070 }
1071
1072 #[inline]
1073 pub fn len_di_ma_range(mut self, start: usize, end: usize, step: usize) -> Self {
1074 self.range.len_di_ma = (start, end, step);
1075 self
1076 }
1077
1078 #[inline]
1079 pub fn ma_type(mut self, ma_type: &str) -> Result<Self, DemandIndexError> {
1080 self.range.ma_type = Some(MaType::parse(ma_type)?.as_str().to_string());
1081 Ok(self)
1082 }
1083
1084 #[inline]
1085 pub fn apply_slices(
1086 self,
1087 high: &[f64],
1088 low: &[f64],
1089 close: &[f64],
1090 volume: &[f64],
1091 ) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1092 demand_index_batch_with_kernel(high, low, close, volume, &self.range, self.kernel)
1093 }
1094
1095 #[inline]
1096 pub fn apply_candles(
1097 self,
1098 candles: &Candles,
1099 ) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1100 self.apply_slices(&candles.high, &candles.low, &candles.close, &candles.volume)
1101 }
1102}
1103
1104#[derive(Clone, Debug)]
1105pub struct DemandIndexBatchOutput {
1106 pub demand_index: Vec<f64>,
1107 pub signal: Vec<f64>,
1108 pub combos: Vec<DemandIndexParams>,
1109 pub rows: usize,
1110 pub cols: usize,
1111}
1112
1113impl DemandIndexBatchOutput {
1114 pub fn row_for_params(&self, params: &DemandIndexParams) -> Option<usize> {
1115 self.combos.iter().position(|combo| {
1116 combo.len_bs.unwrap_or(19) == params.len_bs.unwrap_or(19)
1117 && combo.len_bs_ma.unwrap_or(19) == params.len_bs_ma.unwrap_or(19)
1118 && combo.len_di_ma.unwrap_or(19) == params.len_di_ma.unwrap_or(19)
1119 && combo
1120 .ma_type
1121 .as_deref()
1122 .unwrap_or("ema")
1123 .eq_ignore_ascii_case(params.ma_type.as_deref().unwrap_or("ema"))
1124 })
1125 }
1126}
1127
1128#[inline]
1129fn expand_axis_usize(
1130 (start, end, step): (usize, usize, usize),
1131) -> Result<Vec<usize>, DemandIndexError> {
1132 if step == 0 || start == end {
1133 return Ok(vec![start]);
1134 }
1135
1136 let mut out = Vec::new();
1137 if start < end {
1138 let mut x = start;
1139 while x <= end {
1140 out.push(x);
1141 let next = x.saturating_add(step);
1142 if next == x {
1143 break;
1144 }
1145 x = next;
1146 }
1147 } else {
1148 let mut x = start;
1149 loop {
1150 out.push(x);
1151 if x == end {
1152 break;
1153 }
1154 let next = x.saturating_sub(step);
1155 if next == x || next < end {
1156 break;
1157 }
1158 x = next;
1159 }
1160 }
1161
1162 if out.is_empty() {
1163 return Err(DemandIndexError::InvalidRange {
1164 start: start.to_string(),
1165 end: end.to_string(),
1166 step: step.to_string(),
1167 });
1168 }
1169 Ok(out)
1170}
1171
1172#[inline]
1173fn expand_grid_demand_index(
1174 range: &DemandIndexBatchRange,
1175) -> Result<Vec<DemandIndexParams>, DemandIndexError> {
1176 let len_bs_values = expand_axis_usize(range.len_bs)?;
1177 let len_bs_ma_values = expand_axis_usize(range.len_bs_ma)?;
1178 let len_di_ma_values = expand_axis_usize(range.len_di_ma)?;
1179 let ma_type = MaType::parse(range.ma_type.as_deref().unwrap_or("ema"))?;
1180
1181 let mut out =
1182 Vec::with_capacity(len_bs_values.len() * len_bs_ma_values.len() * len_di_ma_values.len());
1183 for &len_bs in &len_bs_values {
1184 for &len_bs_ma in &len_bs_ma_values {
1185 for &len_di_ma in &len_di_ma_values {
1186 validate_lengths(len_bs, len_bs_ma, len_di_ma, 0)?;
1187 out.push(DemandIndexParams {
1188 len_bs: Some(len_bs),
1189 len_bs_ma: Some(len_bs_ma),
1190 len_di_ma: Some(len_di_ma),
1191 ma_type: Some(ma_type.as_str().to_string()),
1192 });
1193 }
1194 }
1195 }
1196 Ok(out)
1197}
1198
1199#[inline]
1200pub fn demand_index_batch_with_kernel(
1201 high: &[f64],
1202 low: &[f64],
1203 close: &[f64],
1204 volume: &[f64],
1205 sweep: &DemandIndexBatchRange,
1206 kernel: Kernel,
1207) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1208 let batch_kernel = match kernel {
1209 Kernel::Auto => detect_best_batch_kernel(),
1210 other if other.is_batch() => other,
1211 other => return Err(DemandIndexError::InvalidKernelForBatch(other)),
1212 };
1213 demand_index_batch_par_slice(high, low, close, volume, sweep, batch_kernel.to_non_batch())
1214}
1215
1216#[inline]
1217pub fn demand_index_batch_slice(
1218 high: &[f64],
1219 low: &[f64],
1220 close: &[f64],
1221 volume: &[f64],
1222 sweep: &DemandIndexBatchRange,
1223 kernel: Kernel,
1224) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1225 demand_index_batch_inner(high, low, close, volume, sweep, kernel, false)
1226}
1227
1228#[inline]
1229pub fn demand_index_batch_par_slice(
1230 high: &[f64],
1231 low: &[f64],
1232 close: &[f64],
1233 volume: &[f64],
1234 sweep: &DemandIndexBatchRange,
1235 kernel: Kernel,
1236) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1237 demand_index_batch_inner(high, low, close, volume, sweep, kernel, true)
1238}
1239
1240fn demand_index_batch_inner(
1241 high: &[f64],
1242 low: &[f64],
1243 close: &[f64],
1244 volume: &[f64],
1245 sweep: &DemandIndexBatchRange,
1246 kernel: Kernel,
1247 parallel: bool,
1248) -> Result<DemandIndexBatchOutput, DemandIndexError> {
1249 if high.is_empty() || low.is_empty() || close.is_empty() || volume.is_empty() {
1250 return Err(DemandIndexError::EmptyInputData);
1251 }
1252 if high.len() != low.len() || high.len() != close.len() || high.len() != volume.len() {
1253 return Err(DemandIndexError::InconsistentSliceLengths {
1254 high_len: high.len(),
1255 low_len: low.len(),
1256 close_len: close.len(),
1257 volume_len: volume.len(),
1258 });
1259 }
1260
1261 let combos = expand_grid_demand_index(sweep)?;
1262 let rows = combos.len();
1263 let cols = high.len();
1264 let first =
1265 first_valid_ohlcv(high, low, close, volume).ok_or(DemandIndexError::AllValuesNaN)?;
1266 let valid = count_valid_ohlcv(high, low, close, volume);
1267
1268 let mut di_mu = make_uninit_matrix(rows, cols);
1269 let mut signal_mu = make_uninit_matrix(rows, cols);
1270 let mut di_warms = Vec::with_capacity(rows);
1271 let mut signal_warms = Vec::with_capacity(rows);
1272
1273 for combo in &combos {
1274 let len_bs = combo.len_bs.unwrap_or(19);
1275 let len_bs_ma = combo.len_bs_ma.unwrap_or(19);
1276 let len_di_ma = combo.len_di_ma.unwrap_or(19);
1277 let ma_type = MaType::parse(combo.ma_type.as_deref().unwrap_or("ema"))?;
1278 let needed = needed_valid_bars(ma_type, len_bs, len_bs_ma, len_di_ma);
1279 if valid < needed {
1280 return Err(DemandIndexError::NotEnoughValidData { needed, valid });
1281 }
1282 di_warms.push((first + di_warmup(ma_type, len_bs, len_bs_ma)).min(cols));
1283 signal_warms.push((first + signal_warmup(ma_type, len_bs, len_bs_ma, len_di_ma)).min(cols));
1284 }
1285
1286 init_matrix_prefixes(&mut di_mu, cols, &di_warms);
1287 init_matrix_prefixes(&mut signal_mu, cols, &signal_warms);
1288
1289 let mut di_guard = ManuallyDrop::new(di_mu);
1290 let mut signal_guard = ManuallyDrop::new(signal_mu);
1291 let di_out = unsafe {
1292 std::slice::from_raw_parts_mut(di_guard.as_mut_ptr() as *mut f64, di_guard.len())
1293 };
1294 let signal_out = unsafe {
1295 std::slice::from_raw_parts_mut(signal_guard.as_mut_ptr() as *mut f64, signal_guard.len())
1296 };
1297
1298 if parallel {
1299 #[cfg(not(target_arch = "wasm32"))]
1300 {
1301 di_out
1302 .par_chunks_mut(cols)
1303 .zip(signal_out.par_chunks_mut(cols))
1304 .zip(combos.par_iter())
1305 .for_each(|((dst_di, dst_signal), combo)| {
1306 let ma_type = MaType::parse(combo.ma_type.as_deref().unwrap_or("ema"))
1307 .unwrap_or(MaType::Ema);
1308 demand_index_compute_into(
1309 high,
1310 low,
1311 close,
1312 volume,
1313 combo.len_bs.unwrap_or(19),
1314 combo.len_bs_ma.unwrap_or(19),
1315 combo.len_di_ma.unwrap_or(19),
1316 ma_type,
1317 dst_di,
1318 dst_signal,
1319 );
1320 });
1321 }
1322 } else {
1323 let _ = kernel;
1324 for (row, combo) in combos.iter().enumerate() {
1325 let start = row * cols;
1326 let end = start + cols;
1327 demand_index_compute_into(
1328 high,
1329 low,
1330 close,
1331 volume,
1332 combo.len_bs.unwrap_or(19),
1333 combo.len_bs_ma.unwrap_or(19),
1334 combo.len_di_ma.unwrap_or(19),
1335 MaType::parse(combo.ma_type.as_deref().unwrap_or("ema"))?,
1336 &mut di_out[start..end],
1337 &mut signal_out[start..end],
1338 );
1339 }
1340 }
1341
1342 let demand_index = unsafe {
1343 Vec::from_raw_parts(
1344 di_guard.as_mut_ptr() as *mut f64,
1345 di_guard.len(),
1346 di_guard.capacity(),
1347 )
1348 };
1349 let signal = unsafe {
1350 Vec::from_raw_parts(
1351 signal_guard.as_mut_ptr() as *mut f64,
1352 signal_guard.len(),
1353 signal_guard.capacity(),
1354 )
1355 };
1356 core::mem::forget(di_guard);
1357 core::mem::forget(signal_guard);
1358
1359 Ok(DemandIndexBatchOutput {
1360 demand_index,
1361 signal,
1362 combos,
1363 rows,
1364 cols,
1365 })
1366}
1367
1368pub fn demand_index_batch_inner_into(
1369 high: &[f64],
1370 low: &[f64],
1371 close: &[f64],
1372 volume: &[f64],
1373 sweep: &DemandIndexBatchRange,
1374 kernel: Kernel,
1375 demand_index_out: &mut [f64],
1376 signal_out: &mut [f64],
1377) -> Result<Vec<DemandIndexParams>, DemandIndexError> {
1378 let out = demand_index_batch_inner(high, low, close, volume, sweep, kernel, false)?;
1379 let total = out.rows * out.cols;
1380 if demand_index_out.len() != total || signal_out.len() != total {
1381 return Err(DemandIndexError::OutputLengthMismatch {
1382 expected: total,
1383 demand_index_got: demand_index_out.len(),
1384 signal_got: signal_out.len(),
1385 });
1386 }
1387 demand_index_out.copy_from_slice(&out.demand_index);
1388 signal_out.copy_from_slice(&out.signal);
1389 Ok(out.combos)
1390}
1391
1392#[cfg(feature = "python")]
1393#[pyfunction(name = "demand_index")]
1394#[pyo3(signature = (high, low, close, volume, len_bs=None, len_bs_ma=None, len_di_ma=None, ma_type=None, kernel=None))]
1395pub fn demand_index_py<'py>(
1396 py: Python<'py>,
1397 high: PyReadonlyArray1<'py, f64>,
1398 low: PyReadonlyArray1<'py, f64>,
1399 close: PyReadonlyArray1<'py, f64>,
1400 volume: PyReadonlyArray1<'py, f64>,
1401 len_bs: Option<usize>,
1402 len_bs_ma: Option<usize>,
1403 len_di_ma: Option<usize>,
1404 ma_type: Option<&str>,
1405 kernel: Option<&str>,
1406) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1407 let high = high.as_slice()?;
1408 let low = low.as_slice()?;
1409 let close = close.as_slice()?;
1410 let volume = volume.as_slice()?;
1411 let kern = validate_kernel(kernel, false)?;
1412 let input = DemandIndexInput::from_slices(
1413 high,
1414 low,
1415 close,
1416 volume,
1417 DemandIndexParams {
1418 len_bs,
1419 len_bs_ma,
1420 len_di_ma,
1421 ma_type: ma_type.map(|s| s.to_string()),
1422 },
1423 );
1424 let out = py
1425 .allow_threads(|| demand_index_with_kernel(&input, kern))
1426 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1427 Ok((
1428 out.demand_index.into_pyarray(py),
1429 out.signal.into_pyarray(py),
1430 ))
1431}
1432
1433#[cfg(feature = "python")]
1434#[pyclass(name = "DemandIndexStream")]
1435pub struct DemandIndexStreamPy {
1436 stream: DemandIndexStream,
1437}
1438
1439#[cfg(feature = "python")]
1440#[pymethods]
1441impl DemandIndexStreamPy {
1442 #[new]
1443 #[pyo3(signature = (len_bs=19, len_bs_ma=19, len_di_ma=19, ma_type="ema"))]
1444 fn new(len_bs: usize, len_bs_ma: usize, len_di_ma: usize, ma_type: &str) -> PyResult<Self> {
1445 let stream = DemandIndexStream::try_new(DemandIndexParams {
1446 len_bs: Some(len_bs),
1447 len_bs_ma: Some(len_bs_ma),
1448 len_di_ma: Some(len_di_ma),
1449 ma_type: Some(ma_type.to_string()),
1450 })
1451 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1452 Ok(Self { stream })
1453 }
1454
1455 fn update(&mut self, high: f64, low: f64, close: f64, volume: f64) -> Option<(f64, f64)> {
1456 self.stream.update(high, low, close, volume)
1457 }
1458
1459 #[getter]
1460 fn warmup_period(&self) -> usize {
1461 self.stream.get_warmup_period()
1462 }
1463}
1464
1465#[cfg(feature = "python")]
1466#[pyfunction(name = "demand_index_batch")]
1467#[pyo3(signature = (high, low, close, volume, len_bs_range, len_bs_ma_range, len_di_ma_range, ma_type=None, kernel=None))]
1468pub fn demand_index_batch_py<'py>(
1469 py: Python<'py>,
1470 high: PyReadonlyArray1<'py, f64>,
1471 low: PyReadonlyArray1<'py, f64>,
1472 close: PyReadonlyArray1<'py, f64>,
1473 volume: PyReadonlyArray1<'py, f64>,
1474 len_bs_range: (usize, usize, usize),
1475 len_bs_ma_range: (usize, usize, usize),
1476 len_di_ma_range: (usize, usize, usize),
1477 ma_type: Option<&str>,
1478 kernel: Option<&str>,
1479) -> PyResult<Bound<'py, PyDict>> {
1480 let high = high.as_slice()?;
1481 let low = low.as_slice()?;
1482 let close = close.as_slice()?;
1483 let volume = volume.as_slice()?;
1484 let kern = validate_kernel(kernel, true)?;
1485 let sweep = DemandIndexBatchRange {
1486 len_bs: len_bs_range,
1487 len_bs_ma: len_bs_ma_range,
1488 len_di_ma: len_di_ma_range,
1489 ma_type: Some(ma_type.unwrap_or("ema").to_string()),
1490 };
1491 let out = py
1492 .allow_threads(|| demand_index_batch_with_kernel(high, low, close, volume, &sweep, kern))
1493 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1494
1495 let dict = PyDict::new(py);
1496 dict.set_item(
1497 "demand_index",
1498 out.demand_index
1499 .into_pyarray(py)
1500 .reshape((out.rows, out.cols))?,
1501 )?;
1502 dict.set_item(
1503 "signal",
1504 out.signal.into_pyarray(py).reshape((out.rows, out.cols))?,
1505 )?;
1506 dict.set_item(
1507 "len_bs",
1508 out.combos
1509 .iter()
1510 .map(|p| p.len_bs.unwrap_or(19) as u64)
1511 .collect::<Vec<_>>()
1512 .into_pyarray(py),
1513 )?;
1514 dict.set_item(
1515 "len_bs_ma",
1516 out.combos
1517 .iter()
1518 .map(|p| p.len_bs_ma.unwrap_or(19) as u64)
1519 .collect::<Vec<_>>()
1520 .into_pyarray(py),
1521 )?;
1522 dict.set_item(
1523 "len_di_ma",
1524 out.combos
1525 .iter()
1526 .map(|p| p.len_di_ma.unwrap_or(19) as u64)
1527 .collect::<Vec<_>>()
1528 .into_pyarray(py),
1529 )?;
1530 dict.set_item(
1531 "ma_types",
1532 out.combos
1533 .iter()
1534 .map(|p| p.ma_type.as_deref().unwrap_or("ema").to_string())
1535 .collect::<Vec<_>>(),
1536 )?;
1537 dict.set_item("rows", out.rows)?;
1538 dict.set_item("cols", out.cols)?;
1539 Ok(dict)
1540}
1541
1542#[cfg(feature = "python")]
1543pub fn register_demand_index_module(module: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1544 module.add_function(wrap_pyfunction!(demand_index_py, module)?)?;
1545 module.add_function(wrap_pyfunction!(demand_index_batch_py, module)?)?;
1546 module.add_class::<DemandIndexStreamPy>()?;
1547 Ok(())
1548}
1549
1550#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1551#[wasm_bindgen(js_name = "demand_index_js")]
1552pub fn demand_index_js(
1553 high: &[f64],
1554 low: &[f64],
1555 close: &[f64],
1556 volume: &[f64],
1557 len_bs: usize,
1558 len_bs_ma: usize,
1559 len_di_ma: usize,
1560 ma_type: &str,
1561) -> Result<js_sys::Object, JsValue> {
1562 let input = DemandIndexInput::from_slices(
1563 high,
1564 low,
1565 close,
1566 volume,
1567 DemandIndexParams {
1568 len_bs: Some(len_bs),
1569 len_bs_ma: Some(len_bs_ma),
1570 len_di_ma: Some(len_di_ma),
1571 ma_type: Some(ma_type.to_string()),
1572 },
1573 );
1574 let out = demand_index(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1575 let result = js_sys::Object::new();
1576 let di_array = js_sys::Float64Array::new_with_length(out.demand_index.len() as u32);
1577 di_array.copy_from(&out.demand_index);
1578 js_sys::Reflect::set(&result, &JsValue::from_str("demand_index"), &di_array)?;
1579 let signal_array = js_sys::Float64Array::new_with_length(out.signal.len() as u32);
1580 signal_array.copy_from(&out.signal);
1581 js_sys::Reflect::set(&result, &JsValue::from_str("signal"), &signal_array)?;
1582 Ok(result)
1583}
1584
1585#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1586#[wasm_bindgen]
1587pub fn demand_index_alloc(len: usize) -> *mut f64 {
1588 let mut vec = Vec::<f64>::with_capacity(len);
1589 let ptr = vec.as_mut_ptr();
1590 std::mem::forget(vec);
1591 ptr
1592}
1593
1594#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1595#[wasm_bindgen]
1596pub fn demand_index_free(ptr: *mut f64, len: usize) {
1597 if !ptr.is_null() {
1598 unsafe {
1599 let _ = Vec::from_raw_parts(ptr, len, len);
1600 }
1601 }
1602}
1603
1604#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1605#[wasm_bindgen]
1606pub fn demand_index_into(
1607 high_ptr: *const f64,
1608 low_ptr: *const f64,
1609 close_ptr: *const f64,
1610 volume_ptr: *const f64,
1611 demand_index_ptr: *mut f64,
1612 signal_ptr: *mut f64,
1613 len: usize,
1614 len_bs: usize,
1615 len_bs_ma: usize,
1616 len_di_ma: usize,
1617 ma_type: &str,
1618) -> Result<(), JsValue> {
1619 if high_ptr.is_null()
1620 || low_ptr.is_null()
1621 || close_ptr.is_null()
1622 || volume_ptr.is_null()
1623 || demand_index_ptr.is_null()
1624 || signal_ptr.is_null()
1625 {
1626 return Err(JsValue::from_str("Null pointer provided"));
1627 }
1628
1629 unsafe {
1630 let high = std::slice::from_raw_parts(high_ptr, len);
1631 let low = std::slice::from_raw_parts(low_ptr, len);
1632 let close = std::slice::from_raw_parts(close_ptr, len);
1633 let volume = std::slice::from_raw_parts(volume_ptr, len);
1634 let input = DemandIndexInput::from_slices(
1635 high,
1636 low,
1637 close,
1638 volume,
1639 DemandIndexParams {
1640 len_bs: Some(len_bs),
1641 len_bs_ma: Some(len_bs_ma),
1642 len_di_ma: Some(len_di_ma),
1643 ma_type: Some(ma_type.to_string()),
1644 },
1645 );
1646 let di_out = std::slice::from_raw_parts_mut(demand_index_ptr, len);
1647 let signal_out = std::slice::from_raw_parts_mut(signal_ptr, len);
1648 demand_index_into_slices(di_out, signal_out, &input, Kernel::Auto)
1649 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1650 }
1651 Ok(())
1652}
1653
1654#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1655#[derive(Serialize, Deserialize)]
1656pub struct DemandIndexBatchConfig {
1657 pub len_bs_range: (usize, usize, usize),
1658 pub len_bs_ma_range: (usize, usize, usize),
1659 pub len_di_ma_range: (usize, usize, usize),
1660 pub ma_type: Option<String>,
1661}
1662
1663#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1664#[derive(Serialize, Deserialize)]
1665pub struct DemandIndexBatchJsOutput {
1666 pub demand_index: Vec<f64>,
1667 pub signal: Vec<f64>,
1668 pub combos: Vec<DemandIndexParams>,
1669 pub len_bs: Vec<usize>,
1670 pub len_bs_ma: Vec<usize>,
1671 pub len_di_ma: Vec<usize>,
1672 pub ma_types: Vec<String>,
1673 pub rows: usize,
1674 pub cols: usize,
1675}
1676
1677#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1678#[wasm_bindgen(js_name = "demand_index_batch_js")]
1679pub fn demand_index_batch_js(
1680 high: &[f64],
1681 low: &[f64],
1682 close: &[f64],
1683 volume: &[f64],
1684 config: JsValue,
1685) -> Result<JsValue, JsValue> {
1686 let config: DemandIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1687 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1688 let sweep = DemandIndexBatchRange {
1689 len_bs: config.len_bs_range,
1690 len_bs_ma: config.len_bs_ma_range,
1691 len_di_ma: config.len_di_ma_range,
1692 ma_type: Some(config.ma_type.unwrap_or_else(|| "ema".to_string())),
1693 };
1694 let out = demand_index_batch_inner(
1695 high,
1696 low,
1697 close,
1698 volume,
1699 &sweep,
1700 detect_best_kernel(),
1701 false,
1702 )
1703 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1704 serde_wasm_bindgen::to_value(&DemandIndexBatchJsOutput {
1705 len_bs: out.combos.iter().map(|p| p.len_bs.unwrap_or(19)).collect(),
1706 len_bs_ma: out
1707 .combos
1708 .iter()
1709 .map(|p| p.len_bs_ma.unwrap_or(19))
1710 .collect(),
1711 len_di_ma: out
1712 .combos
1713 .iter()
1714 .map(|p| p.len_di_ma.unwrap_or(19))
1715 .collect(),
1716 ma_types: out
1717 .combos
1718 .iter()
1719 .map(|p| p.ma_type.as_deref().unwrap_or("ema").to_string())
1720 .collect(),
1721 demand_index: out.demand_index,
1722 signal: out.signal,
1723 combos: out.combos,
1724 rows: out.rows,
1725 cols: out.cols,
1726 })
1727 .map_err(|e| JsValue::from_str(&e.to_string()))
1728}
1729
1730#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1731#[wasm_bindgen]
1732pub fn demand_index_batch_into(
1733 high_ptr: *const f64,
1734 low_ptr: *const f64,
1735 close_ptr: *const f64,
1736 volume_ptr: *const f64,
1737 out_ptr: *mut f64,
1738 len: usize,
1739 len_bs_start: usize,
1740 len_bs_end: usize,
1741 len_bs_step: usize,
1742 len_bs_ma_start: usize,
1743 len_bs_ma_end: usize,
1744 len_bs_ma_step: usize,
1745 len_di_ma_start: usize,
1746 len_di_ma_end: usize,
1747 len_di_ma_step: usize,
1748 ma_type: &str,
1749) -> Result<usize, JsValue> {
1750 if high_ptr.is_null()
1751 || low_ptr.is_null()
1752 || close_ptr.is_null()
1753 || volume_ptr.is_null()
1754 || out_ptr.is_null()
1755 {
1756 return Err(JsValue::from_str("Null pointer provided"));
1757 }
1758
1759 let sweep = DemandIndexBatchRange {
1760 len_bs: (len_bs_start, len_bs_end, len_bs_step),
1761 len_bs_ma: (len_bs_ma_start, len_bs_ma_end, len_bs_ma_step),
1762 len_di_ma: (len_di_ma_start, len_di_ma_end, len_di_ma_step),
1763 ma_type: Some(ma_type.to_string()),
1764 };
1765 let combos = expand_grid_demand_index(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1766 let rows = combos.len();
1767 let total = rows
1768 .checked_mul(len)
1769 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1770
1771 unsafe {
1772 let high = std::slice::from_raw_parts(high_ptr, len);
1773 let low = std::slice::from_raw_parts(low_ptr, len);
1774 let close = std::slice::from_raw_parts(close_ptr, len);
1775 let volume = std::slice::from_raw_parts(volume_ptr, len);
1776 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1777 let batch = demand_index_batch_inner(
1778 high,
1779 low,
1780 close,
1781 volume,
1782 &sweep,
1783 detect_best_kernel(),
1784 false,
1785 )
1786 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1787 out.copy_from_slice(&batch.demand_index);
1788 }
1789 Ok(rows)
1790}
1791
1792#[cfg(test)]
1793mod tests {
1794 use super::*;
1795 use crate::utilities::data_loader::read_candles_from_csv;
1796 use std::error::Error;
1797
1798 fn load_ohlcv() -> Result<(Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>), Box<dyn Error>> {
1799 let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
1800 Ok((candles.high, candles.low, candles.close, candles.volume))
1801 }
1802
1803 fn assert_series_eq(left: &[f64], right: &[f64]) {
1804 assert_eq!(left.len(), right.len());
1805 for (idx, (&a, &b)) in left.iter().zip(right.iter()).enumerate() {
1806 if a.is_nan() && b.is_nan() {
1807 continue;
1808 }
1809 assert!(
1810 (a - b).abs() <= 1e-10,
1811 "mismatch at {}: left={}, right={}",
1812 idx,
1813 a,
1814 b
1815 );
1816 }
1817 }
1818
1819 #[test]
1820 fn demand_index_output_contract() -> Result<(), Box<dyn Error>> {
1821 let (high, low, close, volume) = load_ohlcv()?;
1822 let input = DemandIndexInput::from_slices(
1823 &high,
1824 &low,
1825 &close,
1826 &volume,
1827 DemandIndexParams::default(),
1828 );
1829 let out = demand_index_with_kernel(&input, Kernel::Scalar)?;
1830 assert_eq!(out.demand_index.len(), close.len());
1831 assert_eq!(out.signal.len(), close.len());
1832 assert!(out.demand_index.iter().any(|v| v.is_finite()));
1833 assert!(out.signal.iter().any(|v| v.is_finite()));
1834 for &value in out.demand_index.iter().filter(|v| v.is_finite()).take(64) {
1835 assert!((-100.0..=100.0).contains(&value));
1836 }
1837 Ok(())
1838 }
1839
1840 #[test]
1841 fn demand_index_into_matches_api() -> Result<(), Box<dyn Error>> {
1842 let (high, low, close, volume) = load_ohlcv()?;
1843 let input = DemandIndexInput::from_slices(
1844 &high,
1845 &low,
1846 &close,
1847 &volume,
1848 DemandIndexParams::default(),
1849 );
1850 let expected = demand_index(&input)?;
1851 let mut di = vec![f64::NAN; close.len()];
1852 let mut signal = vec![f64::NAN; close.len()];
1853 demand_index_into_slices(&mut di, &mut signal, &input, Kernel::Auto)?;
1854 assert_series_eq(&di, &expected.demand_index);
1855 assert_series_eq(&signal, &expected.signal);
1856 Ok(())
1857 }
1858
1859 #[test]
1860 fn demand_index_kernel_parity() -> Result<(), Box<dyn Error>> {
1861 let (high, low, close, volume) = load_ohlcv()?;
1862 let input = DemandIndexInput::from_slices(
1863 &high,
1864 &low,
1865 &close,
1866 &volume,
1867 DemandIndexParams::default(),
1868 );
1869 let auto = demand_index_with_kernel(&input, Kernel::Auto)?;
1870 let scalar = demand_index_with_kernel(&input, Kernel::Scalar)?;
1871 assert_series_eq(&auto.demand_index, &scalar.demand_index);
1872 assert_series_eq(&auto.signal, &scalar.signal);
1873 Ok(())
1874 }
1875
1876 #[test]
1877 fn demand_index_invalid_len_bs() {
1878 let high = [1.0, 2.0, 3.0];
1879 let low = [0.5, 1.5, 2.5];
1880 let close = [0.75, 1.75, 2.75];
1881 let volume = [10.0, 11.0, 12.0];
1882 let input = DemandIndexInput::from_slices(
1883 &high,
1884 &low,
1885 &close,
1886 &volume,
1887 DemandIndexParams {
1888 len_bs: Some(0),
1889 len_bs_ma: Some(19),
1890 len_di_ma: Some(19),
1891 ma_type: Some("ema".to_string()),
1892 },
1893 );
1894 assert!(matches!(
1895 demand_index(&input),
1896 Err(DemandIndexError::InvalidLenBs { .. })
1897 ));
1898 }
1899
1900 #[test]
1901 fn demand_index_stream_matches_batch() -> Result<(), Box<dyn Error>> {
1902 let (high, low, close, volume) = load_ohlcv()?;
1903 let params = DemandIndexParams {
1904 len_bs: Some(10),
1905 len_bs_ma: Some(10),
1906 len_di_ma: Some(8),
1907 ma_type: Some("ema".to_string()),
1908 };
1909 let batch = demand_index(&DemandIndexInput::from_slices(
1910 &high,
1911 &low,
1912 &close,
1913 &volume,
1914 params.clone(),
1915 ))?;
1916 let mut stream = DemandIndexStream::try_new(params)?;
1917 let mut di = Vec::with_capacity(close.len());
1918 let mut signal = Vec::with_capacity(close.len());
1919 for i in 0..close.len() {
1920 if let Some((di_value, signal_value)) =
1921 stream.update(high[i], low[i], close[i], volume[i])
1922 {
1923 di.push(di_value);
1924 signal.push(signal_value);
1925 } else {
1926 di.push(f64::NAN);
1927 signal.push(f64::NAN);
1928 }
1929 }
1930 assert_series_eq(&di, &batch.demand_index);
1931 assert_series_eq(&signal, &batch.signal);
1932 Ok(())
1933 }
1934
1935 #[test]
1936 fn demand_index_batch_single_matches_single() -> Result<(), Box<dyn Error>> {
1937 let (high, low, close, volume) = load_ohlcv()?;
1938 let sweep = DemandIndexBatchRange::default();
1939 let batch =
1940 demand_index_batch_with_kernel(&high, &low, &close, &volume, &sweep, Kernel::Auto)?;
1941 assert_eq!(batch.rows, 1);
1942 assert_eq!(batch.cols, close.len());
1943 let single = demand_index(&DemandIndexInput::from_slices(
1944 &high,
1945 &low,
1946 &close,
1947 &volume,
1948 DemandIndexParams::default(),
1949 ))?;
1950 assert_series_eq(&batch.demand_index[..close.len()], &single.demand_index);
1951 assert_series_eq(&batch.signal[..close.len()], &single.signal);
1952 Ok(())
1953 }
1954
1955 #[test]
1956 fn demand_index_invalid_window_recovers() -> Result<(), Box<dyn Error>> {
1957 let (mut high, mut low, mut close, mut volume) = load_ohlcv()?;
1958 high[80] = f64::NAN;
1959 low[80] = f64::NAN;
1960 close[80] = f64::NAN;
1961 volume[80] = f64::NAN;
1962
1963 let out = demand_index(&DemandIndexInput::from_slices(
1964 &high,
1965 &low,
1966 &close,
1967 &volume,
1968 DemandIndexParams::default(),
1969 ))?;
1970 assert!(out.demand_index[80].is_nan());
1971 assert!(out.signal[80].is_nan());
1972 assert!(out.demand_index.iter().skip(140).any(|v| v.is_finite()));
1973 assert!(out.signal.iter().skip(160).any(|v| v.is_finite()));
1974 Ok(())
1975 }
1976}