Skip to main content

vector_ta/indicators/
didi_index.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(feature = "python")]
10use pyo3::wrap_pyfunction;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::mem::ManuallyDrop;
28use thiserror::Error;
29
30const DEFAULT_SHORT_LENGTH: usize = 3;
31const DEFAULT_MEDIUM_LENGTH: usize = 8;
32const DEFAULT_LONG_LENGTH: usize = 20;
33
34impl<'a> AsRef<[f64]> for DidiIndexInput<'a> {
35    #[inline(always)]
36    fn as_ref(&self) -> &[f64] {
37        match &self.data {
38            DidiIndexData::Slice(slice) => slice,
39            DidiIndexData::Candles { candles, source } => source_type(candles, source),
40        }
41    }
42}
43
44#[derive(Debug, Clone)]
45pub enum DidiIndexData<'a> {
46    Candles {
47        candles: &'a Candles,
48        source: &'a str,
49    },
50    Slice(&'a [f64]),
51}
52
53#[derive(Debug, Clone)]
54pub struct DidiIndexOutput {
55    pub short: Vec<f64>,
56    pub long: Vec<f64>,
57    pub crossover: Vec<f64>,
58    pub crossunder: Vec<f64>,
59}
60
61#[derive(Debug, Clone, PartialEq)]
62#[cfg_attr(
63    all(target_arch = "wasm32", feature = "wasm"),
64    derive(Serialize, Deserialize)
65)]
66pub struct DidiIndexParams {
67    pub short_length: Option<usize>,
68    pub medium_length: Option<usize>,
69    pub long_length: Option<usize>,
70}
71
72impl Default for DidiIndexParams {
73    fn default() -> Self {
74        Self {
75            short_length: Some(DEFAULT_SHORT_LENGTH),
76            medium_length: Some(DEFAULT_MEDIUM_LENGTH),
77            long_length: Some(DEFAULT_LONG_LENGTH),
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct DidiIndexInput<'a> {
84    pub data: DidiIndexData<'a>,
85    pub params: DidiIndexParams,
86}
87
88impl<'a> DidiIndexInput<'a> {
89    #[inline]
90    pub fn from_candles(candles: &'a Candles, source: &'a str, params: DidiIndexParams) -> Self {
91        Self {
92            data: DidiIndexData::Candles { candles, source },
93            params,
94        }
95    }
96
97    #[inline]
98    pub fn from_slice(slice: &'a [f64], params: DidiIndexParams) -> Self {
99        Self {
100            data: DidiIndexData::Slice(slice),
101            params,
102        }
103    }
104
105    #[inline]
106    pub fn with_default_candles(candles: &'a Candles) -> Self {
107        Self::from_candles(candles, "close", DidiIndexParams::default())
108    }
109
110    #[inline]
111    pub fn get_short_length(&self) -> usize {
112        self.params.short_length.unwrap_or(DEFAULT_SHORT_LENGTH)
113    }
114
115    #[inline]
116    pub fn get_medium_length(&self) -> usize {
117        self.params.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH)
118    }
119
120    #[inline]
121    pub fn get_long_length(&self) -> usize {
122        self.params.long_length.unwrap_or(DEFAULT_LONG_LENGTH)
123    }
124}
125
126#[derive(Copy, Clone, Debug)]
127pub struct DidiIndexBuilder {
128    short_length: Option<usize>,
129    medium_length: Option<usize>,
130    long_length: Option<usize>,
131    kernel: Kernel,
132}
133
134impl Default for DidiIndexBuilder {
135    fn default() -> Self {
136        Self {
137            short_length: None,
138            medium_length: None,
139            long_length: None,
140            kernel: Kernel::Auto,
141        }
142    }
143}
144
145impl DidiIndexBuilder {
146    #[inline]
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    #[inline]
152    pub fn short_length(mut self, short_length: usize) -> Self {
153        self.short_length = Some(short_length);
154        self
155    }
156
157    #[inline]
158    pub fn medium_length(mut self, medium_length: usize) -> Self {
159        self.medium_length = Some(medium_length);
160        self
161    }
162
163    #[inline]
164    pub fn long_length(mut self, long_length: usize) -> Self {
165        self.long_length = Some(long_length);
166        self
167    }
168
169    #[inline]
170    pub fn kernel(mut self, kernel: Kernel) -> Self {
171        self.kernel = kernel;
172        self
173    }
174
175    #[inline]
176    pub fn apply(self, candles: &Candles, source: &str) -> Result<DidiIndexOutput, DidiIndexError> {
177        let input = DidiIndexInput::from_candles(
178            candles,
179            source,
180            DidiIndexParams {
181                short_length: self.short_length,
182                medium_length: self.medium_length,
183                long_length: self.long_length,
184            },
185        );
186        didi_index_with_kernel(&input, self.kernel)
187    }
188
189    #[inline]
190    pub fn apply_slice(self, data: &[f64]) -> Result<DidiIndexOutput, DidiIndexError> {
191        let input = DidiIndexInput::from_slice(
192            data,
193            DidiIndexParams {
194                short_length: self.short_length,
195                medium_length: self.medium_length,
196                long_length: self.long_length,
197            },
198        );
199        didi_index_with_kernel(&input, self.kernel)
200    }
201
202    #[inline]
203    pub fn into_stream(self) -> Result<DidiIndexStream, DidiIndexError> {
204        DidiIndexStream::try_new(DidiIndexParams {
205            short_length: self.short_length,
206            medium_length: self.medium_length,
207            long_length: self.long_length,
208        })
209    }
210}
211
212#[derive(Debug, Error)]
213pub enum DidiIndexError {
214    #[error("didi_index: Input data slice is empty.")]
215    EmptyInputData,
216    #[error("didi_index: All values are NaN.")]
217    AllValuesNaN,
218    #[error(
219        "didi_index: Invalid short_length: short_length = {short_length}, data length = {data_len}"
220    )]
221    InvalidShortLength {
222        short_length: usize,
223        data_len: usize,
224    },
225    #[error("didi_index: Invalid medium_length: medium_length = {medium_length}, data length = {data_len}")]
226    InvalidMediumLength {
227        medium_length: usize,
228        data_len: usize,
229    },
230    #[error(
231        "didi_index: Invalid long_length: long_length = {long_length}, data length = {data_len}"
232    )]
233    InvalidLongLength { long_length: usize, data_len: usize },
234    #[error("didi_index: Not enough valid data: needed = {needed}, valid = {valid}")]
235    NotEnoughValidData { needed: usize, valid: usize },
236    #[error("didi_index: Output length mismatch: expected = {expected}, short = {short_got}, long = {long_got}, crossover = {crossover_got}, crossunder = {crossunder_got}")]
237    OutputLengthMismatch {
238        expected: usize,
239        short_got: usize,
240        long_got: usize,
241        crossover_got: usize,
242        crossunder_got: usize,
243    },
244    #[error("didi_index: Invalid range: start={start}, end={end}, step={step}")]
245    InvalidRange {
246        start: String,
247        end: String,
248        step: String,
249    },
250    #[error("didi_index: Invalid kernel for batch: {0:?}")]
251    InvalidKernelForBatch(Kernel),
252}
253
254#[derive(Debug, Clone)]
255struct SmaWindow {
256    period: usize,
257    values: Vec<f64>,
258    idx: usize,
259    count: usize,
260    sum: f64,
261}
262
263impl SmaWindow {
264    #[inline]
265    fn new(period: usize) -> Self {
266        Self {
267            period,
268            values: vec![0.0; period.max(1)],
269            idx: 0,
270            count: 0,
271            sum: 0.0,
272        }
273    }
274
275    #[inline]
276    fn reset(&mut self) {
277        self.idx = 0;
278        self.count = 0;
279        self.sum = 0.0;
280    }
281
282    #[inline]
283    fn update(&mut self, value: f64) -> Option<f64> {
284        if self.count < self.period {
285            self.values[self.idx] = value;
286            self.sum += value;
287            self.count += 1;
288            self.idx += 1;
289            if self.idx == self.period {
290                self.idx = 0;
291            }
292            if self.count == self.period {
293                Some(self.sum / self.period as f64)
294            } else {
295                None
296            }
297        } else {
298            let old = self.values[self.idx];
299            self.values[self.idx] = value;
300            self.sum += value - old;
301            self.idx += 1;
302            if self.idx == self.period {
303                self.idx = 0;
304            }
305            Some(self.sum / self.period as f64)
306        }
307    }
308}
309
310#[derive(Debug, Clone)]
311pub struct DidiIndexStream {
312    short: SmaWindow,
313    medium: SmaWindow,
314    long: SmaWindow,
315    prev_short: f64,
316    prev_long: f64,
317    have_prev: bool,
318    warmup: usize,
319}
320
321impl DidiIndexStream {
322    pub fn try_new(params: DidiIndexParams) -> Result<Self, DidiIndexError> {
323        let short_length = params.short_length.unwrap_or(DEFAULT_SHORT_LENGTH);
324        if short_length == 0 {
325            return Err(DidiIndexError::InvalidShortLength {
326                short_length,
327                data_len: 0,
328            });
329        }
330        let medium_length = params.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH);
331        if medium_length == 0 {
332            return Err(DidiIndexError::InvalidMediumLength {
333                medium_length,
334                data_len: 0,
335            });
336        }
337        let long_length = params.long_length.unwrap_or(DEFAULT_LONG_LENGTH);
338        if long_length == 0 {
339            return Err(DidiIndexError::InvalidLongLength {
340                long_length,
341                data_len: 0,
342            });
343        }
344        Ok(Self {
345            short: SmaWindow::new(short_length),
346            medium: SmaWindow::new(medium_length),
347            long: SmaWindow::new(long_length),
348            prev_short: f64::NAN,
349            prev_long: f64::NAN,
350            have_prev: false,
351            warmup: short_length.max(medium_length).max(long_length) - 1,
352        })
353    }
354
355    #[inline]
356    fn reset(&mut self) {
357        self.short.reset();
358        self.medium.reset();
359        self.long.reset();
360        self.prev_short = f64::NAN;
361        self.prev_long = f64::NAN;
362        self.have_prev = false;
363    }
364
365    #[inline]
366    pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64, f64)> {
367        if !valid_value(value) {
368            self.reset();
369            return None;
370        }
371
372        let short_ma = self.short.update(value);
373        let medium_ma = self.medium.update(value);
374        let long_ma = self.long.update(value);
375        if short_ma.is_none() || medium_ma.is_none() || long_ma.is_none() {
376            self.have_prev = false;
377            return None;
378        }
379
380        let medium_ma = medium_ma.unwrap_or(f64::NAN);
381        if !medium_ma.is_finite() || medium_ma == 0.0 {
382            self.have_prev = false;
383            return Some((f64::NAN, f64::NAN, f64::NAN, f64::NAN));
384        }
385
386        let short = short_ma.unwrap_or(f64::NAN) / medium_ma;
387        let long = long_ma.unwrap_or(f64::NAN) / medium_ma;
388        if !short.is_finite() || !long.is_finite() {
389            self.have_prev = false;
390            return Some((f64::NAN, f64::NAN, f64::NAN, f64::NAN));
391        }
392
393        let crossover = if self.have_prev && short > long && self.prev_short <= self.prev_long {
394            1.0
395        } else {
396            0.0
397        };
398        let crossunder = if self.have_prev && short < long && self.prev_short >= self.prev_long {
399            1.0
400        } else {
401            0.0
402        };
403        self.prev_short = short;
404        self.prev_long = long;
405        self.have_prev = true;
406        Some((short, long, crossover, crossunder))
407    }
408
409    #[inline]
410    pub fn get_warmup_period(&self) -> usize {
411        self.warmup
412    }
413}
414
415#[inline]
416pub fn didi_index(input: &DidiIndexInput) -> Result<DidiIndexOutput, DidiIndexError> {
417    didi_index_with_kernel(input, Kernel::Auto)
418}
419
420#[inline(always)]
421fn valid_value(value: f64) -> bool {
422    value.is_finite()
423}
424
425#[inline(always)]
426fn first_valid_value(data: &[f64]) -> usize {
427    let mut i = 0usize;
428    while i < data.len() {
429        if valid_value(data[i]) {
430            break;
431        }
432        i += 1;
433    }
434    i.min(data.len())
435}
436
437#[inline(always)]
438fn count_valid_values(data: &[f64]) -> usize {
439    data.iter().filter(|v| valid_value(**v)).count()
440}
441
442#[inline(always)]
443fn didi_index_row_from_slice(
444    data: &[f64],
445    params: &DidiIndexParams,
446    short_out: &mut [f64],
447    long_out: &mut [f64],
448    crossover_out: &mut [f64],
449    crossunder_out: &mut [f64],
450) -> Result<(), DidiIndexError> {
451    let mut stream = DidiIndexStream::try_new(params.clone())?;
452    for i in 0..data.len() {
453        match stream.update(data[i]) {
454            Some((short, long, crossover, crossunder)) => {
455                short_out[i] = short;
456                long_out[i] = long;
457                crossover_out[i] = crossover;
458                crossunder_out[i] = crossunder;
459            }
460            None => {
461                short_out[i] = f64::NAN;
462                long_out[i] = f64::NAN;
463                crossover_out[i] = f64::NAN;
464                crossunder_out[i] = f64::NAN;
465            }
466        }
467    }
468    Ok(())
469}
470
471#[inline(always)]
472fn didi_index_prepare<'a>(
473    input: &'a DidiIndexInput,
474    kernel: Kernel,
475) -> Result<(&'a [f64], usize, DidiIndexParams, Kernel), DidiIndexError> {
476    let data = input.as_ref();
477    if data.is_empty() {
478        return Err(DidiIndexError::EmptyInputData);
479    }
480
481    let first = first_valid_value(data);
482    if first >= data.len() {
483        return Err(DidiIndexError::AllValuesNaN);
484    }
485
486    let params = input.params.clone();
487    let short_length = params.short_length.unwrap_or(DEFAULT_SHORT_LENGTH);
488    let medium_length = params.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH);
489    let long_length = params.long_length.unwrap_or(DEFAULT_LONG_LENGTH);
490    let len = data.len();
491    if short_length == 0 || short_length > len {
492        return Err(DidiIndexError::InvalidShortLength {
493            short_length,
494            data_len: len,
495        });
496    }
497    if medium_length == 0 || medium_length > len {
498        return Err(DidiIndexError::InvalidMediumLength {
499            medium_length,
500            data_len: len,
501        });
502    }
503    if long_length == 0 || long_length > len {
504        return Err(DidiIndexError::InvalidLongLength {
505            long_length,
506            data_len: len,
507        });
508    }
509
510    let needed = short_length.max(medium_length).max(long_length);
511    let valid = count_valid_values(data);
512    if valid < needed {
513        return Err(DidiIndexError::NotEnoughValidData { needed, valid });
514    }
515
516    let chosen = match kernel {
517        Kernel::Auto => detect_best_kernel(),
518        other => other.to_non_batch(),
519    };
520    Ok((data, first, params, chosen))
521}
522
523#[inline]
524pub fn didi_index_with_kernel(
525    input: &DidiIndexInput,
526    kernel: Kernel,
527) -> Result<DidiIndexOutput, DidiIndexError> {
528    let (data, first, params, _chosen) = didi_index_prepare(input, kernel)?;
529    let mut short = alloc_with_nan_prefix(data.len(), first);
530    let mut long = alloc_with_nan_prefix(data.len(), first);
531    let mut crossover = alloc_with_nan_prefix(data.len(), first);
532    let mut crossunder = alloc_with_nan_prefix(data.len(), first);
533    didi_index_row_from_slice(
534        data,
535        &params,
536        &mut short,
537        &mut long,
538        &mut crossover,
539        &mut crossunder,
540    )?;
541    Ok(DidiIndexOutput {
542        short,
543        long,
544        crossover,
545        crossunder,
546    })
547}
548
549#[inline]
550pub fn didi_index_into_slices(
551    short_out: &mut [f64],
552    long_out: &mut [f64],
553    crossover_out: &mut [f64],
554    crossunder_out: &mut [f64],
555    input: &DidiIndexInput,
556    kernel: Kernel,
557) -> Result<(), DidiIndexError> {
558    let (data, _first, params, _chosen) = didi_index_prepare(input, kernel)?;
559    if short_out.len() != data.len()
560        || long_out.len() != data.len()
561        || crossover_out.len() != data.len()
562        || crossunder_out.len() != data.len()
563    {
564        return Err(DidiIndexError::OutputLengthMismatch {
565            expected: data.len(),
566            short_got: short_out.len(),
567            long_got: long_out.len(),
568            crossover_got: crossover_out.len(),
569            crossunder_got: crossunder_out.len(),
570        });
571    }
572    didi_index_row_from_slice(
573        data,
574        &params,
575        short_out,
576        long_out,
577        crossover_out,
578        crossunder_out,
579    )
580}
581
582#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
583#[inline]
584pub fn didi_index_into(
585    input: &DidiIndexInput,
586    short_out: &mut [f64],
587    long_out: &mut [f64],
588    crossover_out: &mut [f64],
589    crossunder_out: &mut [f64],
590) -> Result<(), DidiIndexError> {
591    didi_index_into_slices(
592        short_out,
593        long_out,
594        crossover_out,
595        crossunder_out,
596        input,
597        Kernel::Auto,
598    )
599}
600
601#[derive(Clone, Debug)]
602pub struct DidiIndexBatchRange {
603    pub short_length: (usize, usize, usize),
604    pub medium_length: (usize, usize, usize),
605    pub long_length: (usize, usize, usize),
606}
607
608impl Default for DidiIndexBatchRange {
609    fn default() -> Self {
610        Self {
611            short_length: (DEFAULT_SHORT_LENGTH, DEFAULT_SHORT_LENGTH, 0),
612            medium_length: (DEFAULT_MEDIUM_LENGTH, DEFAULT_MEDIUM_LENGTH, 0),
613            long_length: (DEFAULT_LONG_LENGTH, DEFAULT_LONG_LENGTH, 0),
614        }
615    }
616}
617
618#[derive(Clone, Debug)]
619pub struct DidiIndexBatchBuilder {
620    range: DidiIndexBatchRange,
621    kernel: Kernel,
622}
623
624impl Default for DidiIndexBatchBuilder {
625    fn default() -> Self {
626        Self {
627            range: DidiIndexBatchRange::default(),
628            kernel: Kernel::Auto,
629        }
630    }
631}
632
633impl DidiIndexBatchBuilder {
634    #[inline]
635    pub fn new() -> Self {
636        Self::default()
637    }
638
639    #[inline]
640    pub fn short_length_range(mut self, range: (usize, usize, usize)) -> Self {
641        self.range.short_length = range;
642        self
643    }
644
645    #[inline]
646    pub fn medium_length_range(mut self, range: (usize, usize, usize)) -> Self {
647        self.range.medium_length = range;
648        self
649    }
650
651    #[inline]
652    pub fn long_length_range(mut self, range: (usize, usize, usize)) -> Self {
653        self.range.long_length = range;
654        self
655    }
656
657    #[inline]
658    pub fn kernel(mut self, kernel: Kernel) -> Self {
659        self.kernel = kernel;
660        self
661    }
662
663    #[inline]
664    pub fn apply_slice(self, data: &[f64]) -> Result<DidiIndexBatchOutput, DidiIndexError> {
665        didi_index_batch_with_kernel(data, &self.range, self.kernel)
666    }
667
668    #[inline]
669    pub fn apply_candles(
670        self,
671        candles: &Candles,
672        source: &str,
673    ) -> Result<DidiIndexBatchOutput, DidiIndexError> {
674        self.apply_slice(source_type(candles, source))
675    }
676
677    #[inline]
678    pub fn with_default_candles(candles: &Candles) -> Result<DidiIndexBatchOutput, DidiIndexError> {
679        DidiIndexBatchBuilder::new().apply_candles(candles, "close")
680    }
681}
682
683#[derive(Clone, Debug)]
684pub struct DidiIndexBatchOutput {
685    pub short: Vec<f64>,
686    pub long: Vec<f64>,
687    pub crossover: Vec<f64>,
688    pub crossunder: Vec<f64>,
689    pub combos: Vec<DidiIndexParams>,
690    pub rows: usize,
691    pub cols: usize,
692}
693
694impl DidiIndexBatchOutput {
695    pub fn row_for_params(&self, params: &DidiIndexParams) -> Option<usize> {
696        self.combos.iter().position(|combo| combo == params)
697    }
698
699    pub fn short_for(&self, params: &DidiIndexParams) -> Option<&[f64]> {
700        self.row_for_params(params).and_then(|row| {
701            row.checked_mul(self.cols)
702                .and_then(|start| self.short.get(start..start + self.cols))
703        })
704    }
705
706    pub fn long_for(&self, params: &DidiIndexParams) -> Option<&[f64]> {
707        self.row_for_params(params).and_then(|row| {
708            row.checked_mul(self.cols)
709                .and_then(|start| self.long.get(start..start + self.cols))
710        })
711    }
712}
713
714#[inline(always)]
715fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, DidiIndexError> {
716    if step == 0 || start == end {
717        return Ok(vec![start]);
718    }
719    let step = step.max(1);
720    if start < end {
721        let mut out = Vec::new();
722        let mut x = start;
723        while x <= end {
724            out.push(x);
725            match x.checked_add(step) {
726                Some(next) if next != x => x = next,
727                _ => break,
728            }
729        }
730        if out.is_empty() {
731            return Err(DidiIndexError::InvalidRange {
732                start: start.to_string(),
733                end: end.to_string(),
734                step: step.to_string(),
735            });
736        }
737        Ok(out)
738    } else {
739        let mut out = Vec::new();
740        let mut x = start;
741        loop {
742            out.push(x);
743            if x == end {
744                break;
745            }
746            let next = x.saturating_sub(step);
747            if next == x || next < end {
748                break;
749            }
750            x = next;
751        }
752        if out.is_empty() {
753            return Err(DidiIndexError::InvalidRange {
754                start: start.to_string(),
755                end: end.to_string(),
756                step: step.to_string(),
757            });
758        }
759        Ok(out)
760    }
761}
762
763#[inline(always)]
764fn expand_grid_didi_index(
765    range: &DidiIndexBatchRange,
766) -> Result<Vec<DidiIndexParams>, DidiIndexError> {
767    let shorts = axis_usize(range.short_length)?;
768    let mediums = axis_usize(range.medium_length)?;
769    let longs = axis_usize(range.long_length)?;
770
771    if let Some(&short_length) = shorts.iter().find(|&&value| value == 0) {
772        return Err(DidiIndexError::InvalidShortLength {
773            short_length,
774            data_len: 0,
775        });
776    }
777    if let Some(&medium_length) = mediums.iter().find(|&&value| value == 0) {
778        return Err(DidiIndexError::InvalidMediumLength {
779            medium_length,
780            data_len: 0,
781        });
782    }
783    if let Some(&long_length) = longs.iter().find(|&&value| value == 0) {
784        return Err(DidiIndexError::InvalidLongLength {
785            long_length,
786            data_len: 0,
787        });
788    }
789
790    let mut out = Vec::with_capacity(shorts.len() * mediums.len() * longs.len());
791    for &short_length in &shorts {
792        for &medium_length in &mediums {
793            for &long_length in &longs {
794                out.push(DidiIndexParams {
795                    short_length: Some(short_length),
796                    medium_length: Some(medium_length),
797                    long_length: Some(long_length),
798                });
799            }
800        }
801    }
802    Ok(out)
803}
804
805pub fn didi_index_batch_with_kernel(
806    data: &[f64],
807    sweep: &DidiIndexBatchRange,
808    kernel: Kernel,
809) -> Result<DidiIndexBatchOutput, DidiIndexError> {
810    let batch_kernel = match kernel {
811        Kernel::Auto => detect_best_batch_kernel(),
812        other if other.is_batch() => other,
813        other => return Err(DidiIndexError::InvalidKernelForBatch(other)),
814    };
815    didi_index_batch_inner(data, sweep, batch_kernel.to_non_batch(), false)
816}
817
818#[inline]
819pub fn didi_index_batch_slice(
820    data: &[f64],
821    sweep: &DidiIndexBatchRange,
822) -> Result<DidiIndexBatchOutput, DidiIndexError> {
823    didi_index_batch_with_kernel(data, sweep, Kernel::Auto)
824}
825
826#[inline]
827pub fn didi_index_batch_par_slice(
828    data: &[f64],
829    sweep: &DidiIndexBatchRange,
830) -> Result<DidiIndexBatchOutput, DidiIndexError> {
831    #[cfg(not(target_arch = "wasm32"))]
832    {
833        let kernel = detect_best_batch_kernel().to_non_batch();
834        return didi_index_batch_inner(data, sweep, kernel, true);
835    }
836    #[cfg(target_arch = "wasm32")]
837    {
838        didi_index_batch_inner(data, sweep, detect_best_kernel(), false)
839    }
840}
841
842pub fn didi_index_batch_inner(
843    data: &[f64],
844    sweep: &DidiIndexBatchRange,
845    kernel: Kernel,
846    parallel: bool,
847) -> Result<DidiIndexBatchOutput, DidiIndexError> {
848    if data.is_empty() {
849        return Err(DidiIndexError::EmptyInputData);
850    }
851    let first = first_valid_value(data);
852    if first >= data.len() {
853        return Err(DidiIndexError::AllValuesNaN);
854    }
855
856    let combos = expand_grid_didi_index(sweep)?;
857    let rows = combos.len();
858    let cols = data.len();
859    let total = rows
860        .checked_mul(cols)
861        .ok_or_else(|| DidiIndexError::OutputLengthMismatch {
862            expected: usize::MAX,
863            short_got: 0,
864            long_got: 0,
865            crossover_got: 0,
866            crossunder_got: 0,
867        })?;
868
869    let valid = count_valid_values(data);
870    let mut warms = Vec::with_capacity(rows);
871    for combo in &combos {
872        let short_length = combo.short_length.unwrap_or(DEFAULT_SHORT_LENGTH);
873        let medium_length = combo.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH);
874        let long_length = combo.long_length.unwrap_or(DEFAULT_LONG_LENGTH);
875        let needed = short_length.max(medium_length).max(long_length);
876        if short_length > cols {
877            return Err(DidiIndexError::InvalidShortLength {
878                short_length,
879                data_len: cols,
880            });
881        }
882        if medium_length > cols {
883            return Err(DidiIndexError::InvalidMediumLength {
884                medium_length,
885                data_len: cols,
886            });
887        }
888        if long_length > cols {
889            return Err(DidiIndexError::InvalidLongLength {
890                long_length,
891                data_len: cols,
892            });
893        }
894        if valid < needed {
895            return Err(DidiIndexError::NotEnoughValidData { needed, valid });
896        }
897        warms.push((first + needed - 1).min(cols));
898    }
899
900    let mut short_mu = make_uninit_matrix(rows, cols);
901    let mut long_mu = make_uninit_matrix(rows, cols);
902    let mut crossover_mu = make_uninit_matrix(rows, cols);
903    let mut crossunder_mu = make_uninit_matrix(rows, cols);
904    init_matrix_prefixes(&mut short_mu, cols, &warms);
905    init_matrix_prefixes(&mut long_mu, cols, &warms);
906    init_matrix_prefixes(&mut crossover_mu, cols, &warms);
907    init_matrix_prefixes(&mut crossunder_mu, cols, &warms);
908
909    let mut short_guard = ManuallyDrop::new(short_mu);
910    let mut long_guard = ManuallyDrop::new(long_mu);
911    let mut crossover_guard = ManuallyDrop::new(crossover_mu);
912    let mut crossunder_guard = ManuallyDrop::new(crossunder_mu);
913
914    let short_out =
915        unsafe { std::slice::from_raw_parts_mut(short_guard.as_mut_ptr() as *mut f64, total) };
916    let long_out =
917        unsafe { std::slice::from_raw_parts_mut(long_guard.as_mut_ptr() as *mut f64, total) };
918    let crossover_out =
919        unsafe { std::slice::from_raw_parts_mut(crossover_guard.as_mut_ptr() as *mut f64, total) };
920    let crossunder_out =
921        unsafe { std::slice::from_raw_parts_mut(crossunder_guard.as_mut_ptr() as *mut f64, total) };
922
923    if parallel {
924        #[cfg(not(target_arch = "wasm32"))]
925        {
926            short_out
927                .par_chunks_mut(cols)
928                .zip(long_out.par_chunks_mut(cols))
929                .zip(crossover_out.par_chunks_mut(cols))
930                .zip(crossunder_out.par_chunks_mut(cols))
931                .zip(combos.par_iter())
932                .for_each(
933                    |((((dst_short, dst_long), dst_crossover), dst_crossunder), combo)| {
934                        let _ = didi_index_row_from_slice(
935                            data,
936                            combo,
937                            dst_short,
938                            dst_long,
939                            dst_crossover,
940                            dst_crossunder,
941                        );
942                    },
943                );
944        }
945    } else {
946        let _ = kernel;
947        for (row, combo) in combos.iter().enumerate() {
948            let start = row * cols;
949            let end = start + cols;
950            didi_index_row_from_slice(
951                data,
952                combo,
953                &mut short_out[start..end],
954                &mut long_out[start..end],
955                &mut crossover_out[start..end],
956                &mut crossunder_out[start..end],
957            )?;
958        }
959    }
960
961    let short = unsafe {
962        Vec::from_raw_parts(
963            short_guard.as_mut_ptr() as *mut f64,
964            short_guard.len(),
965            short_guard.capacity(),
966        )
967    };
968    let long = unsafe {
969        Vec::from_raw_parts(
970            long_guard.as_mut_ptr() as *mut f64,
971            long_guard.len(),
972            long_guard.capacity(),
973        )
974    };
975    let crossover = unsafe {
976        Vec::from_raw_parts(
977            crossover_guard.as_mut_ptr() as *mut f64,
978            crossover_guard.len(),
979            crossover_guard.capacity(),
980        )
981    };
982    let crossunder = unsafe {
983        Vec::from_raw_parts(
984            crossunder_guard.as_mut_ptr() as *mut f64,
985            crossunder_guard.len(),
986            crossunder_guard.capacity(),
987        )
988    };
989    core::mem::forget(short_guard);
990    core::mem::forget(long_guard);
991    core::mem::forget(crossover_guard);
992    core::mem::forget(crossunder_guard);
993
994    Ok(DidiIndexBatchOutput {
995        short,
996        long,
997        crossover,
998        crossunder,
999        combos,
1000        rows,
1001        cols,
1002    })
1003}
1004
1005pub fn didi_index_batch_inner_into(
1006    data: &[f64],
1007    sweep: &DidiIndexBatchRange,
1008    kernel: Kernel,
1009    short_out: &mut [f64],
1010    long_out: &mut [f64],
1011    crossover_out: &mut [f64],
1012    crossunder_out: &mut [f64],
1013) -> Result<Vec<DidiIndexParams>, DidiIndexError> {
1014    let out = didi_index_batch_inner(data, sweep, kernel, false)?;
1015    let total = out.rows * out.cols;
1016    if short_out.len() != total
1017        || long_out.len() != total
1018        || crossover_out.len() != total
1019        || crossunder_out.len() != total
1020    {
1021        return Err(DidiIndexError::OutputLengthMismatch {
1022            expected: total,
1023            short_got: short_out.len(),
1024            long_got: long_out.len(),
1025            crossover_got: crossover_out.len(),
1026            crossunder_got: crossunder_out.len(),
1027        });
1028    }
1029    short_out.copy_from_slice(&out.short);
1030    long_out.copy_from_slice(&out.long);
1031    crossover_out.copy_from_slice(&out.crossover);
1032    crossunder_out.copy_from_slice(&out.crossunder);
1033    Ok(out.combos)
1034}
1035
1036#[cfg(feature = "python")]
1037#[pyfunction(name = "didi_index")]
1038#[pyo3(signature = (data, short_length=None, medium_length=None, long_length=None, kernel=None))]
1039pub fn didi_index_py<'py>(
1040    py: Python<'py>,
1041    data: PyReadonlyArray1<'py, f64>,
1042    short_length: Option<usize>,
1043    medium_length: Option<usize>,
1044    long_length: Option<usize>,
1045    kernel: Option<&str>,
1046) -> PyResult<(
1047    Bound<'py, PyArray1<f64>>,
1048    Bound<'py, PyArray1<f64>>,
1049    Bound<'py, PyArray1<f64>>,
1050    Bound<'py, PyArray1<f64>>,
1051)> {
1052    let data = data.as_slice()?;
1053    let kern = validate_kernel(kernel, false)?;
1054    let input = DidiIndexInput::from_slice(
1055        data,
1056        DidiIndexParams {
1057            short_length,
1058            medium_length,
1059            long_length,
1060        },
1061    );
1062    let out = py
1063        .allow_threads(|| didi_index_with_kernel(&input, kern))
1064        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1065    Ok((
1066        out.short.into_pyarray(py),
1067        out.long.into_pyarray(py),
1068        out.crossover.into_pyarray(py),
1069        out.crossunder.into_pyarray(py),
1070    ))
1071}
1072
1073#[cfg(feature = "python")]
1074#[pyclass(name = "DidiIndexStream")]
1075pub struct DidiIndexStreamPy {
1076    inner: DidiIndexStream,
1077}
1078
1079#[cfg(feature = "python")]
1080#[pymethods]
1081impl DidiIndexStreamPy {
1082    #[new]
1083    #[pyo3(signature = (short_length=DEFAULT_SHORT_LENGTH, medium_length=DEFAULT_MEDIUM_LENGTH, long_length=DEFAULT_LONG_LENGTH))]
1084    fn new(short_length: usize, medium_length: usize, long_length: usize) -> PyResult<Self> {
1085        let inner = DidiIndexStream::try_new(DidiIndexParams {
1086            short_length: Some(short_length),
1087            medium_length: Some(medium_length),
1088            long_length: Some(long_length),
1089        })
1090        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1091        Ok(Self { inner })
1092    }
1093
1094    fn update(&mut self, value: f64) -> Option<(f64, f64, f64, f64)> {
1095        self.inner.update(value)
1096    }
1097
1098    #[getter]
1099    fn warmup_period(&self) -> usize {
1100        self.inner.get_warmup_period()
1101    }
1102}
1103
1104#[cfg(feature = "python")]
1105#[pyfunction(name = "didi_index_batch")]
1106#[pyo3(signature = (data, short_length_range=(DEFAULT_SHORT_LENGTH, DEFAULT_SHORT_LENGTH, 0), medium_length_range=(DEFAULT_MEDIUM_LENGTH, DEFAULT_MEDIUM_LENGTH, 0), long_length_range=(DEFAULT_LONG_LENGTH, DEFAULT_LONG_LENGTH, 0), kernel=None))]
1107pub fn didi_index_batch_py<'py>(
1108    py: Python<'py>,
1109    data: PyReadonlyArray1<'py, f64>,
1110    short_length_range: (usize, usize, usize),
1111    medium_length_range: (usize, usize, usize),
1112    long_length_range: (usize, usize, usize),
1113    kernel: Option<&str>,
1114) -> PyResult<Bound<'py, PyDict>> {
1115    let data = data.as_slice()?;
1116    let kern = validate_kernel(kernel, true)?;
1117    let sweep = DidiIndexBatchRange {
1118        short_length: short_length_range,
1119        medium_length: medium_length_range,
1120        long_length: long_length_range,
1121    };
1122    let combos =
1123        expand_grid_didi_index(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1124    let rows = combos.len();
1125    let cols = data.len();
1126    let total = rows
1127        .checked_mul(cols)
1128        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1129
1130    let short_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1131    let long_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1132    let crossover_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1133    let crossunder_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1134    let short_slice = unsafe { short_arr.as_slice_mut()? };
1135    let long_slice = unsafe { long_arr.as_slice_mut()? };
1136    let crossover_slice = unsafe { crossover_arr.as_slice_mut()? };
1137    let crossunder_slice = unsafe { crossunder_arr.as_slice_mut()? };
1138
1139    let combos = py
1140        .allow_threads(|| {
1141            let batch = match kern {
1142                Kernel::Auto => detect_best_batch_kernel(),
1143                other => other,
1144            };
1145            didi_index_batch_inner_into(
1146                data,
1147                &sweep,
1148                batch.to_non_batch(),
1149                short_slice,
1150                long_slice,
1151                crossover_slice,
1152                crossunder_slice,
1153            )
1154        })
1155        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1156
1157    let dict = PyDict::new(py);
1158    dict.set_item("short", short_arr.reshape((rows, cols))?)?;
1159    dict.set_item("long", long_arr.reshape((rows, cols))?)?;
1160    dict.set_item("crossover", crossover_arr.reshape((rows, cols))?)?;
1161    dict.set_item("crossunder", crossunder_arr.reshape((rows, cols))?)?;
1162    dict.set_item(
1163        "short_lengths",
1164        combos
1165            .iter()
1166            .map(|p| p.short_length.unwrap_or(DEFAULT_SHORT_LENGTH) as u64)
1167            .collect::<Vec<_>>()
1168            .into_pyarray(py),
1169    )?;
1170    dict.set_item(
1171        "medium_lengths",
1172        combos
1173            .iter()
1174            .map(|p| p.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH) as u64)
1175            .collect::<Vec<_>>()
1176            .into_pyarray(py),
1177    )?;
1178    dict.set_item(
1179        "long_lengths",
1180        combos
1181            .iter()
1182            .map(|p| p.long_length.unwrap_or(DEFAULT_LONG_LENGTH) as u64)
1183            .collect::<Vec<_>>()
1184            .into_pyarray(py),
1185    )?;
1186    dict.set_item("rows", rows)?;
1187    dict.set_item("cols", cols)?;
1188    Ok(dict)
1189}
1190
1191#[cfg(feature = "python")]
1192pub fn register_didi_index_module(module: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
1193    module.add_function(wrap_pyfunction!(didi_index_py, module)?)?;
1194    module.add_function(wrap_pyfunction!(didi_index_batch_py, module)?)?;
1195    module.add_class::<DidiIndexStreamPy>()?;
1196    Ok(())
1197}
1198
1199#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1200#[wasm_bindgen(js_name = "didi_index_js")]
1201pub fn didi_index_js(
1202    data: &[f64],
1203    short_length: usize,
1204    medium_length: usize,
1205    long_length: usize,
1206) -> Result<JsValue, JsValue> {
1207    let input = DidiIndexInput::from_slice(
1208        data,
1209        DidiIndexParams {
1210            short_length: Some(short_length),
1211            medium_length: Some(medium_length),
1212            long_length: Some(long_length),
1213        },
1214    );
1215    let out = didi_index(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1216    let result = js_sys::Object::new();
1217
1218    let short = js_sys::Float64Array::new_with_length(out.short.len() as u32);
1219    short.copy_from(&out.short);
1220    js_sys::Reflect::set(&result, &JsValue::from_str("short"), &short)?;
1221
1222    let long = js_sys::Float64Array::new_with_length(out.long.len() as u32);
1223    long.copy_from(&out.long);
1224    js_sys::Reflect::set(&result, &JsValue::from_str("long"), &long)?;
1225
1226    let crossover = js_sys::Float64Array::new_with_length(out.crossover.len() as u32);
1227    crossover.copy_from(&out.crossover);
1228    js_sys::Reflect::set(&result, &JsValue::from_str("crossover"), &crossover)?;
1229
1230    let crossunder = js_sys::Float64Array::new_with_length(out.crossunder.len() as u32);
1231    crossunder.copy_from(&out.crossunder);
1232    js_sys::Reflect::set(&result, &JsValue::from_str("crossunder"), &crossunder)?;
1233
1234    Ok(result.into())
1235}
1236
1237#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1238#[wasm_bindgen]
1239pub fn didi_index_alloc(len: usize) -> *mut f64 {
1240    let mut vec = Vec::<f64>::with_capacity(len);
1241    let ptr = vec.as_mut_ptr();
1242    std::mem::forget(vec);
1243    ptr
1244}
1245
1246#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1247#[wasm_bindgen]
1248pub fn didi_index_free(ptr: *mut f64, len: usize) {
1249    if !ptr.is_null() {
1250        unsafe {
1251            let _ = Vec::from_raw_parts(ptr, len, len);
1252        }
1253    }
1254}
1255
1256#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1257#[wasm_bindgen]
1258pub fn didi_index_into(
1259    data_ptr: *const f64,
1260    short_ptr: *mut f64,
1261    long_ptr: *mut f64,
1262    crossover_ptr: *mut f64,
1263    crossunder_ptr: *mut f64,
1264    len: usize,
1265    short_length: usize,
1266    medium_length: usize,
1267    long_length: usize,
1268) -> Result<(), JsValue> {
1269    if data_ptr.is_null()
1270        || short_ptr.is_null()
1271        || long_ptr.is_null()
1272        || crossover_ptr.is_null()
1273        || crossunder_ptr.is_null()
1274    {
1275        return Err(JsValue::from_str("Null pointer provided"));
1276    }
1277
1278    unsafe {
1279        let data = std::slice::from_raw_parts(data_ptr, len);
1280        let input = DidiIndexInput::from_slice(
1281            data,
1282            DidiIndexParams {
1283                short_length: Some(short_length),
1284                medium_length: Some(medium_length),
1285                long_length: Some(long_length),
1286            },
1287        );
1288        let alias = data_ptr == short_ptr
1289            || data_ptr == long_ptr
1290            || data_ptr == crossover_ptr
1291            || data_ptr == crossunder_ptr;
1292        if alias {
1293            let mut short_tmp = vec![0.0; len];
1294            let mut long_tmp = vec![0.0; len];
1295            let mut crossover_tmp = vec![0.0; len];
1296            let mut crossunder_tmp = vec![0.0; len];
1297            didi_index_into_slices(
1298                &mut short_tmp,
1299                &mut long_tmp,
1300                &mut crossover_tmp,
1301                &mut crossunder_tmp,
1302                &input,
1303                Kernel::Auto,
1304            )
1305            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1306            std::slice::from_raw_parts_mut(short_ptr, len).copy_from_slice(&short_tmp);
1307            std::slice::from_raw_parts_mut(long_ptr, len).copy_from_slice(&long_tmp);
1308            std::slice::from_raw_parts_mut(crossover_ptr, len).copy_from_slice(&crossover_tmp);
1309            std::slice::from_raw_parts_mut(crossunder_ptr, len).copy_from_slice(&crossunder_tmp);
1310        } else {
1311            didi_index_into_slices(
1312                std::slice::from_raw_parts_mut(short_ptr, len),
1313                std::slice::from_raw_parts_mut(long_ptr, len),
1314                std::slice::from_raw_parts_mut(crossover_ptr, len),
1315                std::slice::from_raw_parts_mut(crossunder_ptr, len),
1316                &input,
1317                Kernel::Auto,
1318            )
1319            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1320        }
1321    }
1322    Ok(())
1323}
1324
1325#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1326#[derive(Serialize, Deserialize)]
1327pub struct DidiIndexBatchConfig {
1328    pub short_length_range: (usize, usize, usize),
1329    pub medium_length_range: Option<(usize, usize, usize)>,
1330    pub long_length_range: Option<(usize, usize, usize)>,
1331}
1332
1333#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1334#[derive(Serialize, Deserialize)]
1335pub struct DidiIndexBatchJsOutput {
1336    pub short: Vec<f64>,
1337    pub long: Vec<f64>,
1338    pub crossover: Vec<f64>,
1339    pub crossunder: Vec<f64>,
1340    pub combos: Vec<DidiIndexParams>,
1341    pub short_lengths: Vec<usize>,
1342    pub medium_lengths: Vec<usize>,
1343    pub long_lengths: Vec<usize>,
1344    pub rows: usize,
1345    pub cols: usize,
1346}
1347
1348#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1349#[wasm_bindgen(js_name = "didi_index_batch_js")]
1350pub fn didi_index_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1351    let config: DidiIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1352        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1353    let sweep = DidiIndexBatchRange {
1354        short_length: config.short_length_range,
1355        medium_length: config.medium_length_range.unwrap_or((
1356            DEFAULT_MEDIUM_LENGTH,
1357            DEFAULT_MEDIUM_LENGTH,
1358            0,
1359        )),
1360        long_length: config.long_length_range.unwrap_or((
1361            DEFAULT_LONG_LENGTH,
1362            DEFAULT_LONG_LENGTH,
1363            0,
1364        )),
1365    };
1366    let out = didi_index_batch_inner(
1367        data,
1368        &sweep,
1369        detect_best_batch_kernel().to_non_batch(),
1370        false,
1371    )
1372    .map_err(|e| JsValue::from_str(&e.to_string()))?;
1373    serde_wasm_bindgen::to_value(&DidiIndexBatchJsOutput {
1374        short_lengths: out
1375            .combos
1376            .iter()
1377            .map(|p| p.short_length.unwrap_or(DEFAULT_SHORT_LENGTH))
1378            .collect(),
1379        medium_lengths: out
1380            .combos
1381            .iter()
1382            .map(|p| p.medium_length.unwrap_or(DEFAULT_MEDIUM_LENGTH))
1383            .collect(),
1384        long_lengths: out
1385            .combos
1386            .iter()
1387            .map(|p| p.long_length.unwrap_or(DEFAULT_LONG_LENGTH))
1388            .collect(),
1389        short: out.short,
1390        long: out.long,
1391        crossover: out.crossover,
1392        crossunder: out.crossunder,
1393        combos: out.combos,
1394        rows: out.rows,
1395        cols: out.cols,
1396    })
1397    .map_err(|e| JsValue::from_str(&e.to_string()))
1398}
1399
1400#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1401#[wasm_bindgen]
1402pub fn didi_index_batch_into(
1403    data_ptr: *const f64,
1404    short_ptr: *mut f64,
1405    long_ptr: *mut f64,
1406    crossover_ptr: *mut f64,
1407    crossunder_ptr: *mut f64,
1408    len: usize,
1409    short_start: usize,
1410    short_end: usize,
1411    short_step: usize,
1412    medium_start: usize,
1413    medium_end: usize,
1414    medium_step: usize,
1415    long_start: usize,
1416    long_end: usize,
1417    long_step: usize,
1418) -> Result<usize, JsValue> {
1419    if data_ptr.is_null()
1420        || short_ptr.is_null()
1421        || long_ptr.is_null()
1422        || crossover_ptr.is_null()
1423        || crossunder_ptr.is_null()
1424    {
1425        return Err(JsValue::from_str("Null pointer provided"));
1426    }
1427
1428    let sweep = DidiIndexBatchRange {
1429        short_length: (short_start, short_end, short_step),
1430        medium_length: (medium_start, medium_end, medium_step),
1431        long_length: (long_start, long_end, long_step),
1432    };
1433    let combos = expand_grid_didi_index(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1434    let rows = combos.len();
1435    let total = rows
1436        .checked_mul(len)
1437        .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1438
1439    unsafe {
1440        let data = std::slice::from_raw_parts(data_ptr, len);
1441        didi_index_batch_inner_into(
1442            data,
1443            &sweep,
1444            detect_best_batch_kernel().to_non_batch(),
1445            std::slice::from_raw_parts_mut(short_ptr, total),
1446            std::slice::from_raw_parts_mut(long_ptr, total),
1447            std::slice::from_raw_parts_mut(crossover_ptr, total),
1448            std::slice::from_raw_parts_mut(crossunder_ptr, total),
1449        )
1450        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1451    }
1452    Ok(rows)
1453}
1454
1455#[cfg(test)]
1456mod tests {
1457    use super::*;
1458
1459    fn approx_eq(a: f64, b: f64) -> bool {
1460        (a - b).abs() <= 1e-12
1461    }
1462
1463    fn approx_eq_or_nan(a: f64, b: f64) -> bool {
1464        (a.is_nan() && b.is_nan()) || approx_eq(a, b)
1465    }
1466
1467    #[test]
1468    fn didi_index_matches_manual_ratios() {
1469        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1470        let input = DidiIndexInput::from_slice(
1471            &data,
1472            DidiIndexParams {
1473                short_length: Some(2),
1474                medium_length: Some(3),
1475                long_length: Some(4),
1476            },
1477        );
1478        let out = didi_index(&input).unwrap();
1479
1480        assert!(out.short[..3].iter().all(|v| v.is_nan()));
1481        assert!(out.long[..3].iter().all(|v| v.is_nan()));
1482        assert!(approx_eq(out.short[3], 3.5 / 3.0));
1483        assert!(approx_eq(out.long[3], 2.5 / 3.0));
1484        assert!(approx_eq(out.short[4], 4.5 / 4.0));
1485        assert!(approx_eq(out.long[4], 3.5 / 4.0));
1486        assert!(approx_eq(out.crossover[3], 0.0));
1487        assert!(approx_eq(out.crossunder[3], 0.0));
1488    }
1489
1490    #[test]
1491    fn didi_index_detects_crossover_and_crossunder() {
1492        let cross_up = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0];
1493        let up_input = DidiIndexInput::from_slice(
1494            &cross_up,
1495            DidiIndexParams {
1496                short_length: Some(2),
1497                medium_length: Some(3),
1498                long_length: Some(4),
1499            },
1500        );
1501        let up_out = didi_index(&up_input).unwrap();
1502        assert!(approx_eq(up_out.crossover[6], 1.0));
1503        assert!(approx_eq(up_out.crossunder[6], 0.0));
1504
1505        let cross_down = [1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0];
1506        let down_input = DidiIndexInput::from_slice(
1507            &cross_down,
1508            DidiIndexParams {
1509                short_length: Some(2),
1510                medium_length: Some(3),
1511                long_length: Some(4),
1512            },
1513        );
1514        let down_out = didi_index(&down_input).unwrap();
1515        assert!(approx_eq(down_out.crossunder[6], 1.0));
1516        assert!(approx_eq(down_out.crossover[6], 0.0));
1517    }
1518
1519    #[test]
1520    fn didi_index_stream_matches_batch_with_reset() {
1521        let data = [1.0, 2.0, 3.0, 4.0, 5.0, f64::NAN, 3.0, 4.0, 5.0, 6.0];
1522        let params = DidiIndexParams {
1523            short_length: Some(2),
1524            medium_length: Some(3),
1525            long_length: Some(4),
1526        };
1527        let input = DidiIndexInput::from_slice(&data, params.clone());
1528        let batch = didi_index(&input).unwrap();
1529        let mut stream = DidiIndexStream::try_new(params).unwrap();
1530
1531        let mut short = Vec::new();
1532        let mut long = Vec::new();
1533        let mut crossover = Vec::new();
1534        let mut crossunder = Vec::new();
1535        for &value in &data {
1536            match stream.update(value) {
1537                Some((s, l, co, cu)) => {
1538                    short.push(s);
1539                    long.push(l);
1540                    crossover.push(co);
1541                    crossunder.push(cu);
1542                }
1543                None => {
1544                    short.push(f64::NAN);
1545                    long.push(f64::NAN);
1546                    crossover.push(f64::NAN);
1547                    crossunder.push(f64::NAN);
1548                }
1549            }
1550        }
1551
1552        assert_eq!(stream.get_warmup_period(), 3);
1553        for i in 0..data.len() {
1554            assert!(approx_eq_or_nan(batch.short[i], short[i]));
1555            assert!(approx_eq_or_nan(batch.long[i], long[i]));
1556            assert!(approx_eq_or_nan(batch.crossover[i], crossover[i]));
1557            assert!(approx_eq_or_nan(batch.crossunder[i], crossunder[i]));
1558        }
1559        assert!(batch.short[5].is_nan());
1560        assert!(batch.short[8].is_nan());
1561        assert!(batch.short[9].is_finite());
1562    }
1563
1564    #[test]
1565    fn didi_index_batch_default_row_matches_single() {
1566        let data = [1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0, 2.0, 1.0];
1567        let batch = didi_index_batch_slice(
1568            &data,
1569            &DidiIndexBatchRange {
1570                short_length: (2, 2, 0),
1571                medium_length: (3, 3, 0),
1572                long_length: (4, 4, 0),
1573            },
1574        )
1575        .unwrap();
1576        let single = didi_index(&DidiIndexInput::from_slice(
1577            &data,
1578            DidiIndexParams {
1579                short_length: Some(2),
1580                medium_length: Some(3),
1581                long_length: Some(4),
1582            },
1583        ))
1584        .unwrap();
1585
1586        assert_eq!(batch.rows, 1);
1587        assert_eq!(batch.cols, data.len());
1588        assert_eq!(batch.short.len(), data.len());
1589        for i in 0..data.len() {
1590            assert!(approx_eq_or_nan(batch.short[i], single.short[i]));
1591            assert!(approx_eq_or_nan(batch.long[i], single.long[i]));
1592        }
1593    }
1594
1595    #[test]
1596    fn didi_index_rejects_invalid_lengths() {
1597        let data = [1.0, 2.0, 3.0];
1598        let err = didi_index(&DidiIndexInput::from_slice(
1599            &data,
1600            DidiIndexParams {
1601                short_length: Some(0),
1602                medium_length: Some(2),
1603                long_length: Some(3),
1604            },
1605        ))
1606        .unwrap_err();
1607        assert!(matches!(err, DidiIndexError::InvalidShortLength { .. }));
1608    }
1609}