1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(feature = "python")]
10use pyo3::wrap_pyfunction;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::mem::ManuallyDrop;
28use thiserror::Error;
29
30const DEFAULT_SHORT_LENGTH: usize = 3;
31const DEFAULT_MEDIUM_LENGTH: usize = 8;
32const DEFAULT_LONG_LENGTH: usize = 20;
33
34impl<'a> AsRef<[f64]> for DidiIndexInput<'a> {
35 #[inline(always)]
36 fn as_ref(&self) -> &[f64] {
37 match &self.data {
38 DidiIndexData::Slice(slice) => slice,
39 DidiIndexData::Candles { candles, source } => source_type(candles, source),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
45pub enum DidiIndexData<'a> {
46 Candles {
47 candles: &'a Candles,
48 source: &'a str,
49 },
50 Slice(&'a [f64]),
51}
52
53#[derive(Debug, Clone)]
54pub struct DidiIndexOutput {
55 pub short: Vec<f64>,
56 pub long: Vec<f64>,
57 pub crossover: Vec<f64>,
58 pub crossunder: Vec<f64>,
59}
60
61#[derive(Debug, Clone, PartialEq)]
62#[cfg_attr(
63 all(target_arch = "wasm32", feature = "wasm"),
64 derive(Serialize, Deserialize)
65)]
66pub struct DidiIndexParams {
67 pub short_length: Option<usize>,
68 pub medium_length: Option<usize>,
69 pub long_length: Option<usize>,
70}
71
72impl Default for DidiIndexParams {
73 fn default() -> Self {
74 Self {
75 short_length: Some(DEFAULT_SHORT_LENGTH),
76 medium_length: Some(DEFAULT_MEDIUM_LENGTH),
77 long_length: Some(DEFAULT_LONG_LENGTH),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct DidiIndexInput<'a> {
84 pub data: DidiIndexData<'a>,
85 pub params: DidiIndexParams,
86}
87
88impl<'a> DidiIndexInput<'a> {
89 #[inline]
90 pub fn from_candles(candles: &'a Candles, source: &'a str, params: DidiIndexParams) -> Self {
91 Self {
92 data: DidiIndexData::Candles { candles, source },
93 params,
94 }
95 }
96
97 #[inline]
98 pub fn from_slice(slice: &'a [f64], params: DidiIndexParams) -> Self {
99 Self {
100 data: DidiIndexData::Slice(slice),
101 params,
102 }
103 }
104
105 #[inline]
106 pub fn with_default_candles(candles: &'a Candles) -> Self {
107 Self::from_candles(candles, "close", DidiIndexParams::default())
108 }
109
110 #[inline]
111 pub fn get_short_length(&self) -> usize {
112 self.params.short_length.unwrap_or(DEFAULT_SHORT_LENGTH)
113 }
114
115 #[inline]
116 pub fn get_medium_length(&self) -> usize {
117 self.params.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH)
118 }
119
120 #[inline]
121 pub fn get_long_length(&self) -> usize {
122 self.params.long_length.unwrap_or(DEFAULT_LONG_LENGTH)
123 }
124}
125
126#[derive(Copy, Clone, Debug)]
127pub struct DidiIndexBuilder {
128 short_length: Option<usize>,
129 medium_length: Option<usize>,
130 long_length: Option<usize>,
131 kernel: Kernel,
132}
133
134impl Default for DidiIndexBuilder {
135 fn default() -> Self {
136 Self {
137 short_length: None,
138 medium_length: None,
139 long_length: None,
140 kernel: Kernel::Auto,
141 }
142 }
143}
144
145impl DidiIndexBuilder {
146 #[inline]
147 pub fn new() -> Self {
148 Self::default()
149 }
150
151 #[inline]
152 pub fn short_length(mut self, short_length: usize) -> Self {
153 self.short_length = Some(short_length);
154 self
155 }
156
157 #[inline]
158 pub fn medium_length(mut self, medium_length: usize) -> Self {
159 self.medium_length = Some(medium_length);
160 self
161 }
162
163 #[inline]
164 pub fn long_length(mut self, long_length: usize) -> Self {
165 self.long_length = Some(long_length);
166 self
167 }
168
169 #[inline]
170 pub fn kernel(mut self, kernel: Kernel) -> Self {
171 self.kernel = kernel;
172 self
173 }
174
175 #[inline]
176 pub fn apply(self, candles: &Candles, source: &str) -> Result<DidiIndexOutput, DidiIndexError> {
177 let input = DidiIndexInput::from_candles(
178 candles,
179 source,
180 DidiIndexParams {
181 short_length: self.short_length,
182 medium_length: self.medium_length,
183 long_length: self.long_length,
184 },
185 );
186 didi_index_with_kernel(&input, self.kernel)
187 }
188
189 #[inline]
190 pub fn apply_slice(self, data: &[f64]) -> Result<DidiIndexOutput, DidiIndexError> {
191 let input = DidiIndexInput::from_slice(
192 data,
193 DidiIndexParams {
194 short_length: self.short_length,
195 medium_length: self.medium_length,
196 long_length: self.long_length,
197 },
198 );
199 didi_index_with_kernel(&input, self.kernel)
200 }
201
202 #[inline]
203 pub fn into_stream(self) -> Result<DidiIndexStream, DidiIndexError> {
204 DidiIndexStream::try_new(DidiIndexParams {
205 short_length: self.short_length,
206 medium_length: self.medium_length,
207 long_length: self.long_length,
208 })
209 }
210}
211
212#[derive(Debug, Error)]
213pub enum DidiIndexError {
214 #[error("didi_index: Input data slice is empty.")]
215 EmptyInputData,
216 #[error("didi_index: All values are NaN.")]
217 AllValuesNaN,
218 #[error(
219 "didi_index: Invalid short_length: short_length = {short_length}, data length = {data_len}"
220 )]
221 InvalidShortLength {
222 short_length: usize,
223 data_len: usize,
224 },
225 #[error("didi_index: Invalid medium_length: medium_length = {medium_length}, data length = {data_len}")]
226 InvalidMediumLength {
227 medium_length: usize,
228 data_len: usize,
229 },
230 #[error(
231 "didi_index: Invalid long_length: long_length = {long_length}, data length = {data_len}"
232 )]
233 InvalidLongLength { long_length: usize, data_len: usize },
234 #[error("didi_index: Not enough valid data: needed = {needed}, valid = {valid}")]
235 NotEnoughValidData { needed: usize, valid: usize },
236 #[error("didi_index: Output length mismatch: expected = {expected}, short = {short_got}, long = {long_got}, crossover = {crossover_got}, crossunder = {crossunder_got}")]
237 OutputLengthMismatch {
238 expected: usize,
239 short_got: usize,
240 long_got: usize,
241 crossover_got: usize,
242 crossunder_got: usize,
243 },
244 #[error("didi_index: Invalid range: start={start}, end={end}, step={step}")]
245 InvalidRange {
246 start: String,
247 end: String,
248 step: String,
249 },
250 #[error("didi_index: Invalid kernel for batch: {0:?}")]
251 InvalidKernelForBatch(Kernel),
252}
253
254#[derive(Debug, Clone)]
255struct SmaWindow {
256 period: usize,
257 values: Vec<f64>,
258 idx: usize,
259 count: usize,
260 sum: f64,
261}
262
263impl SmaWindow {
264 #[inline]
265 fn new(period: usize) -> Self {
266 Self {
267 period,
268 values: vec![0.0; period.max(1)],
269 idx: 0,
270 count: 0,
271 sum: 0.0,
272 }
273 }
274
275 #[inline]
276 fn reset(&mut self) {
277 self.idx = 0;
278 self.count = 0;
279 self.sum = 0.0;
280 }
281
282 #[inline]
283 fn update(&mut self, value: f64) -> Option<f64> {
284 if self.count < self.period {
285 self.values[self.idx] = value;
286 self.sum += value;
287 self.count += 1;
288 self.idx += 1;
289 if self.idx == self.period {
290 self.idx = 0;
291 }
292 if self.count == self.period {
293 Some(self.sum / self.period as f64)
294 } else {
295 None
296 }
297 } else {
298 let old = self.values[self.idx];
299 self.values[self.idx] = value;
300 self.sum += value - old;
301 self.idx += 1;
302 if self.idx == self.period {
303 self.idx = 0;
304 }
305 Some(self.sum / self.period as f64)
306 }
307 }
308}
309
310#[derive(Debug, Clone)]
311pub struct DidiIndexStream {
312 short: SmaWindow,
313 medium: SmaWindow,
314 long: SmaWindow,
315 prev_short: f64,
316 prev_long: f64,
317 have_prev: bool,
318 warmup: usize,
319}
320
321impl DidiIndexStream {
322 pub fn try_new(params: DidiIndexParams) -> Result<Self, DidiIndexError> {
323 let short_length = params.short_length.unwrap_or(DEFAULT_SHORT_LENGTH);
324 if short_length == 0 {
325 return Err(DidiIndexError::InvalidShortLength {
326 short_length,
327 data_len: 0,
328 });
329 }
330 let medium_length = params.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH);
331 if medium_length == 0 {
332 return Err(DidiIndexError::InvalidMediumLength {
333 medium_length,
334 data_len: 0,
335 });
336 }
337 let long_length = params.long_length.unwrap_or(DEFAULT_LONG_LENGTH);
338 if long_length == 0 {
339 return Err(DidiIndexError::InvalidLongLength {
340 long_length,
341 data_len: 0,
342 });
343 }
344 Ok(Self {
345 short: SmaWindow::new(short_length),
346 medium: SmaWindow::new(medium_length),
347 long: SmaWindow::new(long_length),
348 prev_short: f64::NAN,
349 prev_long: f64::NAN,
350 have_prev: false,
351 warmup: short_length.max(medium_length).max(long_length) - 1,
352 })
353 }
354
355 #[inline]
356 fn reset(&mut self) {
357 self.short.reset();
358 self.medium.reset();
359 self.long.reset();
360 self.prev_short = f64::NAN;
361 self.prev_long = f64::NAN;
362 self.have_prev = false;
363 }
364
365 #[inline]
366 pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64, f64)> {
367 if !valid_value(value) {
368 self.reset();
369 return None;
370 }
371
372 let short_ma = self.short.update(value);
373 let medium_ma = self.medium.update(value);
374 let long_ma = self.long.update(value);
375 if short_ma.is_none() || medium_ma.is_none() || long_ma.is_none() {
376 self.have_prev = false;
377 return None;
378 }
379
380 let medium_ma = medium_ma.unwrap_or(f64::NAN);
381 if !medium_ma.is_finite() || medium_ma == 0.0 {
382 self.have_prev = false;
383 return Some((f64::NAN, f64::NAN, f64::NAN, f64::NAN));
384 }
385
386 let short = short_ma.unwrap_or(f64::NAN) / medium_ma;
387 let long = long_ma.unwrap_or(f64::NAN) / medium_ma;
388 if !short.is_finite() || !long.is_finite() {
389 self.have_prev = false;
390 return Some((f64::NAN, f64::NAN, f64::NAN, f64::NAN));
391 }
392
393 let crossover = if self.have_prev && short > long && self.prev_short <= self.prev_long {
394 1.0
395 } else {
396 0.0
397 };
398 let crossunder = if self.have_prev && short < long && self.prev_short >= self.prev_long {
399 1.0
400 } else {
401 0.0
402 };
403 self.prev_short = short;
404 self.prev_long = long;
405 self.have_prev = true;
406 Some((short, long, crossover, crossunder))
407 }
408
409 #[inline]
410 pub fn get_warmup_period(&self) -> usize {
411 self.warmup
412 }
413}
414
415#[inline]
416pub fn didi_index(input: &DidiIndexInput) -> Result<DidiIndexOutput, DidiIndexError> {
417 didi_index_with_kernel(input, Kernel::Auto)
418}
419
420#[inline(always)]
421fn valid_value(value: f64) -> bool {
422 value.is_finite()
423}
424
425#[inline(always)]
426fn first_valid_value(data: &[f64]) -> usize {
427 let mut i = 0usize;
428 while i < data.len() {
429 if valid_value(data[i]) {
430 break;
431 }
432 i += 1;
433 }
434 i.min(data.len())
435}
436
437#[inline(always)]
438fn count_valid_values(data: &[f64]) -> usize {
439 data.iter().filter(|v| valid_value(**v)).count()
440}
441
442#[inline(always)]
443fn didi_index_row_from_slice(
444 data: &[f64],
445 params: &DidiIndexParams,
446 short_out: &mut [f64],
447 long_out: &mut [f64],
448 crossover_out: &mut [f64],
449 crossunder_out: &mut [f64],
450) -> Result<(), DidiIndexError> {
451 let mut stream = DidiIndexStream::try_new(params.clone())?;
452 for i in 0..data.len() {
453 match stream.update(data[i]) {
454 Some((short, long, crossover, crossunder)) => {
455 short_out[i] = short;
456 long_out[i] = long;
457 crossover_out[i] = crossover;
458 crossunder_out[i] = crossunder;
459 }
460 None => {
461 short_out[i] = f64::NAN;
462 long_out[i] = f64::NAN;
463 crossover_out[i] = f64::NAN;
464 crossunder_out[i] = f64::NAN;
465 }
466 }
467 }
468 Ok(())
469}
470
471#[inline(always)]
472fn didi_index_prepare<'a>(
473 input: &'a DidiIndexInput,
474 kernel: Kernel,
475) -> Result<(&'a [f64], usize, DidiIndexParams, Kernel), DidiIndexError> {
476 let data = input.as_ref();
477 if data.is_empty() {
478 return Err(DidiIndexError::EmptyInputData);
479 }
480
481 let first = first_valid_value(data);
482 if first >= data.len() {
483 return Err(DidiIndexError::AllValuesNaN);
484 }
485
486 let params = input.params.clone();
487 let short_length = params.short_length.unwrap_or(DEFAULT_SHORT_LENGTH);
488 let medium_length = params.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH);
489 let long_length = params.long_length.unwrap_or(DEFAULT_LONG_LENGTH);
490 let len = data.len();
491 if short_length == 0 || short_length > len {
492 return Err(DidiIndexError::InvalidShortLength {
493 short_length,
494 data_len: len,
495 });
496 }
497 if medium_length == 0 || medium_length > len {
498 return Err(DidiIndexError::InvalidMediumLength {
499 medium_length,
500 data_len: len,
501 });
502 }
503 if long_length == 0 || long_length > len {
504 return Err(DidiIndexError::InvalidLongLength {
505 long_length,
506 data_len: len,
507 });
508 }
509
510 let needed = short_length.max(medium_length).max(long_length);
511 let valid = count_valid_values(data);
512 if valid < needed {
513 return Err(DidiIndexError::NotEnoughValidData { needed, valid });
514 }
515
516 let chosen = match kernel {
517 Kernel::Auto => detect_best_kernel(),
518 other => other.to_non_batch(),
519 };
520 Ok((data, first, params, chosen))
521}
522
523#[inline]
524pub fn didi_index_with_kernel(
525 input: &DidiIndexInput,
526 kernel: Kernel,
527) -> Result<DidiIndexOutput, DidiIndexError> {
528 let (data, first, params, _chosen) = didi_index_prepare(input, kernel)?;
529 let mut short = alloc_with_nan_prefix(data.len(), first);
530 let mut long = alloc_with_nan_prefix(data.len(), first);
531 let mut crossover = alloc_with_nan_prefix(data.len(), first);
532 let mut crossunder = alloc_with_nan_prefix(data.len(), first);
533 didi_index_row_from_slice(
534 data,
535 ¶ms,
536 &mut short,
537 &mut long,
538 &mut crossover,
539 &mut crossunder,
540 )?;
541 Ok(DidiIndexOutput {
542 short,
543 long,
544 crossover,
545 crossunder,
546 })
547}
548
549#[inline]
550pub fn didi_index_into_slices(
551 short_out: &mut [f64],
552 long_out: &mut [f64],
553 crossover_out: &mut [f64],
554 crossunder_out: &mut [f64],
555 input: &DidiIndexInput,
556 kernel: Kernel,
557) -> Result<(), DidiIndexError> {
558 let (data, _first, params, _chosen) = didi_index_prepare(input, kernel)?;
559 if short_out.len() != data.len()
560 || long_out.len() != data.len()
561 || crossover_out.len() != data.len()
562 || crossunder_out.len() != data.len()
563 {
564 return Err(DidiIndexError::OutputLengthMismatch {
565 expected: data.len(),
566 short_got: short_out.len(),
567 long_got: long_out.len(),
568 crossover_got: crossover_out.len(),
569 crossunder_got: crossunder_out.len(),
570 });
571 }
572 didi_index_row_from_slice(
573 data,
574 ¶ms,
575 short_out,
576 long_out,
577 crossover_out,
578 crossunder_out,
579 )
580}
581
582#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
583#[inline]
584pub fn didi_index_into(
585 input: &DidiIndexInput,
586 short_out: &mut [f64],
587 long_out: &mut [f64],
588 crossover_out: &mut [f64],
589 crossunder_out: &mut [f64],
590) -> Result<(), DidiIndexError> {
591 didi_index_into_slices(
592 short_out,
593 long_out,
594 crossover_out,
595 crossunder_out,
596 input,
597 Kernel::Auto,
598 )
599}
600
601#[derive(Clone, Debug)]
602pub struct DidiIndexBatchRange {
603 pub short_length: (usize, usize, usize),
604 pub medium_length: (usize, usize, usize),
605 pub long_length: (usize, usize, usize),
606}
607
608impl Default for DidiIndexBatchRange {
609 fn default() -> Self {
610 Self {
611 short_length: (DEFAULT_SHORT_LENGTH, DEFAULT_SHORT_LENGTH, 0),
612 medium_length: (DEFAULT_MEDIUM_LENGTH, DEFAULT_MEDIUM_LENGTH, 0),
613 long_length: (DEFAULT_LONG_LENGTH, DEFAULT_LONG_LENGTH, 0),
614 }
615 }
616}
617
618#[derive(Clone, Debug)]
619pub struct DidiIndexBatchBuilder {
620 range: DidiIndexBatchRange,
621 kernel: Kernel,
622}
623
624impl Default for DidiIndexBatchBuilder {
625 fn default() -> Self {
626 Self {
627 range: DidiIndexBatchRange::default(),
628 kernel: Kernel::Auto,
629 }
630 }
631}
632
633impl DidiIndexBatchBuilder {
634 #[inline]
635 pub fn new() -> Self {
636 Self::default()
637 }
638
639 #[inline]
640 pub fn short_length_range(mut self, range: (usize, usize, usize)) -> Self {
641 self.range.short_length = range;
642 self
643 }
644
645 #[inline]
646 pub fn medium_length_range(mut self, range: (usize, usize, usize)) -> Self {
647 self.range.medium_length = range;
648 self
649 }
650
651 #[inline]
652 pub fn long_length_range(mut self, range: (usize, usize, usize)) -> Self {
653 self.range.long_length = range;
654 self
655 }
656
657 #[inline]
658 pub fn kernel(mut self, kernel: Kernel) -> Self {
659 self.kernel = kernel;
660 self
661 }
662
663 #[inline]
664 pub fn apply_slice(self, data: &[f64]) -> Result<DidiIndexBatchOutput, DidiIndexError> {
665 didi_index_batch_with_kernel(data, &self.range, self.kernel)
666 }
667
668 #[inline]
669 pub fn apply_candles(
670 self,
671 candles: &Candles,
672 source: &str,
673 ) -> Result<DidiIndexBatchOutput, DidiIndexError> {
674 self.apply_slice(source_type(candles, source))
675 }
676
677 #[inline]
678 pub fn with_default_candles(candles: &Candles) -> Result<DidiIndexBatchOutput, DidiIndexError> {
679 DidiIndexBatchBuilder::new().apply_candles(candles, "close")
680 }
681}
682
683#[derive(Clone, Debug)]
684pub struct DidiIndexBatchOutput {
685 pub short: Vec<f64>,
686 pub long: Vec<f64>,
687 pub crossover: Vec<f64>,
688 pub crossunder: Vec<f64>,
689 pub combos: Vec<DidiIndexParams>,
690 pub rows: usize,
691 pub cols: usize,
692}
693
694impl DidiIndexBatchOutput {
695 pub fn row_for_params(&self, params: &DidiIndexParams) -> Option<usize> {
696 self.combos.iter().position(|combo| combo == params)
697 }
698
699 pub fn short_for(&self, params: &DidiIndexParams) -> Option<&[f64]> {
700 self.row_for_params(params).and_then(|row| {
701 row.checked_mul(self.cols)
702 .and_then(|start| self.short.get(start..start + self.cols))
703 })
704 }
705
706 pub fn long_for(&self, params: &DidiIndexParams) -> Option<&[f64]> {
707 self.row_for_params(params).and_then(|row| {
708 row.checked_mul(self.cols)
709 .and_then(|start| self.long.get(start..start + self.cols))
710 })
711 }
712}
713
714#[inline(always)]
715fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, DidiIndexError> {
716 if step == 0 || start == end {
717 return Ok(vec![start]);
718 }
719 let step = step.max(1);
720 if start < end {
721 let mut out = Vec::new();
722 let mut x = start;
723 while x <= end {
724 out.push(x);
725 match x.checked_add(step) {
726 Some(next) if next != x => x = next,
727 _ => break,
728 }
729 }
730 if out.is_empty() {
731 return Err(DidiIndexError::InvalidRange {
732 start: start.to_string(),
733 end: end.to_string(),
734 step: step.to_string(),
735 });
736 }
737 Ok(out)
738 } else {
739 let mut out = Vec::new();
740 let mut x = start;
741 loop {
742 out.push(x);
743 if x == end {
744 break;
745 }
746 let next = x.saturating_sub(step);
747 if next == x || next < end {
748 break;
749 }
750 x = next;
751 }
752 if out.is_empty() {
753 return Err(DidiIndexError::InvalidRange {
754 start: start.to_string(),
755 end: end.to_string(),
756 step: step.to_string(),
757 });
758 }
759 Ok(out)
760 }
761}
762
763#[inline(always)]
764fn expand_grid_didi_index(
765 range: &DidiIndexBatchRange,
766) -> Result<Vec<DidiIndexParams>, DidiIndexError> {
767 let shorts = axis_usize(range.short_length)?;
768 let mediums = axis_usize(range.medium_length)?;
769 let longs = axis_usize(range.long_length)?;
770
771 if let Some(&short_length) = shorts.iter().find(|&&value| value == 0) {
772 return Err(DidiIndexError::InvalidShortLength {
773 short_length,
774 data_len: 0,
775 });
776 }
777 if let Some(&medium_length) = mediums.iter().find(|&&value| value == 0) {
778 return Err(DidiIndexError::InvalidMediumLength {
779 medium_length,
780 data_len: 0,
781 });
782 }
783 if let Some(&long_length) = longs.iter().find(|&&value| value == 0) {
784 return Err(DidiIndexError::InvalidLongLength {
785 long_length,
786 data_len: 0,
787 });
788 }
789
790 let mut out = Vec::with_capacity(shorts.len() * mediums.len() * longs.len());
791 for &short_length in &shorts {
792 for &medium_length in &mediums {
793 for &long_length in &longs {
794 out.push(DidiIndexParams {
795 short_length: Some(short_length),
796 medium_length: Some(medium_length),
797 long_length: Some(long_length),
798 });
799 }
800 }
801 }
802 Ok(out)
803}
804
805pub fn didi_index_batch_with_kernel(
806 data: &[f64],
807 sweep: &DidiIndexBatchRange,
808 kernel: Kernel,
809) -> Result<DidiIndexBatchOutput, DidiIndexError> {
810 let batch_kernel = match kernel {
811 Kernel::Auto => detect_best_batch_kernel(),
812 other if other.is_batch() => other,
813 other => return Err(DidiIndexError::InvalidKernelForBatch(other)),
814 };
815 didi_index_batch_inner(data, sweep, batch_kernel.to_non_batch(), false)
816}
817
818#[inline]
819pub fn didi_index_batch_slice(
820 data: &[f64],
821 sweep: &DidiIndexBatchRange,
822) -> Result<DidiIndexBatchOutput, DidiIndexError> {
823 didi_index_batch_with_kernel(data, sweep, Kernel::Auto)
824}
825
826#[inline]
827pub fn didi_index_batch_par_slice(
828 data: &[f64],
829 sweep: &DidiIndexBatchRange,
830) -> Result<DidiIndexBatchOutput, DidiIndexError> {
831 #[cfg(not(target_arch = "wasm32"))]
832 {
833 let kernel = detect_best_batch_kernel().to_non_batch();
834 return didi_index_batch_inner(data, sweep, kernel, true);
835 }
836 #[cfg(target_arch = "wasm32")]
837 {
838 didi_index_batch_inner(data, sweep, detect_best_kernel(), false)
839 }
840}
841
842pub fn didi_index_batch_inner(
843 data: &[f64],
844 sweep: &DidiIndexBatchRange,
845 kernel: Kernel,
846 parallel: bool,
847) -> Result<DidiIndexBatchOutput, DidiIndexError> {
848 if data.is_empty() {
849 return Err(DidiIndexError::EmptyInputData);
850 }
851 let first = first_valid_value(data);
852 if first >= data.len() {
853 return Err(DidiIndexError::AllValuesNaN);
854 }
855
856 let combos = expand_grid_didi_index(sweep)?;
857 let rows = combos.len();
858 let cols = data.len();
859 let total = rows
860 .checked_mul(cols)
861 .ok_or_else(|| DidiIndexError::OutputLengthMismatch {
862 expected: usize::MAX,
863 short_got: 0,
864 long_got: 0,
865 crossover_got: 0,
866 crossunder_got: 0,
867 })?;
868
869 let valid = count_valid_values(data);
870 let mut warms = Vec::with_capacity(rows);
871 for combo in &combos {
872 let short_length = combo.short_length.unwrap_or(DEFAULT_SHORT_LENGTH);
873 let medium_length = combo.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH);
874 let long_length = combo.long_length.unwrap_or(DEFAULT_LONG_LENGTH);
875 let needed = short_length.max(medium_length).max(long_length);
876 if short_length > cols {
877 return Err(DidiIndexError::InvalidShortLength {
878 short_length,
879 data_len: cols,
880 });
881 }
882 if medium_length > cols {
883 return Err(DidiIndexError::InvalidMediumLength {
884 medium_length,
885 data_len: cols,
886 });
887 }
888 if long_length > cols {
889 return Err(DidiIndexError::InvalidLongLength {
890 long_length,
891 data_len: cols,
892 });
893 }
894 if valid < needed {
895 return Err(DidiIndexError::NotEnoughValidData { needed, valid });
896 }
897 warms.push((first + needed - 1).min(cols));
898 }
899
900 let mut short_mu = make_uninit_matrix(rows, cols);
901 let mut long_mu = make_uninit_matrix(rows, cols);
902 let mut crossover_mu = make_uninit_matrix(rows, cols);
903 let mut crossunder_mu = make_uninit_matrix(rows, cols);
904 init_matrix_prefixes(&mut short_mu, cols, &warms);
905 init_matrix_prefixes(&mut long_mu, cols, &warms);
906 init_matrix_prefixes(&mut crossover_mu, cols, &warms);
907 init_matrix_prefixes(&mut crossunder_mu, cols, &warms);
908
909 let mut short_guard = ManuallyDrop::new(short_mu);
910 let mut long_guard = ManuallyDrop::new(long_mu);
911 let mut crossover_guard = ManuallyDrop::new(crossover_mu);
912 let mut crossunder_guard = ManuallyDrop::new(crossunder_mu);
913
914 let short_out =
915 unsafe { std::slice::from_raw_parts_mut(short_guard.as_mut_ptr() as *mut f64, total) };
916 let long_out =
917 unsafe { std::slice::from_raw_parts_mut(long_guard.as_mut_ptr() as *mut f64, total) };
918 let crossover_out =
919 unsafe { std::slice::from_raw_parts_mut(crossover_guard.as_mut_ptr() as *mut f64, total) };
920 let crossunder_out =
921 unsafe { std::slice::from_raw_parts_mut(crossunder_guard.as_mut_ptr() as *mut f64, total) };
922
923 if parallel {
924 #[cfg(not(target_arch = "wasm32"))]
925 {
926 short_out
927 .par_chunks_mut(cols)
928 .zip(long_out.par_chunks_mut(cols))
929 .zip(crossover_out.par_chunks_mut(cols))
930 .zip(crossunder_out.par_chunks_mut(cols))
931 .zip(combos.par_iter())
932 .for_each(
933 |((((dst_short, dst_long), dst_crossover), dst_crossunder), combo)| {
934 let _ = didi_index_row_from_slice(
935 data,
936 combo,
937 dst_short,
938 dst_long,
939 dst_crossover,
940 dst_crossunder,
941 );
942 },
943 );
944 }
945 } else {
946 let _ = kernel;
947 for (row, combo) in combos.iter().enumerate() {
948 let start = row * cols;
949 let end = start + cols;
950 didi_index_row_from_slice(
951 data,
952 combo,
953 &mut short_out[start..end],
954 &mut long_out[start..end],
955 &mut crossover_out[start..end],
956 &mut crossunder_out[start..end],
957 )?;
958 }
959 }
960
961 let short = unsafe {
962 Vec::from_raw_parts(
963 short_guard.as_mut_ptr() as *mut f64,
964 short_guard.len(),
965 short_guard.capacity(),
966 )
967 };
968 let long = unsafe {
969 Vec::from_raw_parts(
970 long_guard.as_mut_ptr() as *mut f64,
971 long_guard.len(),
972 long_guard.capacity(),
973 )
974 };
975 let crossover = unsafe {
976 Vec::from_raw_parts(
977 crossover_guard.as_mut_ptr() as *mut f64,
978 crossover_guard.len(),
979 crossover_guard.capacity(),
980 )
981 };
982 let crossunder = unsafe {
983 Vec::from_raw_parts(
984 crossunder_guard.as_mut_ptr() as *mut f64,
985 crossunder_guard.len(),
986 crossunder_guard.capacity(),
987 )
988 };
989 core::mem::forget(short_guard);
990 core::mem::forget(long_guard);
991 core::mem::forget(crossover_guard);
992 core::mem::forget(crossunder_guard);
993
994 Ok(DidiIndexBatchOutput {
995 short,
996 long,
997 crossover,
998 crossunder,
999 combos,
1000 rows,
1001 cols,
1002 })
1003}
1004
1005pub fn didi_index_batch_inner_into(
1006 data: &[f64],
1007 sweep: &DidiIndexBatchRange,
1008 kernel: Kernel,
1009 short_out: &mut [f64],
1010 long_out: &mut [f64],
1011 crossover_out: &mut [f64],
1012 crossunder_out: &mut [f64],
1013) -> Result<Vec<DidiIndexParams>, DidiIndexError> {
1014 let out = didi_index_batch_inner(data, sweep, kernel, false)?;
1015 let total = out.rows * out.cols;
1016 if short_out.len() != total
1017 || long_out.len() != total
1018 || crossover_out.len() != total
1019 || crossunder_out.len() != total
1020 {
1021 return Err(DidiIndexError::OutputLengthMismatch {
1022 expected: total,
1023 short_got: short_out.len(),
1024 long_got: long_out.len(),
1025 crossover_got: crossover_out.len(),
1026 crossunder_got: crossunder_out.len(),
1027 });
1028 }
1029 short_out.copy_from_slice(&out.short);
1030 long_out.copy_from_slice(&out.long);
1031 crossover_out.copy_from_slice(&out.crossover);
1032 crossunder_out.copy_from_slice(&out.crossunder);
1033 Ok(out.combos)
1034}
1035
1036#[cfg(feature = "python")]
1037#[pyfunction(name = "didi_index")]
1038#[pyo3(signature = (data, short_length=None, medium_length=None, long_length=None, kernel=None))]
1039pub fn didi_index_py<'py>(
1040 py: Python<'py>,
1041 data: PyReadonlyArray1<'py, f64>,
1042 short_length: Option<usize>,
1043 medium_length: Option<usize>,
1044 long_length: Option<usize>,
1045 kernel: Option<&str>,
1046) -> PyResult<(
1047 Bound<'py, PyArray1<f64>>,
1048 Bound<'py, PyArray1<f64>>,
1049 Bound<'py, PyArray1<f64>>,
1050 Bound<'py, PyArray1<f64>>,
1051)> {
1052 let data = data.as_slice()?;
1053 let kern = validate_kernel(kernel, false)?;
1054 let input = DidiIndexInput::from_slice(
1055 data,
1056 DidiIndexParams {
1057 short_length,
1058 medium_length,
1059 long_length,
1060 },
1061 );
1062 let out = py
1063 .allow_threads(|| didi_index_with_kernel(&input, kern))
1064 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1065 Ok((
1066 out.short.into_pyarray(py),
1067 out.long.into_pyarray(py),
1068 out.crossover.into_pyarray(py),
1069 out.crossunder.into_pyarray(py),
1070 ))
1071}
1072
1073#[cfg(feature = "python")]
1074#[pyclass(name = "DidiIndexStream")]
1075pub struct DidiIndexStreamPy {
1076 inner: DidiIndexStream,
1077}
1078
1079#[cfg(feature = "python")]
1080#[pymethods]
1081impl DidiIndexStreamPy {
1082 #[new]
1083 #[pyo3(signature = (short_length=DEFAULT_SHORT_LENGTH, medium_length=DEFAULT_MEDIUM_LENGTH, long_length=DEFAULT_LONG_LENGTH))]
1084 fn new(short_length: usize, medium_length: usize, long_length: usize) -> PyResult<Self> {
1085 let inner = DidiIndexStream::try_new(DidiIndexParams {
1086 short_length: Some(short_length),
1087 medium_length: Some(medium_length),
1088 long_length: Some(long_length),
1089 })
1090 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1091 Ok(Self { inner })
1092 }
1093
1094 fn update(&mut self, value: f64) -> Option<(f64, f64, f64, f64)> {
1095 self.inner.update(value)
1096 }
1097
1098 #[getter]
1099 fn warmup_period(&self) -> usize {
1100 self.inner.get_warmup_period()
1101 }
1102}
1103
1104#[cfg(feature = "python")]
1105#[pyfunction(name = "didi_index_batch")]
1106#[pyo3(signature = (data, short_length_range=(DEFAULT_SHORT_LENGTH, DEFAULT_SHORT_LENGTH, 0), medium_length_range=(DEFAULT_MEDIUM_LENGTH, DEFAULT_MEDIUM_LENGTH, 0), long_length_range=(DEFAULT_LONG_LENGTH, DEFAULT_LONG_LENGTH, 0), kernel=None))]
1107pub fn didi_index_batch_py<'py>(
1108 py: Python<'py>,
1109 data: PyReadonlyArray1<'py, f64>,
1110 short_length_range: (usize, usize, usize),
1111 medium_length_range: (usize, usize, usize),
1112 long_length_range: (usize, usize, usize),
1113 kernel: Option<&str>,
1114) -> PyResult<Bound<'py, PyDict>> {
1115 let data = data.as_slice()?;
1116 let kern = validate_kernel(kernel, true)?;
1117 let sweep = DidiIndexBatchRange {
1118 short_length: short_length_range,
1119 medium_length: medium_length_range,
1120 long_length: long_length_range,
1121 };
1122 let combos =
1123 expand_grid_didi_index(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1124 let rows = combos.len();
1125 let cols = data.len();
1126 let total = rows
1127 .checked_mul(cols)
1128 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1129
1130 let short_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1131 let long_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1132 let crossover_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1133 let crossunder_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1134 let short_slice = unsafe { short_arr.as_slice_mut()? };
1135 let long_slice = unsafe { long_arr.as_slice_mut()? };
1136 let crossover_slice = unsafe { crossover_arr.as_slice_mut()? };
1137 let crossunder_slice = unsafe { crossunder_arr.as_slice_mut()? };
1138
1139 let combos = py
1140 .allow_threads(|| {
1141 let batch = match kern {
1142 Kernel::Auto => detect_best_batch_kernel(),
1143 other => other,
1144 };
1145 didi_index_batch_inner_into(
1146 data,
1147 &sweep,
1148 batch.to_non_batch(),
1149 short_slice,
1150 long_slice,
1151 crossover_slice,
1152 crossunder_slice,
1153 )
1154 })
1155 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1156
1157 let dict = PyDict::new(py);
1158 dict.set_item("short", short_arr.reshape((rows, cols))?)?;
1159 dict.set_item("long", long_arr.reshape((rows, cols))?)?;
1160 dict.set_item("crossover", crossover_arr.reshape((rows, cols))?)?;
1161 dict.set_item("crossunder", crossunder_arr.reshape((rows, cols))?)?;
1162 dict.set_item(
1163 "short_lengths",
1164 combos
1165 .iter()
1166 .map(|p| p.short_length.unwrap_or(DEFAULT_SHORT_LENGTH) as u64)
1167 .collect::<Vec<_>>()
1168 .into_pyarray(py),
1169 )?;
1170 dict.set_item(
1171 "medium_lengths",
1172 combos
1173 .iter()
1174 .map(|p| p.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH) as u64)
1175 .collect::<Vec<_>>()
1176 .into_pyarray(py),
1177 )?;
1178 dict.set_item(
1179 "long_lengths",
1180 combos
1181 .iter()
1182 .map(|p| p.long_length.unwrap_or(DEFAULT_LONG_LENGTH) as u64)
1183 .collect::<Vec<_>>()
1184 .into_pyarray(py),
1185 )?;
1186 dict.set_item("rows", rows)?;
1187 dict.set_item("cols", cols)?;
1188 Ok(dict)
1189}
1190
1191#[cfg(feature = "python")]
1192pub fn register_didi_index_module(module: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1193 module.add_function(wrap_pyfunction!(didi_index_py, module)?)?;
1194 module.add_function(wrap_pyfunction!(didi_index_batch_py, module)?)?;
1195 module.add_class::<DidiIndexStreamPy>()?;
1196 Ok(())
1197}
1198
1199#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1200#[wasm_bindgen(js_name = "didi_index_js")]
1201pub fn didi_index_js(
1202 data: &[f64],
1203 short_length: usize,
1204 medium_length: usize,
1205 long_length: usize,
1206) -> Result<JsValue, JsValue> {
1207 let input = DidiIndexInput::from_slice(
1208 data,
1209 DidiIndexParams {
1210 short_length: Some(short_length),
1211 medium_length: Some(medium_length),
1212 long_length: Some(long_length),
1213 },
1214 );
1215 let out = didi_index(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1216 let result = js_sys::Object::new();
1217
1218 let short = js_sys::Float64Array::new_with_length(out.short.len() as u32);
1219 short.copy_from(&out.short);
1220 js_sys::Reflect::set(&result, &JsValue::from_str("short"), &short)?;
1221
1222 let long = js_sys::Float64Array::new_with_length(out.long.len() as u32);
1223 long.copy_from(&out.long);
1224 js_sys::Reflect::set(&result, &JsValue::from_str("long"), &long)?;
1225
1226 let crossover = js_sys::Float64Array::new_with_length(out.crossover.len() as u32);
1227 crossover.copy_from(&out.crossover);
1228 js_sys::Reflect::set(&result, &JsValue::from_str("crossover"), &crossover)?;
1229
1230 let crossunder = js_sys::Float64Array::new_with_length(out.crossunder.len() as u32);
1231 crossunder.copy_from(&out.crossunder);
1232 js_sys::Reflect::set(&result, &JsValue::from_str("crossunder"), &crossunder)?;
1233
1234 Ok(result.into())
1235}
1236
1237#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1238#[wasm_bindgen]
1239pub fn didi_index_alloc(len: usize) -> *mut f64 {
1240 let mut vec = Vec::<f64>::with_capacity(len);
1241 let ptr = vec.as_mut_ptr();
1242 std::mem::forget(vec);
1243 ptr
1244}
1245
1246#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1247#[wasm_bindgen]
1248pub fn didi_index_free(ptr: *mut f64, len: usize) {
1249 if !ptr.is_null() {
1250 unsafe {
1251 let _ = Vec::from_raw_parts(ptr, len, len);
1252 }
1253 }
1254}
1255
1256#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1257#[wasm_bindgen]
1258pub fn didi_index_into(
1259 data_ptr: *const f64,
1260 short_ptr: *mut f64,
1261 long_ptr: *mut f64,
1262 crossover_ptr: *mut f64,
1263 crossunder_ptr: *mut f64,
1264 len: usize,
1265 short_length: usize,
1266 medium_length: usize,
1267 long_length: usize,
1268) -> Result<(), JsValue> {
1269 if data_ptr.is_null()
1270 || short_ptr.is_null()
1271 || long_ptr.is_null()
1272 || crossover_ptr.is_null()
1273 || crossunder_ptr.is_null()
1274 {
1275 return Err(JsValue::from_str("Null pointer provided"));
1276 }
1277
1278 unsafe {
1279 let data = std::slice::from_raw_parts(data_ptr, len);
1280 let input = DidiIndexInput::from_slice(
1281 data,
1282 DidiIndexParams {
1283 short_length: Some(short_length),
1284 medium_length: Some(medium_length),
1285 long_length: Some(long_length),
1286 },
1287 );
1288 let alias = data_ptr == short_ptr
1289 || data_ptr == long_ptr
1290 || data_ptr == crossover_ptr
1291 || data_ptr == crossunder_ptr;
1292 if alias {
1293 let mut short_tmp = vec![0.0; len];
1294 let mut long_tmp = vec![0.0; len];
1295 let mut crossover_tmp = vec![0.0; len];
1296 let mut crossunder_tmp = vec![0.0; len];
1297 didi_index_into_slices(
1298 &mut short_tmp,
1299 &mut long_tmp,
1300 &mut crossover_tmp,
1301 &mut crossunder_tmp,
1302 &input,
1303 Kernel::Auto,
1304 )
1305 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1306 std::slice::from_raw_parts_mut(short_ptr, len).copy_from_slice(&short_tmp);
1307 std::slice::from_raw_parts_mut(long_ptr, len).copy_from_slice(&long_tmp);
1308 std::slice::from_raw_parts_mut(crossover_ptr, len).copy_from_slice(&crossover_tmp);
1309 std::slice::from_raw_parts_mut(crossunder_ptr, len).copy_from_slice(&crossunder_tmp);
1310 } else {
1311 didi_index_into_slices(
1312 std::slice::from_raw_parts_mut(short_ptr, len),
1313 std::slice::from_raw_parts_mut(long_ptr, len),
1314 std::slice::from_raw_parts_mut(crossover_ptr, len),
1315 std::slice::from_raw_parts_mut(crossunder_ptr, len),
1316 &input,
1317 Kernel::Auto,
1318 )
1319 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1320 }
1321 }
1322 Ok(())
1323}
1324
1325#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1326#[derive(Serialize, Deserialize)]
1327pub struct DidiIndexBatchConfig {
1328 pub short_length_range: (usize, usize, usize),
1329 pub medium_length_range: Option<(usize, usize, usize)>,
1330 pub long_length_range: Option<(usize, usize, usize)>,
1331}
1332
1333#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1334#[derive(Serialize, Deserialize)]
1335pub struct DidiIndexBatchJsOutput {
1336 pub short: Vec<f64>,
1337 pub long: Vec<f64>,
1338 pub crossover: Vec<f64>,
1339 pub crossunder: Vec<f64>,
1340 pub combos: Vec<DidiIndexParams>,
1341 pub short_lengths: Vec<usize>,
1342 pub medium_lengths: Vec<usize>,
1343 pub long_lengths: Vec<usize>,
1344 pub rows: usize,
1345 pub cols: usize,
1346}
1347
1348#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1349#[wasm_bindgen(js_name = "didi_index_batch_js")]
1350pub fn didi_index_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1351 let config: DidiIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1352 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1353 let sweep = DidiIndexBatchRange {
1354 short_length: config.short_length_range,
1355 medium_length: config.medium_length_range.unwrap_or((
1356 DEFAULT_MEDIUM_LENGTH,
1357 DEFAULT_MEDIUM_LENGTH,
1358 0,
1359 )),
1360 long_length: config.long_length_range.unwrap_or((
1361 DEFAULT_LONG_LENGTH,
1362 DEFAULT_LONG_LENGTH,
1363 0,
1364 )),
1365 };
1366 let out = didi_index_batch_inner(
1367 data,
1368 &sweep,
1369 detect_best_batch_kernel().to_non_batch(),
1370 false,
1371 )
1372 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1373 serde_wasm_bindgen::to_value(&DidiIndexBatchJsOutput {
1374 short_lengths: out
1375 .combos
1376 .iter()
1377 .map(|p| p.short_length.unwrap_or(DEFAULT_SHORT_LENGTH))
1378 .collect(),
1379 medium_lengths: out
1380 .combos
1381 .iter()
1382 .map(|p| p.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH))
1383 .collect(),
1384 long_lengths: out
1385 .combos
1386 .iter()
1387 .map(|p| p.long_length.unwrap_or(DEFAULT_LONG_LENGTH))
1388 .collect(),
1389 short: out.short,
1390 long: out.long,
1391 crossover: out.crossover,
1392 crossunder: out.crossunder,
1393 combos: out.combos,
1394 rows: out.rows,
1395 cols: out.cols,
1396 })
1397 .map_err(|e| JsValue::from_str(&e.to_string()))
1398}
1399
1400#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1401#[wasm_bindgen]
1402pub fn didi_index_batch_into(
1403 data_ptr: *const f64,
1404 short_ptr: *mut f64,
1405 long_ptr: *mut f64,
1406 crossover_ptr: *mut f64,
1407 crossunder_ptr: *mut f64,
1408 len: usize,
1409 short_start: usize,
1410 short_end: usize,
1411 short_step: usize,
1412 medium_start: usize,
1413 medium_end: usize,
1414 medium_step: usize,
1415 long_start: usize,
1416 long_end: usize,
1417 long_step: usize,
1418) -> Result<usize, JsValue> {
1419 if data_ptr.is_null()
1420 || short_ptr.is_null()
1421 || long_ptr.is_null()
1422 || crossover_ptr.is_null()
1423 || crossunder_ptr.is_null()
1424 {
1425 return Err(JsValue::from_str("Null pointer provided"));
1426 }
1427
1428 let sweep = DidiIndexBatchRange {
1429 short_length: (short_start, short_end, short_step),
1430 medium_length: (medium_start, medium_end, medium_step),
1431 long_length: (long_start, long_end, long_step),
1432 };
1433 let combos = expand_grid_didi_index(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1434 let rows = combos.len();
1435 let total = rows
1436 .checked_mul(len)
1437 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1438
1439 unsafe {
1440 let data = std::slice::from_raw_parts(data_ptr, len);
1441 didi_index_batch_inner_into(
1442 data,
1443 &sweep,
1444 detect_best_batch_kernel().to_non_batch(),
1445 std::slice::from_raw_parts_mut(short_ptr, total),
1446 std::slice::from_raw_parts_mut(long_ptr, total),
1447 std::slice::from_raw_parts_mut(crossover_ptr, total),
1448 std::slice::from_raw_parts_mut(crossunder_ptr, total),
1449 )
1450 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1451 }
1452 Ok(rows)
1453}
1454
1455#[cfg(test)]
1456mod tests {
1457 use super::*;
1458
1459 fn approx_eq(a: f64, b: f64) -> bool {
1460 (a - b).abs() <= 1e-12
1461 }
1462
1463 fn approx_eq_or_nan(a: f64, b: f64) -> bool {
1464 (a.is_nan() && b.is_nan()) || approx_eq(a, b)
1465 }
1466
1467 #[test]
1468 fn didi_index_matches_manual_ratios() {
1469 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1470 let input = DidiIndexInput::from_slice(
1471 &data,
1472 DidiIndexParams {
1473 short_length: Some(2),
1474 medium_length: Some(3),
1475 long_length: Some(4),
1476 },
1477 );
1478 let out = didi_index(&input).unwrap();
1479
1480 assert!(out.short[..3].iter().all(|v| v.is_nan()));
1481 assert!(out.long[..3].iter().all(|v| v.is_nan()));
1482 assert!(approx_eq(out.short[3], 3.5 / 3.0));
1483 assert!(approx_eq(out.long[3], 2.5 / 3.0));
1484 assert!(approx_eq(out.short[4], 4.5 / 4.0));
1485 assert!(approx_eq(out.long[4], 3.5 / 4.0));
1486 assert!(approx_eq(out.crossover[3], 0.0));
1487 assert!(approx_eq(out.crossunder[3], 0.0));
1488 }
1489
1490 #[test]
1491 fn didi_index_detects_crossover_and_crossunder() {
1492 let cross_up = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0];
1493 let up_input = DidiIndexInput::from_slice(
1494 &cross_up,
1495 DidiIndexParams {
1496 short_length: Some(2),
1497 medium_length: Some(3),
1498 long_length: Some(4),
1499 },
1500 );
1501 let up_out = didi_index(&up_input).unwrap();
1502 assert!(approx_eq(up_out.crossover[6], 1.0));
1503 assert!(approx_eq(up_out.crossunder[6], 0.0));
1504
1505 let cross_down = [1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0];
1506 let down_input = DidiIndexInput::from_slice(
1507 &cross_down,
1508 DidiIndexParams {
1509 short_length: Some(2),
1510 medium_length: Some(3),
1511 long_length: Some(4),
1512 },
1513 );
1514 let down_out = didi_index(&down_input).unwrap();
1515 assert!(approx_eq(down_out.crossunder[6], 1.0));
1516 assert!(approx_eq(down_out.crossover[6], 0.0));
1517 }
1518
1519 #[test]
1520 fn didi_index_stream_matches_batch_with_reset() {
1521 let data = [1.0, 2.0, 3.0, 4.0, 5.0, f64::NAN, 3.0, 4.0, 5.0, 6.0];
1522 let params = DidiIndexParams {
1523 short_length: Some(2),
1524 medium_length: Some(3),
1525 long_length: Some(4),
1526 };
1527 let input = DidiIndexInput::from_slice(&data, params.clone());
1528 let batch = didi_index(&input).unwrap();
1529 let mut stream = DidiIndexStream::try_new(params).unwrap();
1530
1531 let mut short = Vec::new();
1532 let mut long = Vec::new();
1533 let mut crossover = Vec::new();
1534 let mut crossunder = Vec::new();
1535 for &value in &data {
1536 match stream.update(value) {
1537 Some((s, l, co, cu)) => {
1538 short.push(s);
1539 long.push(l);
1540 crossover.push(co);
1541 crossunder.push(cu);
1542 }
1543 None => {
1544 short.push(f64::NAN);
1545 long.push(f64::NAN);
1546 crossover.push(f64::NAN);
1547 crossunder.push(f64::NAN);
1548 }
1549 }
1550 }
1551
1552 assert_eq!(stream.get_warmup_period(), 3);
1553 for i in 0..data.len() {
1554 assert!(approx_eq_or_nan(batch.short[i], short[i]));
1555 assert!(approx_eq_or_nan(batch.long[i], long[i]));
1556 assert!(approx_eq_or_nan(batch.crossover[i], crossover[i]));
1557 assert!(approx_eq_or_nan(batch.crossunder[i], crossunder[i]));
1558 }
1559 assert!(batch.short[5].is_nan());
1560 assert!(batch.short[8].is_nan());
1561 assert!(batch.short[9].is_finite());
1562 }
1563
1564 #[test]
1565 fn didi_index_batch_default_row_matches_single() {
1566 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0];
1567 let batch = didi_index_batch_slice(
1568 &data,
1569 &DidiIndexBatchRange {
1570 short_length: (2, 2, 0),
1571 medium_length: (3, 3, 0),
1572 long_length: (4, 4, 0),
1573 },
1574 )
1575 .unwrap();
1576 let single = didi_index(&DidiIndexInput::from_slice(
1577 &data,
1578 DidiIndexParams {
1579 short_length: Some(2),
1580 medium_length: Some(3),
1581 long_length: Some(4),
1582 },
1583 ))
1584 .unwrap();
1585
1586 assert_eq!(batch.rows, 1);
1587 assert_eq!(batch.cols, data.len());
1588 assert_eq!(batch.short.len(), data.len());
1589 for i in 0..data.len() {
1590 assert!(approx_eq_or_nan(batch.short[i], single.short[i]));
1591 assert!(approx_eq_or_nan(batch.long[i], single.long[i]));
1592 }
1593 }
1594
1595 #[test]
1596 fn didi_index_rejects_invalid_lengths() {
1597 let data = [1.0, 2.0, 3.0];
1598 let err = didi_index(&DidiIndexInput::from_slice(
1599 &data,
1600 DidiIndexParams {
1601 short_length: Some(0),
1602 medium_length: Some(2),
1603 long_length: Some(3),
1604 },
1605 ))
1606 .unwrap_err();
1607 assert!(matches!(err, DidiIndexError::InvalidShortLength { .. }));
1608 }
1609}