Skip to main content

vector_ta/indicators/
squeeze_index.rs

1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(feature = "python")]
10use pyo3::wrap_pyfunction;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21    make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::mem::ManuallyDrop;
29use thiserror::Error;
30
31const DEFAULT_CONV: f64 = 50.0;
32const DEFAULT_LENGTH: usize = 20;
33const FLOAT_TOL: f64 = 1e-12;
34
35impl<'a> AsRef<[f64]> for SqueezeIndexInput<'a> {
36    #[inline(always)]
37    fn as_ref(&self) -> &[f64] {
38        match &self.data {
39            SqueezeIndexData::Slice(slice) => slice,
40            SqueezeIndexData::Candles { candles, source } => source_type(candles, source),
41        }
42    }
43}
44
45#[derive(Debug, Clone)]
46pub enum SqueezeIndexData<'a> {
47    Candles {
48        candles: &'a Candles,
49        source: &'a str,
50    },
51    Slice(&'a [f64]),
52}
53
54#[derive(Debug, Clone)]
55pub struct SqueezeIndexOutput {
56    pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone, PartialEq)]
60#[cfg_attr(
61    all(target_arch = "wasm32", feature = "wasm"),
62    derive(Serialize, Deserialize)
63)]
64pub struct SqueezeIndexParams {
65    pub conv: Option<f64>,
66    pub length: Option<usize>,
67}
68
69impl Default for SqueezeIndexParams {
70    fn default() -> Self {
71        Self {
72            conv: Some(DEFAULT_CONV),
73            length: Some(DEFAULT_LENGTH),
74        }
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct SqueezeIndexInput<'a> {
80    pub data: SqueezeIndexData<'a>,
81    pub params: SqueezeIndexParams,
82}
83
84impl<'a> SqueezeIndexInput<'a> {
85    #[inline]
86    pub fn from_candles(candles: &'a Candles, source: &'a str, params: SqueezeIndexParams) -> Self {
87        Self {
88            data: SqueezeIndexData::Candles { candles, source },
89            params,
90        }
91    }
92
93    #[inline]
94    pub fn from_slice(slice: &'a [f64], params: SqueezeIndexParams) -> Self {
95        Self {
96            data: SqueezeIndexData::Slice(slice),
97            params,
98        }
99    }
100
101    #[inline]
102    pub fn with_default_candles(candles: &'a Candles) -> Self {
103        Self::from_candles(candles, "close", SqueezeIndexParams::default())
104    }
105
106    #[inline]
107    pub fn get_conv(&self) -> f64 {
108        self.params.conv.unwrap_or(DEFAULT_CONV)
109    }
110
111    #[inline]
112    pub fn get_length(&self) -> usize {
113        self.params.length.unwrap_or(DEFAULT_LENGTH)
114    }
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct SqueezeIndexBuilder {
119    conv: Option<f64>,
120    length: Option<usize>,
121    kernel: Kernel,
122}
123
124impl Default for SqueezeIndexBuilder {
125    fn default() -> Self {
126        Self {
127            conv: None,
128            length: None,
129            kernel: Kernel::Auto,
130        }
131    }
132}
133
134impl SqueezeIndexBuilder {
135    #[inline]
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    #[inline]
141    pub fn conv(mut self, conv: f64) -> Self {
142        self.conv = Some(conv);
143        self
144    }
145
146    #[inline]
147    pub fn length(mut self, length: usize) -> Self {
148        self.length = Some(length);
149        self
150    }
151
152    #[inline]
153    pub fn kernel(mut self, kernel: Kernel) -> Self {
154        self.kernel = kernel;
155        self
156    }
157
158    #[inline]
159    pub fn apply(
160        self,
161        candles: &Candles,
162        source: &str,
163    ) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
164        let input = SqueezeIndexInput::from_candles(
165            candles,
166            source,
167            SqueezeIndexParams {
168                conv: self.conv,
169                length: self.length,
170            },
171        );
172        squeeze_index_with_kernel(&input, self.kernel)
173    }
174
175    #[inline]
176    pub fn apply_slice(self, data: &[f64]) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
177        let input = SqueezeIndexInput::from_slice(
178            data,
179            SqueezeIndexParams {
180                conv: self.conv,
181                length: self.length,
182            },
183        );
184        squeeze_index_with_kernel(&input, self.kernel)
185    }
186
187    #[inline]
188    pub fn into_stream(self) -> Result<SqueezeIndexStream, SqueezeIndexError> {
189        SqueezeIndexStream::try_new(SqueezeIndexParams {
190            conv: self.conv,
191            length: self.length,
192        })
193    }
194}
195
196#[derive(Debug, Error)]
197pub enum SqueezeIndexError {
198    #[error("squeeze_index: Input data slice is empty.")]
199    EmptyInputData,
200    #[error("squeeze_index: All values are NaN.")]
201    AllValuesNaN,
202    #[error("squeeze_index: Invalid conv: {conv}")]
203    InvalidConv { conv: f64 },
204    #[error("squeeze_index: Invalid length: length = {length}, data length = {data_len}")]
205    InvalidLength { length: usize, data_len: usize },
206    #[error("squeeze_index: Not enough valid data: needed = {needed}, valid = {valid}")]
207    NotEnoughValidData { needed: usize, valid: usize },
208    #[error("squeeze_index: Output length mismatch: expected = {expected}, got = {got}")]
209    OutputLengthMismatch { expected: usize, got: usize },
210    #[error("squeeze_index: Invalid range: start={start}, end={end}, step={step}")]
211    InvalidRange {
212        start: String,
213        end: String,
214        step: String,
215    },
216    #[error("squeeze_index: Invalid kernel for batch: {0:?}")]
217    InvalidKernelForBatch(Kernel),
218}
219
220#[derive(Debug, Clone, Copy)]
221struct ResolvedParams {
222    conv: f64,
223    length: usize,
224}
225
226#[derive(Debug, Clone)]
227pub struct SqueezeIndexStream {
228    params: ResolvedParams,
229    max_state: f64,
230    min_state: f64,
231    ring_vals: Vec<f64>,
232    ring_valid: Vec<u8>,
233    head: usize,
234    filled: usize,
235    valid_count: usize,
236    sum_x: f64,
237    sum_x2: f64,
238    weighted: f64,
239}
240
241impl SqueezeIndexStream {
242    pub fn try_new(params: SqueezeIndexParams) -> Result<Self, SqueezeIndexError> {
243        let params = resolve_params(&params, 0)?;
244        Ok(Self::new_resolved(params))
245    }
246
247    #[inline]
248    fn new_resolved(params: ResolvedParams) -> Self {
249        Self {
250            params,
251            max_state: 0.0,
252            min_state: 0.0,
253            ring_vals: vec![0.0; params.length],
254            ring_valid: vec![0u8; params.length],
255            head: 0,
256            filled: 0,
257            valid_count: 0,
258            sum_x: 0.0,
259            sum_x2: 0.0,
260            weighted: 0.0,
261        }
262    }
263
264    #[inline]
265    fn reset(&mut self) {
266        *self = Self::new_resolved(self.params);
267    }
268
269    #[inline]
270    pub fn get_warmup_period(&self) -> usize {
271        self.params.length.saturating_sub(1)
272    }
273
274    #[inline]
275    pub fn update(&mut self, value: f64) -> Option<f64> {
276        if !value.is_finite() {
277            self.reset();
278            return None;
279        }
280
281        let conv = self.params.conv;
282        let max_next = value.max(self.max_state - (self.max_state - value) / conv);
283        let min_next = value.min(self.min_state + (value - self.min_state) / conv);
284        self.max_state = max_next;
285        self.min_state = min_next;
286
287        let spread = max_next - min_next;
288        let diff = if spread > 0.0 { spread.ln() } else { f64::NAN };
289        self.push_diff(diff)
290    }
291
292    #[inline]
293    fn push_diff(&mut self, diff: f64) -> Option<f64> {
294        let is_valid = diff.is_finite();
295        let value = if is_valid { diff } else { 0.0 };
296        let n = self.params.length;
297
298        if self.filled < n {
299            let pos = self.filled;
300            self.ring_vals[pos] = value;
301            self.ring_valid[pos] = if is_valid { 1 } else { 0 };
302            self.sum_x += value;
303            self.sum_x2 += value * value;
304            self.weighted += pos as f64 * value;
305            if is_valid {
306                self.valid_count += 1;
307            }
308            self.filled += 1;
309            if self.filled < n {
310                return None;
311            }
312        } else {
313            let old_value = self.ring_vals[self.head];
314            let old_valid = self.ring_valid[self.head] as usize;
315            let old_sum = self.sum_x;
316
317            self.weighted =
318                self.weighted - old_sum + old_value + (n.saturating_sub(1) as f64) * value;
319            self.sum_x = old_sum - old_value + value;
320            self.sum_x2 = self.sum_x2 - old_value * old_value + value * value;
321            self.valid_count = self.valid_count + if is_valid { 1 } else { 0 } - old_valid;
322
323            self.ring_vals[self.head] = value;
324            self.ring_valid[self.head] = if is_valid { 1 } else { 0 };
325            self.head += 1;
326            if self.head == n {
327                self.head = 0;
328            }
329        }
330
331        if self.valid_count != n {
332            return Some(f64::NAN);
333        }
334
335        Some(psi_from_corr(self.sum_x, self.sum_x2, self.weighted, n))
336    }
337}
338
339#[inline]
340pub fn squeeze_index(input: &SqueezeIndexInput) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
341    squeeze_index_with_kernel(input, Kernel::Auto)
342}
343
344#[inline(always)]
345fn valid_value(value: f64) -> bool {
346    value.is_finite()
347}
348
349#[inline(always)]
350fn count_valid_values(data: &[f64]) -> usize {
351    data.iter().filter(|&&v| valid_value(v)).count()
352}
353
354#[inline(always)]
355fn first_valid_value(data: &[f64]) -> usize {
356    let mut i = 0usize;
357    while i < data.len() {
358        if valid_value(data[i]) {
359            return i;
360        }
361        i += 1;
362    }
363    data.len()
364}
365
366#[inline(always)]
367fn psi_from_corr(sum_x: f64, sum_x2: f64, weighted: f64, length: usize) -> f64 {
368    let n = length as f64;
369    let sum_y = n * (n - 1.0) * 0.5;
370    let sum_y2 = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
371    let denom_x = n * sum_x2 - sum_x * sum_x;
372    let denom_y = n * sum_y2 - sum_y * sum_y;
373    let denom = denom_x * denom_y;
374    if denom <= 0.0 || !denom.is_finite() {
375        return f64::NAN;
376    }
377    let corr = (n * weighted - sum_x * sum_y) / denom.sqrt();
378    -50.0 * corr + 50.0
379}
380
381#[inline]
382fn resolve_params(
383    params: &SqueezeIndexParams,
384    data_len: usize,
385) -> Result<ResolvedParams, SqueezeIndexError> {
386    let conv = params.conv.unwrap_or(DEFAULT_CONV);
387    if !conv.is_finite() || conv <= 1.0 {
388        return Err(SqueezeIndexError::InvalidConv { conv });
389    }
390    let length = params.length.unwrap_or(DEFAULT_LENGTH);
391    if length == 0 || (data_len > 0 && length > data_len) {
392        return Err(SqueezeIndexError::InvalidLength { length, data_len });
393    }
394    Ok(ResolvedParams { conv, length })
395}
396
397#[inline]
398fn squeeze_index_prepare<'a>(
399    input: &'a SqueezeIndexInput,
400    kernel: Kernel,
401) -> Result<(&'a [f64], ResolvedParams, Kernel), SqueezeIndexError> {
402    let data = input.as_ref();
403    let len = data.len();
404    if len == 0 {
405        return Err(SqueezeIndexError::EmptyInputData);
406    }
407
408    let first = first_valid_value(data);
409    if first >= len {
410        return Err(SqueezeIndexError::AllValuesNaN);
411    }
412
413    let params = resolve_params(&input.params, len)?;
414    let valid = count_valid_values(data);
415    if valid < params.length {
416        return Err(SqueezeIndexError::NotEnoughValidData {
417            needed: params.length,
418            valid,
419        });
420    }
421
422    let chosen = match kernel {
423        Kernel::Auto => detect_best_kernel(),
424        other => other.to_non_batch(),
425    };
426
427    Ok((data, params, chosen))
428}
429
430#[inline(always)]
431fn squeeze_index_row_from_slice(data: &[f64], params: ResolvedParams, out: &mut [f64]) {
432    out.fill(f64::NAN);
433    let mut stream = SqueezeIndexStream::new_resolved(params);
434    for (i, &value) in data.iter().enumerate() {
435        if let Some(out_value) = stream.update(value) {
436            out[i] = out_value;
437        }
438    }
439}
440
441#[inline]
442pub fn squeeze_index_with_kernel(
443    input: &SqueezeIndexInput,
444    kernel: Kernel,
445) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
446    let (data, params, _chosen) = squeeze_index_prepare(input, kernel)?;
447    let mut values = alloc_with_nan_prefix(data.len(), params.length.saturating_sub(1));
448    squeeze_index_row_from_slice(data, params, &mut values);
449    Ok(SqueezeIndexOutput { values })
450}
451
452#[inline]
453pub fn squeeze_index_into_slice(
454    dst: &mut [f64],
455    input: &SqueezeIndexInput,
456    kernel: Kernel,
457) -> Result<(), SqueezeIndexError> {
458    let (data, params, _chosen) = squeeze_index_prepare(input, kernel)?;
459    if dst.len() != data.len() {
460        return Err(SqueezeIndexError::OutputLengthMismatch {
461            expected: data.len(),
462            got: dst.len(),
463        });
464    }
465    squeeze_index_row_from_slice(data, params, dst);
466    Ok(())
467}
468
469#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
470#[inline]
471pub fn squeeze_index_into(
472    input: &SqueezeIndexInput,
473    out: &mut [f64],
474) -> Result<(), SqueezeIndexError> {
475    squeeze_index_into_slice(out, input, Kernel::Auto)
476}
477
478#[derive(Clone, Debug)]
479pub struct SqueezeIndexBatchRange {
480    pub conv: (f64, f64, f64),
481    pub length: (usize, usize, usize),
482}
483
484impl Default for SqueezeIndexBatchRange {
485    fn default() -> Self {
486        Self {
487            conv: (DEFAULT_CONV, DEFAULT_CONV, 0.0),
488            length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
489        }
490    }
491}
492
493#[derive(Clone, Debug, Default)]
494pub struct SqueezeIndexBatchBuilder {
495    range: SqueezeIndexBatchRange,
496    kernel: Kernel,
497}
498
499impl SqueezeIndexBatchBuilder {
500    #[inline]
501    pub fn new() -> Self {
502        Self::default()
503    }
504
505    #[inline]
506    pub fn kernel(mut self, kernel: Kernel) -> Self {
507        self.kernel = kernel;
508        self
509    }
510
511    #[inline]
512    pub fn conv_range(mut self, start: f64, end: f64, step: f64) -> Self {
513        self.range.conv = (start, end, step);
514        self
515    }
516
517    #[inline]
518    pub fn conv_static(mut self, conv: f64) -> Self {
519        self.range.conv = (conv, conv, 0.0);
520        self
521    }
522
523    #[inline]
524    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
525        self.range.length = (start, end, step);
526        self
527    }
528
529    #[inline]
530    pub fn length_static(mut self, length: usize) -> Self {
531        self.range.length = (length, length, 0);
532        self
533    }
534
535    #[inline]
536    pub fn apply_slice(self, data: &[f64]) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
537        squeeze_index_batch_with_kernel(data, &self.range, self.kernel)
538    }
539
540    #[inline]
541    pub fn apply_candles(
542        self,
543        candles: &Candles,
544        source: &str,
545    ) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
546        self.apply_slice(source_type(candles, source))
547    }
548
549    #[inline]
550    pub fn with_default_candles(
551        candles: &Candles,
552    ) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
553        SqueezeIndexBatchBuilder::new()
554            .kernel(Kernel::Auto)
555            .apply_candles(candles, "close")
556    }
557}
558
559#[derive(Clone, Debug)]
560pub struct SqueezeIndexBatchOutput {
561    pub values: Vec<f64>,
562    pub combos: Vec<SqueezeIndexParams>,
563    pub rows: usize,
564    pub cols: usize,
565}
566
567impl SqueezeIndexBatchOutput {
568    pub fn row_for_params(&self, params: &SqueezeIndexParams) -> Option<usize> {
569        self.combos.iter().position(|combo| {
570            (combo.conv.unwrap_or(DEFAULT_CONV) - params.conv.unwrap_or(DEFAULT_CONV)).abs()
571                <= FLOAT_TOL
572                && combo.length.unwrap_or(DEFAULT_LENGTH) == params.length.unwrap_or(DEFAULT_LENGTH)
573        })
574    }
575
576    pub fn values_for(&self, params: &SqueezeIndexParams) -> Option<&[f64]> {
577        self.row_for_params(params).and_then(|row| {
578            row.checked_mul(self.cols)
579                .and_then(|start| self.values.get(start..start + self.cols))
580        })
581    }
582}
583
584fn expand_axis_usize(
585    start: usize,
586    end: usize,
587    step: usize,
588) -> Result<Vec<usize>, SqueezeIndexError> {
589    if start == end {
590        return Ok(vec![start]);
591    }
592    if step == 0 {
593        return Err(SqueezeIndexError::InvalidRange {
594            start: start.to_string(),
595            end: end.to_string(),
596            step: step.to_string(),
597        });
598    }
599
600    let mut out = Vec::new();
601    if start < end {
602        let mut x = start;
603        while x <= end {
604            out.push(x);
605            let next = x.saturating_add(step);
606            if next == x {
607                break;
608            }
609            x = next;
610        }
611    } else {
612        let mut x = start;
613        while x >= end {
614            out.push(x);
615            let next = x.saturating_sub(step);
616            if next == x {
617                break;
618            }
619            x = next;
620        }
621    }
622
623    if out.is_empty() {
624        return Err(SqueezeIndexError::InvalidRange {
625            start: start.to_string(),
626            end: end.to_string(),
627            step: step.to_string(),
628        });
629    }
630    Ok(out)
631}
632
633fn expand_axis_f64(start: f64, end: f64, step: f64) -> Result<Vec<f64>, SqueezeIndexError> {
634    if !start.is_finite() || !end.is_finite() || !step.is_finite() {
635        return Err(SqueezeIndexError::InvalidRange {
636            start: start.to_string(),
637            end: end.to_string(),
638            step: step.to_string(),
639        });
640    }
641    if (start - end).abs() <= FLOAT_TOL {
642        return Ok(vec![start]);
643    }
644    if step == 0.0 {
645        return Err(SqueezeIndexError::InvalidRange {
646            start: start.to_string(),
647            end: end.to_string(),
648            step: step.to_string(),
649        });
650    }
651
652    let mut out = Vec::new();
653    if start < end {
654        let mut x = start;
655        while x <= end + FLOAT_TOL {
656            out.push(x);
657            x += step.abs();
658        }
659    } else {
660        let mut x = start;
661        while x >= end - FLOAT_TOL {
662            out.push(x);
663            x -= step.abs();
664        }
665    }
666
667    if out.is_empty() {
668        return Err(SqueezeIndexError::InvalidRange {
669            start: start.to_string(),
670            end: end.to_string(),
671            step: step.to_string(),
672        });
673    }
674    Ok(out)
675}
676
677fn expand_grid_squeeze_index(
678    range: &SqueezeIndexBatchRange,
679) -> Result<Vec<SqueezeIndexParams>, SqueezeIndexError> {
680    let convs = expand_axis_f64(range.conv.0, range.conv.1, range.conv.2)?;
681    let lengths = expand_axis_usize(range.length.0, range.length.1, range.length.2)?;
682
683    let mut combos = Vec::with_capacity(convs.len().saturating_mul(lengths.len()));
684    for &conv in &convs {
685        if !conv.is_finite() || conv <= 1.0 {
686            return Err(SqueezeIndexError::InvalidConv { conv });
687        }
688        for &length in &lengths {
689            if length == 0 {
690                return Err(SqueezeIndexError::InvalidLength {
691                    length,
692                    data_len: 0,
693                });
694            }
695            combos.push(SqueezeIndexParams {
696                conv: Some(conv),
697                length: Some(length),
698            });
699        }
700    }
701    Ok(combos)
702}
703
704#[inline]
705pub fn squeeze_index_batch_with_kernel(
706    data: &[f64],
707    sweep: &SqueezeIndexBatchRange,
708    kernel: Kernel,
709) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
710    let batch_kernel = match kernel {
711        Kernel::Auto => detect_best_batch_kernel(),
712        other if other.is_batch() => other,
713        other => return Err(SqueezeIndexError::InvalidKernelForBatch(other)),
714    };
715    squeeze_index_batch_par_slice(data, sweep, batch_kernel.to_non_batch())
716}
717
718#[inline]
719pub fn squeeze_index_batch_slice(
720    data: &[f64],
721    sweep: &SqueezeIndexBatchRange,
722    kernel: Kernel,
723) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
724    squeeze_index_batch_inner(data, sweep, kernel, false)
725}
726
727#[inline]
728pub fn squeeze_index_batch_par_slice(
729    data: &[f64],
730    sweep: &SqueezeIndexBatchRange,
731    kernel: Kernel,
732) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
733    squeeze_index_batch_inner(data, sweep, kernel, true)
734}
735
736#[inline(always)]
737fn squeeze_index_batch_inner(
738    data: &[f64],
739    sweep: &SqueezeIndexBatchRange,
740    _kernel: Kernel,
741    parallel: bool,
742) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
743    let combos = expand_grid_squeeze_index(sweep)?;
744    let rows = combos.len();
745    let cols = data.len();
746    if cols == 0 {
747        return Err(SqueezeIndexError::EmptyInputData);
748    }
749
750    let first = first_valid_value(data);
751    if first >= cols {
752        return Err(SqueezeIndexError::AllValuesNaN);
753    }
754    let valid = count_valid_values(data);
755    let max_length = combos
756        .iter()
757        .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
758        .max()
759        .unwrap_or(0);
760    if valid < max_length {
761        return Err(SqueezeIndexError::NotEnoughValidData {
762            needed: max_length,
763            valid,
764        });
765    }
766
767    let mut buf_mu = make_uninit_matrix(rows, cols);
768    let warmups: Vec<usize> = combos
769        .iter()
770        .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH).saturating_sub(1))
771        .collect();
772    init_matrix_prefixes(&mut buf_mu, cols, &warmups);
773
774    let mut guard = ManuallyDrop::new(buf_mu);
775    let out: &mut [f64] =
776        unsafe { std::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
777
778    if parallel {
779        #[cfg(not(target_arch = "wasm32"))]
780        out.par_chunks_mut(cols)
781            .enumerate()
782            .for_each(|(row, out_row)| {
783                let params = resolve_params(&combos[row], cols).unwrap();
784                squeeze_index_row_from_slice(data, params, out_row);
785            });
786
787        #[cfg(target_arch = "wasm32")]
788        for (row, out_row) in out.chunks_mut(cols).enumerate() {
789            let params = resolve_params(&combos[row], cols).unwrap();
790            squeeze_index_row_from_slice(data, params, out_row);
791        }
792    } else {
793        for (row, out_row) in out.chunks_mut(cols).enumerate() {
794            let params = resolve_params(&combos[row], cols).unwrap();
795            squeeze_index_row_from_slice(data, params, out_row);
796        }
797    }
798
799    let values = unsafe {
800        Vec::from_raw_parts(
801            guard.as_mut_ptr() as *mut f64,
802            guard.len(),
803            guard.capacity(),
804        )
805    };
806
807    Ok(SqueezeIndexBatchOutput {
808        values,
809        combos,
810        rows,
811        cols,
812    })
813}
814
815#[inline(always)]
816pub fn squeeze_index_batch_inner_into(
817    data: &[f64],
818    sweep: &SqueezeIndexBatchRange,
819    _kernel: Kernel,
820    parallel: bool,
821    out: &mut [f64],
822) -> Result<Vec<SqueezeIndexParams>, SqueezeIndexError> {
823    let combos = expand_grid_squeeze_index(sweep)?;
824    let rows = combos.len();
825    let cols = data.len();
826    if cols == 0 {
827        return Err(SqueezeIndexError::EmptyInputData);
828    }
829
830    let total = rows
831        .checked_mul(cols)
832        .ok_or_else(|| SqueezeIndexError::OutputLengthMismatch {
833            expected: usize::MAX,
834            got: out.len(),
835        })?;
836    if out.len() != total {
837        return Err(SqueezeIndexError::OutputLengthMismatch {
838            expected: total,
839            got: out.len(),
840        });
841    }
842
843    let first = first_valid_value(data);
844    if first >= cols {
845        return Err(SqueezeIndexError::AllValuesNaN);
846    }
847    let valid = count_valid_values(data);
848    let max_length = combos
849        .iter()
850        .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
851        .max()
852        .unwrap_or(0);
853    if valid < max_length {
854        return Err(SqueezeIndexError::NotEnoughValidData {
855            needed: max_length,
856            valid,
857        });
858    }
859
860    if parallel {
861        #[cfg(not(target_arch = "wasm32"))]
862        out.par_chunks_mut(cols)
863            .enumerate()
864            .for_each(|(row, out_row)| {
865                let params = resolve_params(&combos[row], cols).unwrap();
866                squeeze_index_row_from_slice(data, params, out_row);
867            });
868
869        #[cfg(target_arch = "wasm32")]
870        for (row, out_row) in out.chunks_mut(cols).enumerate() {
871            let params = resolve_params(&combos[row], cols).unwrap();
872            squeeze_index_row_from_slice(data, params, out_row);
873        }
874    } else {
875        for (row, out_row) in out.chunks_mut(cols).enumerate() {
876            let params = resolve_params(&combos[row], cols).unwrap();
877            squeeze_index_row_from_slice(data, params, out_row);
878        }
879    }
880
881    Ok(combos)
882}
883
884#[cfg(feature = "python")]
885#[pyfunction(name = "squeeze_index")]
886#[pyo3(signature = (data, conv=DEFAULT_CONV, length=DEFAULT_LENGTH, kernel=None))]
887pub fn squeeze_index_py<'py>(
888    py: Python<'py>,
889    data: PyReadonlyArray1<'py, f64>,
890    conv: f64,
891    length: usize,
892    kernel: Option<&str>,
893) -> PyResult<Bound<'py, PyArray1<f64>>> {
894    let data = data.as_slice()?;
895    let kernel = validate_kernel(kernel, false)?;
896    let input = SqueezeIndexInput::from_slice(
897        data,
898        SqueezeIndexParams {
899            conv: Some(conv),
900            length: Some(length),
901        },
902    );
903    let output = py
904        .allow_threads(|| squeeze_index_with_kernel(&input, kernel))
905        .map_err(|e| PyValueError::new_err(e.to_string()))?;
906    Ok(output.values.into_pyarray(py))
907}
908
909#[cfg(feature = "python")]
910#[pyclass(name = "SqueezeIndexStream")]
911pub struct SqueezeIndexStreamPy {
912    stream: SqueezeIndexStream,
913}
914
915#[cfg(feature = "python")]
916#[pymethods]
917impl SqueezeIndexStreamPy {
918    #[new]
919    #[pyo3(signature = (conv=DEFAULT_CONV, length=DEFAULT_LENGTH))]
920    fn new(conv: f64, length: usize) -> PyResult<Self> {
921        let stream = SqueezeIndexStream::try_new(SqueezeIndexParams {
922            conv: Some(conv),
923            length: Some(length),
924        })
925        .map_err(|e| PyValueError::new_err(e.to_string()))?;
926        Ok(Self { stream })
927    }
928
929    fn update(&mut self, value: f64) -> Option<f64> {
930        self.stream.update(value)
931    }
932}
933
934#[cfg(feature = "python")]
935#[pyfunction(name = "squeeze_index_batch")]
936#[pyo3(signature = (data, conv_range, length_range, kernel=None))]
937pub fn squeeze_index_batch_py<'py>(
938    py: Python<'py>,
939    data: PyReadonlyArray1<'py, f64>,
940    conv_range: (f64, f64, f64),
941    length_range: (usize, usize, usize),
942    kernel: Option<&str>,
943) -> PyResult<Bound<'py, PyDict>> {
944    let data = data.as_slice()?;
945    let kernel = validate_kernel(kernel, true)?;
946    let sweep = SqueezeIndexBatchRange {
947        conv: conv_range,
948        length: length_range,
949    };
950
951    let combos =
952        expand_grid_squeeze_index(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
953    let rows = combos.len();
954    let cols = data.len();
955    let total = rows
956        .checked_mul(cols)
957        .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
958
959    let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
960    let slice_out = unsafe { out_arr.as_slice_mut()? };
961
962    let combos = py
963        .allow_threads(|| {
964            let batch = match kernel {
965                Kernel::Auto => detect_best_batch_kernel(),
966                other => other,
967            };
968            squeeze_index_batch_inner_into(data, &sweep, batch.to_non_batch(), true, slice_out)
969        })
970        .map_err(|e| PyValueError::new_err(e.to_string()))?;
971
972    let dict = PyDict::new(py);
973    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
974    dict.set_item(
975        "convs",
976        combos
977            .iter()
978            .map(|combo| combo.conv.unwrap_or(DEFAULT_CONV))
979            .collect::<Vec<_>>()
980            .into_pyarray(py),
981    )?;
982    dict.set_item(
983        "lengths",
984        combos
985            .iter()
986            .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH) as u64)
987            .collect::<Vec<_>>()
988            .into_pyarray(py),
989    )?;
990    dict.set_item("rows", rows)?;
991    dict.set_item("cols", cols)?;
992    Ok(dict)
993}
994
995#[cfg(feature = "python")]
996pub fn register_squeeze_index_module(module: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
997    module.add_function(wrap_pyfunction!(squeeze_index_py, module)?)?;
998    module.add_function(wrap_pyfunction!(squeeze_index_batch_py, module)?)?;
999    module.add_class::<SqueezeIndexStreamPy>()?;
1000    Ok(())
1001}
1002
1003#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1004#[wasm_bindgen(js_name = "squeeze_index_js")]
1005pub fn squeeze_index_js(data: &[f64], conv: f64, length: usize) -> Result<Vec<f64>, JsValue> {
1006    let input = SqueezeIndexInput::from_slice(
1007        data,
1008        SqueezeIndexParams {
1009            conv: Some(conv),
1010            length: Some(length),
1011        },
1012    );
1013    let mut output = vec![0.0; data.len()];
1014    squeeze_index_into_slice(&mut output, &input, Kernel::Auto)
1015        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1016    Ok(output)
1017}
1018
1019#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1020#[wasm_bindgen]
1021pub fn squeeze_index_alloc(len: usize) -> *mut f64 {
1022    let mut vec = Vec::<f64>::with_capacity(len);
1023    let ptr = vec.as_mut_ptr();
1024    std::mem::forget(vec);
1025    ptr
1026}
1027
1028#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1029#[wasm_bindgen]
1030pub fn squeeze_index_free(ptr: *mut f64, len: usize) {
1031    if !ptr.is_null() {
1032        unsafe {
1033            let _ = Vec::from_raw_parts(ptr, len, len);
1034        }
1035    }
1036}
1037
1038#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1039#[wasm_bindgen]
1040pub fn squeeze_index_into(
1041    data_ptr: *const f64,
1042    out_ptr: *mut f64,
1043    len: usize,
1044    conv: f64,
1045    length: usize,
1046) -> Result<(), JsValue> {
1047    if data_ptr.is_null() || out_ptr.is_null() {
1048        return Err(JsValue::from_str("Null pointer provided"));
1049    }
1050
1051    unsafe {
1052        let data = std::slice::from_raw_parts(data_ptr, len);
1053        let input = SqueezeIndexInput::from_slice(
1054            data,
1055            SqueezeIndexParams {
1056                conv: Some(conv),
1057                length: Some(length),
1058            },
1059        );
1060        if data_ptr == out_ptr {
1061            let mut tmp = vec![0.0; len];
1062            squeeze_index_into_slice(&mut tmp, &input, Kernel::Auto)
1063                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1064            std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1065        } else {
1066            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1067            squeeze_index_into_slice(out, &input, Kernel::Auto)
1068                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1069        }
1070    }
1071
1072    Ok(())
1073}
1074
1075#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1076#[derive(Serialize, Deserialize)]
1077pub struct SqueezeIndexBatchConfig {
1078    pub conv_range: (f64, f64, f64),
1079    pub length_range: (usize, usize, usize),
1080}
1081
1082#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1083#[derive(Serialize, Deserialize)]
1084pub struct SqueezeIndexBatchJsOutput {
1085    pub values: Vec<f64>,
1086    pub combos: Vec<SqueezeIndexParams>,
1087    pub convs: Vec<f64>,
1088    pub lengths: Vec<usize>,
1089    pub rows: usize,
1090    pub cols: usize,
1091}
1092
1093#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1094#[wasm_bindgen(js_name = "squeeze_index_batch_js")]
1095pub fn squeeze_index_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1096    let config: SqueezeIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1097        .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1098    let sweep = SqueezeIndexBatchRange {
1099        conv: config.conv_range,
1100        length: config.length_range,
1101    };
1102    let output = squeeze_index_batch_inner(data, &sweep, detect_best_kernel(), false)
1103        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1104
1105    serde_wasm_bindgen::to_value(&SqueezeIndexBatchJsOutput {
1106        convs: output
1107            .combos
1108            .iter()
1109            .map(|combo| combo.conv.unwrap_or(DEFAULT_CONV))
1110            .collect(),
1111        lengths: output
1112            .combos
1113            .iter()
1114            .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
1115            .collect(),
1116        values: output.values,
1117        combos: output.combos,
1118        rows: output.rows,
1119        cols: output.cols,
1120    })
1121    .map_err(|e| JsValue::from_str(&e.to_string()))
1122}
1123
1124#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1125#[wasm_bindgen]
1126pub fn squeeze_index_batch_into(
1127    data_ptr: *const f64,
1128    out_ptr: *mut f64,
1129    len: usize,
1130    conv_start: f64,
1131    conv_end: f64,
1132    conv_step: f64,
1133    length_start: usize,
1134    length_end: usize,
1135    length_step: usize,
1136) -> Result<usize, JsValue> {
1137    if data_ptr.is_null() || out_ptr.is_null() {
1138        return Err(JsValue::from_str("Null pointer provided"));
1139    }
1140
1141    let sweep = SqueezeIndexBatchRange {
1142        conv: (conv_start, conv_end, conv_step),
1143        length: (length_start, length_end, length_step),
1144    };
1145    let combos =
1146        expand_grid_squeeze_index(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1147    let rows = combos.len();
1148
1149    unsafe {
1150        let data = std::slice::from_raw_parts(data_ptr, len);
1151        let total = rows
1152            .checked_mul(len)
1153            .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1154        let out = std::slice::from_raw_parts_mut(out_ptr, total);
1155        squeeze_index_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1156            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1157    }
1158
1159    Ok(rows)
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164    use super::*;
1165    use crate::utilities::data_loader::read_candles_from_csv;
1166    use std::error::Error;
1167
1168    fn load_close() -> Result<Vec<f64>, Box<dyn Error>> {
1169        let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
1170        Ok(candles.close)
1171    }
1172
1173    fn approx_eq_or_nan(lhs: &[f64], rhs: &[f64], tol: f64) {
1174        assert_eq!(lhs.len(), rhs.len());
1175        for (i, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
1176            if a.is_nan() && b.is_nan() {
1177                continue;
1178            }
1179            assert!(
1180                (a - b).abs() <= tol,
1181                "mismatch at {i}: lhs={a} rhs={b} tol={tol}"
1182            );
1183        }
1184    }
1185
1186    #[test]
1187    fn squeeze_index_output_contract() -> Result<(), Box<dyn Error>> {
1188        let close = load_close()?;
1189        let input = SqueezeIndexInput::from_slice(
1190            &close,
1191            SqueezeIndexParams {
1192                conv: Some(50.0),
1193                length: Some(20),
1194            },
1195        );
1196        let out = squeeze_index_with_kernel(&input, Kernel::Scalar)?;
1197        assert_eq!(out.values.len(), close.len());
1198        let first_valid = out.values.iter().position(|v| !v.is_nan()).unwrap();
1199        assert!(first_valid >= 19);
1200        assert!(out.values[first_valid..].iter().any(|v| v.is_finite()));
1201        Ok(())
1202    }
1203
1204    #[test]
1205    fn squeeze_index_auto_matches_scalar() -> Result<(), Box<dyn Error>> {
1206        let close = load_close()?;
1207        let input = SqueezeIndexInput::from_slice(
1208            &close,
1209            SqueezeIndexParams {
1210                conv: Some(50.0),
1211                length: Some(20),
1212            },
1213        );
1214        let auto = squeeze_index_with_kernel(&input, Kernel::Auto)?;
1215        let scalar = squeeze_index_with_kernel(&input, Kernel::Scalar)?;
1216        approx_eq_or_nan(&auto.values, &scalar.values, 1e-12);
1217        Ok(())
1218    }
1219
1220    #[test]
1221    fn squeeze_index_stream_matches_batch_and_recovers_after_nan() -> Result<(), Box<dyn Error>> {
1222        let mut close = load_close()?;
1223        close[64] = f64::NAN;
1224        let batch = squeeze_index_with_kernel(
1225            &SqueezeIndexInput::from_slice(
1226                &close,
1227                SqueezeIndexParams {
1228                    conv: Some(50.0),
1229                    length: Some(10),
1230                },
1231            ),
1232            Kernel::Scalar,
1233        )?;
1234        let mut stream = SqueezeIndexStream::try_new(SqueezeIndexParams {
1235            conv: Some(50.0),
1236            length: Some(10),
1237        })?;
1238        let mut streamed = vec![f64::NAN; close.len()];
1239        for (i, &value) in close.iter().enumerate() {
1240            if let Some(out) = stream.update(value) {
1241                streamed[i] = out;
1242            }
1243        }
1244        approx_eq_or_nan(&streamed, &batch.values, 1e-12);
1245        assert!(streamed[64].is_nan());
1246        assert!(streamed[72].is_nan());
1247        assert!(streamed[74].is_finite());
1248        Ok(())
1249    }
1250
1251    #[test]
1252    fn squeeze_index_batch_single_param_matches_single() -> Result<(), Box<dyn Error>> {
1253        let close = load_close()?;
1254        let batch = squeeze_index_batch_with_kernel(
1255            &close,
1256            &SqueezeIndexBatchRange {
1257                conv: (50.0, 50.0, 0.0),
1258                length: (20, 20, 0),
1259            },
1260            Kernel::Auto,
1261        )?;
1262        assert_eq!(batch.rows, 1);
1263        assert_eq!(batch.cols, close.len());
1264        let single = squeeze_index_with_kernel(
1265            &SqueezeIndexInput::from_slice(
1266                &close,
1267                SqueezeIndexParams {
1268                    conv: Some(50.0),
1269                    length: Some(20),
1270                },
1271            ),
1272            Kernel::Auto,
1273        )?;
1274        approx_eq_or_nan(&batch.values, &single.values, 1e-12);
1275        Ok(())
1276    }
1277}