1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::PyDict;
9#[cfg(feature = "python")]
10use pyo3::wrap_pyfunction;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17use crate::utilities::data_loader::{source_type, Candles};
18use crate::utilities::enums::Kernel;
19use crate::utilities::helpers::{
20 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
21 make_uninit_matrix,
22};
23#[cfg(feature = "python")]
24use crate::utilities::kernel_validation::validate_kernel;
25#[cfg(not(target_arch = "wasm32"))]
26use rayon::prelude::*;
27use std::convert::AsRef;
28use std::mem::ManuallyDrop;
29use thiserror::Error;
30
31const DEFAULT_CONV: f64 = 50.0;
32const DEFAULT_LENGTH: usize = 20;
33const FLOAT_TOL: f64 = 1e-12;
34
35impl<'a> AsRef<[f64]> for SqueezeIndexInput<'a> {
36 #[inline(always)]
37 fn as_ref(&self) -> &[f64] {
38 match &self.data {
39 SqueezeIndexData::Slice(slice) => slice,
40 SqueezeIndexData::Candles { candles, source } => source_type(candles, source),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
46pub enum SqueezeIndexData<'a> {
47 Candles {
48 candles: &'a Candles,
49 source: &'a str,
50 },
51 Slice(&'a [f64]),
52}
53
54#[derive(Debug, Clone)]
55pub struct SqueezeIndexOutput {
56 pub values: Vec<f64>,
57}
58
59#[derive(Debug, Clone, PartialEq)]
60#[cfg_attr(
61 all(target_arch = "wasm32", feature = "wasm"),
62 derive(Serialize, Deserialize)
63)]
64pub struct SqueezeIndexParams {
65 pub conv: Option<f64>,
66 pub length: Option<usize>,
67}
68
69impl Default for SqueezeIndexParams {
70 fn default() -> Self {
71 Self {
72 conv: Some(DEFAULT_CONV),
73 length: Some(DEFAULT_LENGTH),
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
79pub struct SqueezeIndexInput<'a> {
80 pub data: SqueezeIndexData<'a>,
81 pub params: SqueezeIndexParams,
82}
83
84impl<'a> SqueezeIndexInput<'a> {
85 #[inline]
86 pub fn from_candles(candles: &'a Candles, source: &'a str, params: SqueezeIndexParams) -> Self {
87 Self {
88 data: SqueezeIndexData::Candles { candles, source },
89 params,
90 }
91 }
92
93 #[inline]
94 pub fn from_slice(slice: &'a [f64], params: SqueezeIndexParams) -> Self {
95 Self {
96 data: SqueezeIndexData::Slice(slice),
97 params,
98 }
99 }
100
101 #[inline]
102 pub fn with_default_candles(candles: &'a Candles) -> Self {
103 Self::from_candles(candles, "close", SqueezeIndexParams::default())
104 }
105
106 #[inline]
107 pub fn get_conv(&self) -> f64 {
108 self.params.conv.unwrap_or(DEFAULT_CONV)
109 }
110
111 #[inline]
112 pub fn get_length(&self) -> usize {
113 self.params.length.unwrap_or(DEFAULT_LENGTH)
114 }
115}
116
117#[derive(Copy, Clone, Debug)]
118pub struct SqueezeIndexBuilder {
119 conv: Option<f64>,
120 length: Option<usize>,
121 kernel: Kernel,
122}
123
124impl Default for SqueezeIndexBuilder {
125 fn default() -> Self {
126 Self {
127 conv: None,
128 length: None,
129 kernel: Kernel::Auto,
130 }
131 }
132}
133
134impl SqueezeIndexBuilder {
135 #[inline]
136 pub fn new() -> Self {
137 Self::default()
138 }
139
140 #[inline]
141 pub fn conv(mut self, conv: f64) -> Self {
142 self.conv = Some(conv);
143 self
144 }
145
146 #[inline]
147 pub fn length(mut self, length: usize) -> Self {
148 self.length = Some(length);
149 self
150 }
151
152 #[inline]
153 pub fn kernel(mut self, kernel: Kernel) -> Self {
154 self.kernel = kernel;
155 self
156 }
157
158 #[inline]
159 pub fn apply(
160 self,
161 candles: &Candles,
162 source: &str,
163 ) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
164 let input = SqueezeIndexInput::from_candles(
165 candles,
166 source,
167 SqueezeIndexParams {
168 conv: self.conv,
169 length: self.length,
170 },
171 );
172 squeeze_index_with_kernel(&input, self.kernel)
173 }
174
175 #[inline]
176 pub fn apply_slice(self, data: &[f64]) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
177 let input = SqueezeIndexInput::from_slice(
178 data,
179 SqueezeIndexParams {
180 conv: self.conv,
181 length: self.length,
182 },
183 );
184 squeeze_index_with_kernel(&input, self.kernel)
185 }
186
187 #[inline]
188 pub fn into_stream(self) -> Result<SqueezeIndexStream, SqueezeIndexError> {
189 SqueezeIndexStream::try_new(SqueezeIndexParams {
190 conv: self.conv,
191 length: self.length,
192 })
193 }
194}
195
196#[derive(Debug, Error)]
197pub enum SqueezeIndexError {
198 #[error("squeeze_index: Input data slice is empty.")]
199 EmptyInputData,
200 #[error("squeeze_index: All values are NaN.")]
201 AllValuesNaN,
202 #[error("squeeze_index: Invalid conv: {conv}")]
203 InvalidConv { conv: f64 },
204 #[error("squeeze_index: Invalid length: length = {length}, data length = {data_len}")]
205 InvalidLength { length: usize, data_len: usize },
206 #[error("squeeze_index: Not enough valid data: needed = {needed}, valid = {valid}")]
207 NotEnoughValidData { needed: usize, valid: usize },
208 #[error("squeeze_index: Output length mismatch: expected = {expected}, got = {got}")]
209 OutputLengthMismatch { expected: usize, got: usize },
210 #[error("squeeze_index: Invalid range: start={start}, end={end}, step={step}")]
211 InvalidRange {
212 start: String,
213 end: String,
214 step: String,
215 },
216 #[error("squeeze_index: Invalid kernel for batch: {0:?}")]
217 InvalidKernelForBatch(Kernel),
218}
219
220#[derive(Debug, Clone, Copy)]
221struct ResolvedParams {
222 conv: f64,
223 length: usize,
224}
225
226#[derive(Debug, Clone)]
227pub struct SqueezeIndexStream {
228 params: ResolvedParams,
229 max_state: f64,
230 min_state: f64,
231 ring_vals: Vec<f64>,
232 ring_valid: Vec<u8>,
233 head: usize,
234 filled: usize,
235 valid_count: usize,
236 sum_x: f64,
237 sum_x2: f64,
238 weighted: f64,
239}
240
241impl SqueezeIndexStream {
242 pub fn try_new(params: SqueezeIndexParams) -> Result<Self, SqueezeIndexError> {
243 let params = resolve_params(¶ms, 0)?;
244 Ok(Self::new_resolved(params))
245 }
246
247 #[inline]
248 fn new_resolved(params: ResolvedParams) -> Self {
249 Self {
250 params,
251 max_state: 0.0,
252 min_state: 0.0,
253 ring_vals: vec![0.0; params.length],
254 ring_valid: vec![0u8; params.length],
255 head: 0,
256 filled: 0,
257 valid_count: 0,
258 sum_x: 0.0,
259 sum_x2: 0.0,
260 weighted: 0.0,
261 }
262 }
263
264 #[inline]
265 fn reset(&mut self) {
266 *self = Self::new_resolved(self.params);
267 }
268
269 #[inline]
270 pub fn get_warmup_period(&self) -> usize {
271 self.params.length.saturating_sub(1)
272 }
273
274 #[inline]
275 pub fn update(&mut self, value: f64) -> Option<f64> {
276 if !value.is_finite() {
277 self.reset();
278 return None;
279 }
280
281 let conv = self.params.conv;
282 let max_next = value.max(self.max_state - (self.max_state - value) / conv);
283 let min_next = value.min(self.min_state + (value - self.min_state) / conv);
284 self.max_state = max_next;
285 self.min_state = min_next;
286
287 let spread = max_next - min_next;
288 let diff = if spread > 0.0 { spread.ln() } else { f64::NAN };
289 self.push_diff(diff)
290 }
291
292 #[inline]
293 fn push_diff(&mut self, diff: f64) -> Option<f64> {
294 let is_valid = diff.is_finite();
295 let value = if is_valid { diff } else { 0.0 };
296 let n = self.params.length;
297
298 if self.filled < n {
299 let pos = self.filled;
300 self.ring_vals[pos] = value;
301 self.ring_valid[pos] = if is_valid { 1 } else { 0 };
302 self.sum_x += value;
303 self.sum_x2 += value * value;
304 self.weighted += pos as f64 * value;
305 if is_valid {
306 self.valid_count += 1;
307 }
308 self.filled += 1;
309 if self.filled < n {
310 return None;
311 }
312 } else {
313 let old_value = self.ring_vals[self.head];
314 let old_valid = self.ring_valid[self.head] as usize;
315 let old_sum = self.sum_x;
316
317 self.weighted =
318 self.weighted - old_sum + old_value + (n.saturating_sub(1) as f64) * value;
319 self.sum_x = old_sum - old_value + value;
320 self.sum_x2 = self.sum_x2 - old_value * old_value + value * value;
321 self.valid_count = self.valid_count + if is_valid { 1 } else { 0 } - old_valid;
322
323 self.ring_vals[self.head] = value;
324 self.ring_valid[self.head] = if is_valid { 1 } else { 0 };
325 self.head += 1;
326 if self.head == n {
327 self.head = 0;
328 }
329 }
330
331 if self.valid_count != n {
332 return Some(f64::NAN);
333 }
334
335 Some(psi_from_corr(self.sum_x, self.sum_x2, self.weighted, n))
336 }
337}
338
339#[inline]
340pub fn squeeze_index(input: &SqueezeIndexInput) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
341 squeeze_index_with_kernel(input, Kernel::Auto)
342}
343
344#[inline(always)]
345fn valid_value(value: f64) -> bool {
346 value.is_finite()
347}
348
349#[inline(always)]
350fn count_valid_values(data: &[f64]) -> usize {
351 data.iter().filter(|&&v| valid_value(v)).count()
352}
353
354#[inline(always)]
355fn first_valid_value(data: &[f64]) -> usize {
356 let mut i = 0usize;
357 while i < data.len() {
358 if valid_value(data[i]) {
359 return i;
360 }
361 i += 1;
362 }
363 data.len()
364}
365
366#[inline(always)]
367fn psi_from_corr(sum_x: f64, sum_x2: f64, weighted: f64, length: usize) -> f64 {
368 let n = length as f64;
369 let sum_y = n * (n - 1.0) * 0.5;
370 let sum_y2 = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
371 let denom_x = n * sum_x2 - sum_x * sum_x;
372 let denom_y = n * sum_y2 - sum_y * sum_y;
373 let denom = denom_x * denom_y;
374 if denom <= 0.0 || !denom.is_finite() {
375 return f64::NAN;
376 }
377 let corr = (n * weighted - sum_x * sum_y) / denom.sqrt();
378 -50.0 * corr + 50.0
379}
380
381#[inline]
382fn resolve_params(
383 params: &SqueezeIndexParams,
384 data_len: usize,
385) -> Result<ResolvedParams, SqueezeIndexError> {
386 let conv = params.conv.unwrap_or(DEFAULT_CONV);
387 if !conv.is_finite() || conv <= 1.0 {
388 return Err(SqueezeIndexError::InvalidConv { conv });
389 }
390 let length = params.length.unwrap_or(DEFAULT_LENGTH);
391 if length == 0 || (data_len > 0 && length > data_len) {
392 return Err(SqueezeIndexError::InvalidLength { length, data_len });
393 }
394 Ok(ResolvedParams { conv, length })
395}
396
397#[inline]
398fn squeeze_index_prepare<'a>(
399 input: &'a SqueezeIndexInput,
400 kernel: Kernel,
401) -> Result<(&'a [f64], ResolvedParams, Kernel), SqueezeIndexError> {
402 let data = input.as_ref();
403 let len = data.len();
404 if len == 0 {
405 return Err(SqueezeIndexError::EmptyInputData);
406 }
407
408 let first = first_valid_value(data);
409 if first >= len {
410 return Err(SqueezeIndexError::AllValuesNaN);
411 }
412
413 let params = resolve_params(&input.params, len)?;
414 let valid = count_valid_values(data);
415 if valid < params.length {
416 return Err(SqueezeIndexError::NotEnoughValidData {
417 needed: params.length,
418 valid,
419 });
420 }
421
422 let chosen = match kernel {
423 Kernel::Auto => detect_best_kernel(),
424 other => other.to_non_batch(),
425 };
426
427 Ok((data, params, chosen))
428}
429
430#[inline(always)]
431fn squeeze_index_row_from_slice(data: &[f64], params: ResolvedParams, out: &mut [f64]) {
432 out.fill(f64::NAN);
433 let mut stream = SqueezeIndexStream::new_resolved(params);
434 for (i, &value) in data.iter().enumerate() {
435 if let Some(out_value) = stream.update(value) {
436 out[i] = out_value;
437 }
438 }
439}
440
441#[inline]
442pub fn squeeze_index_with_kernel(
443 input: &SqueezeIndexInput,
444 kernel: Kernel,
445) -> Result<SqueezeIndexOutput, SqueezeIndexError> {
446 let (data, params, _chosen) = squeeze_index_prepare(input, kernel)?;
447 let mut values = alloc_with_nan_prefix(data.len(), params.length.saturating_sub(1));
448 squeeze_index_row_from_slice(data, params, &mut values);
449 Ok(SqueezeIndexOutput { values })
450}
451
452#[inline]
453pub fn squeeze_index_into_slice(
454 dst: &mut [f64],
455 input: &SqueezeIndexInput,
456 kernel: Kernel,
457) -> Result<(), SqueezeIndexError> {
458 let (data, params, _chosen) = squeeze_index_prepare(input, kernel)?;
459 if dst.len() != data.len() {
460 return Err(SqueezeIndexError::OutputLengthMismatch {
461 expected: data.len(),
462 got: dst.len(),
463 });
464 }
465 squeeze_index_row_from_slice(data, params, dst);
466 Ok(())
467}
468
469#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
470#[inline]
471pub fn squeeze_index_into(
472 input: &SqueezeIndexInput,
473 out: &mut [f64],
474) -> Result<(), SqueezeIndexError> {
475 squeeze_index_into_slice(out, input, Kernel::Auto)
476}
477
478#[derive(Clone, Debug)]
479pub struct SqueezeIndexBatchRange {
480 pub conv: (f64, f64, f64),
481 pub length: (usize, usize, usize),
482}
483
484impl Default for SqueezeIndexBatchRange {
485 fn default() -> Self {
486 Self {
487 conv: (DEFAULT_CONV, DEFAULT_CONV, 0.0),
488 length: (DEFAULT_LENGTH, DEFAULT_LENGTH, 0),
489 }
490 }
491}
492
493#[derive(Clone, Debug, Default)]
494pub struct SqueezeIndexBatchBuilder {
495 range: SqueezeIndexBatchRange,
496 kernel: Kernel,
497}
498
499impl SqueezeIndexBatchBuilder {
500 #[inline]
501 pub fn new() -> Self {
502 Self::default()
503 }
504
505 #[inline]
506 pub fn kernel(mut self, kernel: Kernel) -> Self {
507 self.kernel = kernel;
508 self
509 }
510
511 #[inline]
512 pub fn conv_range(mut self, start: f64, end: f64, step: f64) -> Self {
513 self.range.conv = (start, end, step);
514 self
515 }
516
517 #[inline]
518 pub fn conv_static(mut self, conv: f64) -> Self {
519 self.range.conv = (conv, conv, 0.0);
520 self
521 }
522
523 #[inline]
524 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
525 self.range.length = (start, end, step);
526 self
527 }
528
529 #[inline]
530 pub fn length_static(mut self, length: usize) -> Self {
531 self.range.length = (length, length, 0);
532 self
533 }
534
535 #[inline]
536 pub fn apply_slice(self, data: &[f64]) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
537 squeeze_index_batch_with_kernel(data, &self.range, self.kernel)
538 }
539
540 #[inline]
541 pub fn apply_candles(
542 self,
543 candles: &Candles,
544 source: &str,
545 ) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
546 self.apply_slice(source_type(candles, source))
547 }
548
549 #[inline]
550 pub fn with_default_candles(
551 candles: &Candles,
552 ) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
553 SqueezeIndexBatchBuilder::new()
554 .kernel(Kernel::Auto)
555 .apply_candles(candles, "close")
556 }
557}
558
559#[derive(Clone, Debug)]
560pub struct SqueezeIndexBatchOutput {
561 pub values: Vec<f64>,
562 pub combos: Vec<SqueezeIndexParams>,
563 pub rows: usize,
564 pub cols: usize,
565}
566
567impl SqueezeIndexBatchOutput {
568 pub fn row_for_params(&self, params: &SqueezeIndexParams) -> Option<usize> {
569 self.combos.iter().position(|combo| {
570 (combo.conv.unwrap_or(DEFAULT_CONV) - params.conv.unwrap_or(DEFAULT_CONV)).abs()
571 <= FLOAT_TOL
572 && combo.length.unwrap_or(DEFAULT_LENGTH) == params.length.unwrap_or(DEFAULT_LENGTH)
573 })
574 }
575
576 pub fn values_for(&self, params: &SqueezeIndexParams) -> Option<&[f64]> {
577 self.row_for_params(params).and_then(|row| {
578 row.checked_mul(self.cols)
579 .and_then(|start| self.values.get(start..start + self.cols))
580 })
581 }
582}
583
584fn expand_axis_usize(
585 start: usize,
586 end: usize,
587 step: usize,
588) -> Result<Vec<usize>, SqueezeIndexError> {
589 if start == end {
590 return Ok(vec![start]);
591 }
592 if step == 0 {
593 return Err(SqueezeIndexError::InvalidRange {
594 start: start.to_string(),
595 end: end.to_string(),
596 step: step.to_string(),
597 });
598 }
599
600 let mut out = Vec::new();
601 if start < end {
602 let mut x = start;
603 while x <= end {
604 out.push(x);
605 let next = x.saturating_add(step);
606 if next == x {
607 break;
608 }
609 x = next;
610 }
611 } else {
612 let mut x = start;
613 while x >= end {
614 out.push(x);
615 let next = x.saturating_sub(step);
616 if next == x {
617 break;
618 }
619 x = next;
620 }
621 }
622
623 if out.is_empty() {
624 return Err(SqueezeIndexError::InvalidRange {
625 start: start.to_string(),
626 end: end.to_string(),
627 step: step.to_string(),
628 });
629 }
630 Ok(out)
631}
632
633fn expand_axis_f64(start: f64, end: f64, step: f64) -> Result<Vec<f64>, SqueezeIndexError> {
634 if !start.is_finite() || !end.is_finite() || !step.is_finite() {
635 return Err(SqueezeIndexError::InvalidRange {
636 start: start.to_string(),
637 end: end.to_string(),
638 step: step.to_string(),
639 });
640 }
641 if (start - end).abs() <= FLOAT_TOL {
642 return Ok(vec![start]);
643 }
644 if step == 0.0 {
645 return Err(SqueezeIndexError::InvalidRange {
646 start: start.to_string(),
647 end: end.to_string(),
648 step: step.to_string(),
649 });
650 }
651
652 let mut out = Vec::new();
653 if start < end {
654 let mut x = start;
655 while x <= end + FLOAT_TOL {
656 out.push(x);
657 x += step.abs();
658 }
659 } else {
660 let mut x = start;
661 while x >= end - FLOAT_TOL {
662 out.push(x);
663 x -= step.abs();
664 }
665 }
666
667 if out.is_empty() {
668 return Err(SqueezeIndexError::InvalidRange {
669 start: start.to_string(),
670 end: end.to_string(),
671 step: step.to_string(),
672 });
673 }
674 Ok(out)
675}
676
677fn expand_grid_squeeze_index(
678 range: &SqueezeIndexBatchRange,
679) -> Result<Vec<SqueezeIndexParams>, SqueezeIndexError> {
680 let convs = expand_axis_f64(range.conv.0, range.conv.1, range.conv.2)?;
681 let lengths = expand_axis_usize(range.length.0, range.length.1, range.length.2)?;
682
683 let mut combos = Vec::with_capacity(convs.len().saturating_mul(lengths.len()));
684 for &conv in &convs {
685 if !conv.is_finite() || conv <= 1.0 {
686 return Err(SqueezeIndexError::InvalidConv { conv });
687 }
688 for &length in &lengths {
689 if length == 0 {
690 return Err(SqueezeIndexError::InvalidLength {
691 length,
692 data_len: 0,
693 });
694 }
695 combos.push(SqueezeIndexParams {
696 conv: Some(conv),
697 length: Some(length),
698 });
699 }
700 }
701 Ok(combos)
702}
703
704#[inline]
705pub fn squeeze_index_batch_with_kernel(
706 data: &[f64],
707 sweep: &SqueezeIndexBatchRange,
708 kernel: Kernel,
709) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
710 let batch_kernel = match kernel {
711 Kernel::Auto => detect_best_batch_kernel(),
712 other if other.is_batch() => other,
713 other => return Err(SqueezeIndexError::InvalidKernelForBatch(other)),
714 };
715 squeeze_index_batch_par_slice(data, sweep, batch_kernel.to_non_batch())
716}
717
718#[inline]
719pub fn squeeze_index_batch_slice(
720 data: &[f64],
721 sweep: &SqueezeIndexBatchRange,
722 kernel: Kernel,
723) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
724 squeeze_index_batch_inner(data, sweep, kernel, false)
725}
726
727#[inline]
728pub fn squeeze_index_batch_par_slice(
729 data: &[f64],
730 sweep: &SqueezeIndexBatchRange,
731 kernel: Kernel,
732) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
733 squeeze_index_batch_inner(data, sweep, kernel, true)
734}
735
736#[inline(always)]
737fn squeeze_index_batch_inner(
738 data: &[f64],
739 sweep: &SqueezeIndexBatchRange,
740 _kernel: Kernel,
741 parallel: bool,
742) -> Result<SqueezeIndexBatchOutput, SqueezeIndexError> {
743 let combos = expand_grid_squeeze_index(sweep)?;
744 let rows = combos.len();
745 let cols = data.len();
746 if cols == 0 {
747 return Err(SqueezeIndexError::EmptyInputData);
748 }
749
750 let first = first_valid_value(data);
751 if first >= cols {
752 return Err(SqueezeIndexError::AllValuesNaN);
753 }
754 let valid = count_valid_values(data);
755 let max_length = combos
756 .iter()
757 .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
758 .max()
759 .unwrap_or(0);
760 if valid < max_length {
761 return Err(SqueezeIndexError::NotEnoughValidData {
762 needed: max_length,
763 valid,
764 });
765 }
766
767 let mut buf_mu = make_uninit_matrix(rows, cols);
768 let warmups: Vec<usize> = combos
769 .iter()
770 .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH).saturating_sub(1))
771 .collect();
772 init_matrix_prefixes(&mut buf_mu, cols, &warmups);
773
774 let mut guard = ManuallyDrop::new(buf_mu);
775 let out: &mut [f64] =
776 unsafe { std::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
777
778 if parallel {
779 #[cfg(not(target_arch = "wasm32"))]
780 out.par_chunks_mut(cols)
781 .enumerate()
782 .for_each(|(row, out_row)| {
783 let params = resolve_params(&combos[row], cols).unwrap();
784 squeeze_index_row_from_slice(data, params, out_row);
785 });
786
787 #[cfg(target_arch = "wasm32")]
788 for (row, out_row) in out.chunks_mut(cols).enumerate() {
789 let params = resolve_params(&combos[row], cols).unwrap();
790 squeeze_index_row_from_slice(data, params, out_row);
791 }
792 } else {
793 for (row, out_row) in out.chunks_mut(cols).enumerate() {
794 let params = resolve_params(&combos[row], cols).unwrap();
795 squeeze_index_row_from_slice(data, params, out_row);
796 }
797 }
798
799 let values = unsafe {
800 Vec::from_raw_parts(
801 guard.as_mut_ptr() as *mut f64,
802 guard.len(),
803 guard.capacity(),
804 )
805 };
806
807 Ok(SqueezeIndexBatchOutput {
808 values,
809 combos,
810 rows,
811 cols,
812 })
813}
814
815#[inline(always)]
816pub fn squeeze_index_batch_inner_into(
817 data: &[f64],
818 sweep: &SqueezeIndexBatchRange,
819 _kernel: Kernel,
820 parallel: bool,
821 out: &mut [f64],
822) -> Result<Vec<SqueezeIndexParams>, SqueezeIndexError> {
823 let combos = expand_grid_squeeze_index(sweep)?;
824 let rows = combos.len();
825 let cols = data.len();
826 if cols == 0 {
827 return Err(SqueezeIndexError::EmptyInputData);
828 }
829
830 let total = rows
831 .checked_mul(cols)
832 .ok_or_else(|| SqueezeIndexError::OutputLengthMismatch {
833 expected: usize::MAX,
834 got: out.len(),
835 })?;
836 if out.len() != total {
837 return Err(SqueezeIndexError::OutputLengthMismatch {
838 expected: total,
839 got: out.len(),
840 });
841 }
842
843 let first = first_valid_value(data);
844 if first >= cols {
845 return Err(SqueezeIndexError::AllValuesNaN);
846 }
847 let valid = count_valid_values(data);
848 let max_length = combos
849 .iter()
850 .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
851 .max()
852 .unwrap_or(0);
853 if valid < max_length {
854 return Err(SqueezeIndexError::NotEnoughValidData {
855 needed: max_length,
856 valid,
857 });
858 }
859
860 if parallel {
861 #[cfg(not(target_arch = "wasm32"))]
862 out.par_chunks_mut(cols)
863 .enumerate()
864 .for_each(|(row, out_row)| {
865 let params = resolve_params(&combos[row], cols).unwrap();
866 squeeze_index_row_from_slice(data, params, out_row);
867 });
868
869 #[cfg(target_arch = "wasm32")]
870 for (row, out_row) in out.chunks_mut(cols).enumerate() {
871 let params = resolve_params(&combos[row], cols).unwrap();
872 squeeze_index_row_from_slice(data, params, out_row);
873 }
874 } else {
875 for (row, out_row) in out.chunks_mut(cols).enumerate() {
876 let params = resolve_params(&combos[row], cols).unwrap();
877 squeeze_index_row_from_slice(data, params, out_row);
878 }
879 }
880
881 Ok(combos)
882}
883
884#[cfg(feature = "python")]
885#[pyfunction(name = "squeeze_index")]
886#[pyo3(signature = (data, conv=DEFAULT_CONV, length=DEFAULT_LENGTH, kernel=None))]
887pub fn squeeze_index_py<'py>(
888 py: Python<'py>,
889 data: PyReadonlyArray1<'py, f64>,
890 conv: f64,
891 length: usize,
892 kernel: Option<&str>,
893) -> PyResult<Bound<'py, PyArray1<f64>>> {
894 let data = data.as_slice()?;
895 let kernel = validate_kernel(kernel, false)?;
896 let input = SqueezeIndexInput::from_slice(
897 data,
898 SqueezeIndexParams {
899 conv: Some(conv),
900 length: Some(length),
901 },
902 );
903 let output = py
904 .allow_threads(|| squeeze_index_with_kernel(&input, kernel))
905 .map_err(|e| PyValueError::new_err(e.to_string()))?;
906 Ok(output.values.into_pyarray(py))
907}
908
909#[cfg(feature = "python")]
910#[pyclass(name = "SqueezeIndexStream")]
911pub struct SqueezeIndexStreamPy {
912 stream: SqueezeIndexStream,
913}
914
915#[cfg(feature = "python")]
916#[pymethods]
917impl SqueezeIndexStreamPy {
918 #[new]
919 #[pyo3(signature = (conv=DEFAULT_CONV, length=DEFAULT_LENGTH))]
920 fn new(conv: f64, length: usize) -> PyResult<Self> {
921 let stream = SqueezeIndexStream::try_new(SqueezeIndexParams {
922 conv: Some(conv),
923 length: Some(length),
924 })
925 .map_err(|e| PyValueError::new_err(e.to_string()))?;
926 Ok(Self { stream })
927 }
928
929 fn update(&mut self, value: f64) -> Option<f64> {
930 self.stream.update(value)
931 }
932}
933
934#[cfg(feature = "python")]
935#[pyfunction(name = "squeeze_index_batch")]
936#[pyo3(signature = (data, conv_range, length_range, kernel=None))]
937pub fn squeeze_index_batch_py<'py>(
938 py: Python<'py>,
939 data: PyReadonlyArray1<'py, f64>,
940 conv_range: (f64, f64, f64),
941 length_range: (usize, usize, usize),
942 kernel: Option<&str>,
943) -> PyResult<Bound<'py, PyDict>> {
944 let data = data.as_slice()?;
945 let kernel = validate_kernel(kernel, true)?;
946 let sweep = SqueezeIndexBatchRange {
947 conv: conv_range,
948 length: length_range,
949 };
950
951 let combos =
952 expand_grid_squeeze_index(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
953 let rows = combos.len();
954 let cols = data.len();
955 let total = rows
956 .checked_mul(cols)
957 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
958
959 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
960 let slice_out = unsafe { out_arr.as_slice_mut()? };
961
962 let combos = py
963 .allow_threads(|| {
964 let batch = match kernel {
965 Kernel::Auto => detect_best_batch_kernel(),
966 other => other,
967 };
968 squeeze_index_batch_inner_into(data, &sweep, batch.to_non_batch(), true, slice_out)
969 })
970 .map_err(|e| PyValueError::new_err(e.to_string()))?;
971
972 let dict = PyDict::new(py);
973 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
974 dict.set_item(
975 "convs",
976 combos
977 .iter()
978 .map(|combo| combo.conv.unwrap_or(DEFAULT_CONV))
979 .collect::<Vec<_>>()
980 .into_pyarray(py),
981 )?;
982 dict.set_item(
983 "lengths",
984 combos
985 .iter()
986 .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH) as u64)
987 .collect::<Vec<_>>()
988 .into_pyarray(py),
989 )?;
990 dict.set_item("rows", rows)?;
991 dict.set_item("cols", cols)?;
992 Ok(dict)
993}
994
995#[cfg(feature = "python")]
996pub fn register_squeeze_index_module(module: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
997 module.add_function(wrap_pyfunction!(squeeze_index_py, module)?)?;
998 module.add_function(wrap_pyfunction!(squeeze_index_batch_py, module)?)?;
999 module.add_class::<SqueezeIndexStreamPy>()?;
1000 Ok(())
1001}
1002
1003#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1004#[wasm_bindgen(js_name = "squeeze_index_js")]
1005pub fn squeeze_index_js(data: &[f64], conv: f64, length: usize) -> Result<Vec<f64>, JsValue> {
1006 let input = SqueezeIndexInput::from_slice(
1007 data,
1008 SqueezeIndexParams {
1009 conv: Some(conv),
1010 length: Some(length),
1011 },
1012 );
1013 let mut output = vec![0.0; data.len()];
1014 squeeze_index_into_slice(&mut output, &input, Kernel::Auto)
1015 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1016 Ok(output)
1017}
1018
1019#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1020#[wasm_bindgen]
1021pub fn squeeze_index_alloc(len: usize) -> *mut f64 {
1022 let mut vec = Vec::<f64>::with_capacity(len);
1023 let ptr = vec.as_mut_ptr();
1024 std::mem::forget(vec);
1025 ptr
1026}
1027
1028#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1029#[wasm_bindgen]
1030pub fn squeeze_index_free(ptr: *mut f64, len: usize) {
1031 if !ptr.is_null() {
1032 unsafe {
1033 let _ = Vec::from_raw_parts(ptr, len, len);
1034 }
1035 }
1036}
1037
1038#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1039#[wasm_bindgen]
1040pub fn squeeze_index_into(
1041 data_ptr: *const f64,
1042 out_ptr: *mut f64,
1043 len: usize,
1044 conv: f64,
1045 length: usize,
1046) -> Result<(), JsValue> {
1047 if data_ptr.is_null() || out_ptr.is_null() {
1048 return Err(JsValue::from_str("Null pointer provided"));
1049 }
1050
1051 unsafe {
1052 let data = std::slice::from_raw_parts(data_ptr, len);
1053 let input = SqueezeIndexInput::from_slice(
1054 data,
1055 SqueezeIndexParams {
1056 conv: Some(conv),
1057 length: Some(length),
1058 },
1059 );
1060 if data_ptr == out_ptr {
1061 let mut tmp = vec![0.0; len];
1062 squeeze_index_into_slice(&mut tmp, &input, Kernel::Auto)
1063 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1064 std::slice::from_raw_parts_mut(out_ptr, len).copy_from_slice(&tmp);
1065 } else {
1066 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1067 squeeze_index_into_slice(out, &input, Kernel::Auto)
1068 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1069 }
1070 }
1071
1072 Ok(())
1073}
1074
1075#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1076#[derive(Serialize, Deserialize)]
1077pub struct SqueezeIndexBatchConfig {
1078 pub conv_range: (f64, f64, f64),
1079 pub length_range: (usize, usize, usize),
1080}
1081
1082#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1083#[derive(Serialize, Deserialize)]
1084pub struct SqueezeIndexBatchJsOutput {
1085 pub values: Vec<f64>,
1086 pub combos: Vec<SqueezeIndexParams>,
1087 pub convs: Vec<f64>,
1088 pub lengths: Vec<usize>,
1089 pub rows: usize,
1090 pub cols: usize,
1091}
1092
1093#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1094#[wasm_bindgen(js_name = "squeeze_index_batch_js")]
1095pub fn squeeze_index_batch_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1096 let config: SqueezeIndexBatchConfig = serde_wasm_bindgen::from_value(config)
1097 .map_err(|e| JsValue::from_str(&format!("Invalid config: {e}")))?;
1098 let sweep = SqueezeIndexBatchRange {
1099 conv: config.conv_range,
1100 length: config.length_range,
1101 };
1102 let output = squeeze_index_batch_inner(data, &sweep, detect_best_kernel(), false)
1103 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1104
1105 serde_wasm_bindgen::to_value(&SqueezeIndexBatchJsOutput {
1106 convs: output
1107 .combos
1108 .iter()
1109 .map(|combo| combo.conv.unwrap_or(DEFAULT_CONV))
1110 .collect(),
1111 lengths: output
1112 .combos
1113 .iter()
1114 .map(|combo| combo.length.unwrap_or(DEFAULT_LENGTH))
1115 .collect(),
1116 values: output.values,
1117 combos: output.combos,
1118 rows: output.rows,
1119 cols: output.cols,
1120 })
1121 .map_err(|e| JsValue::from_str(&e.to_string()))
1122}
1123
1124#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1125#[wasm_bindgen]
1126pub fn squeeze_index_batch_into(
1127 data_ptr: *const f64,
1128 out_ptr: *mut f64,
1129 len: usize,
1130 conv_start: f64,
1131 conv_end: f64,
1132 conv_step: f64,
1133 length_start: usize,
1134 length_end: usize,
1135 length_step: usize,
1136) -> Result<usize, JsValue> {
1137 if data_ptr.is_null() || out_ptr.is_null() {
1138 return Err(JsValue::from_str("Null pointer provided"));
1139 }
1140
1141 let sweep = SqueezeIndexBatchRange {
1142 conv: (conv_start, conv_end, conv_step),
1143 length: (length_start, length_end, length_step),
1144 };
1145 let combos =
1146 expand_grid_squeeze_index(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
1147 let rows = combos.len();
1148
1149 unsafe {
1150 let data = std::slice::from_raw_parts(data_ptr, len);
1151 let total = rows
1152 .checked_mul(len)
1153 .ok_or_else(|| JsValue::from_str("rows*cols overflow"))?;
1154 let out = std::slice::from_raw_parts_mut(out_ptr, total);
1155 squeeze_index_batch_inner_into(data, &sweep, detect_best_kernel(), false, out)
1156 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1157 }
1158
1159 Ok(rows)
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165 use crate::utilities::data_loader::read_candles_from_csv;
1166 use std::error::Error;
1167
1168 fn load_close() -> Result<Vec<f64>, Box<dyn Error>> {
1169 let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
1170 Ok(candles.close)
1171 }
1172
1173 fn approx_eq_or_nan(lhs: &[f64], rhs: &[f64], tol: f64) {
1174 assert_eq!(lhs.len(), rhs.len());
1175 for (i, (&a, &b)) in lhs.iter().zip(rhs.iter()).enumerate() {
1176 if a.is_nan() && b.is_nan() {
1177 continue;
1178 }
1179 assert!(
1180 (a - b).abs() <= tol,
1181 "mismatch at {i}: lhs={a} rhs={b} tol={tol}"
1182 );
1183 }
1184 }
1185
1186 #[test]
1187 fn squeeze_index_output_contract() -> Result<(), Box<dyn Error>> {
1188 let close = load_close()?;
1189 let input = SqueezeIndexInput::from_slice(
1190 &close,
1191 SqueezeIndexParams {
1192 conv: Some(50.0),
1193 length: Some(20),
1194 },
1195 );
1196 let out = squeeze_index_with_kernel(&input, Kernel::Scalar)?;
1197 assert_eq!(out.values.len(), close.len());
1198 let first_valid = out.values.iter().position(|v| !v.is_nan()).unwrap();
1199 assert!(first_valid >= 19);
1200 assert!(out.values[first_valid..].iter().any(|v| v.is_finite()));
1201 Ok(())
1202 }
1203
1204 #[test]
1205 fn squeeze_index_auto_matches_scalar() -> Result<(), Box<dyn Error>> {
1206 let close = load_close()?;
1207 let input = SqueezeIndexInput::from_slice(
1208 &close,
1209 SqueezeIndexParams {
1210 conv: Some(50.0),
1211 length: Some(20),
1212 },
1213 );
1214 let auto = squeeze_index_with_kernel(&input, Kernel::Auto)?;
1215 let scalar = squeeze_index_with_kernel(&input, Kernel::Scalar)?;
1216 approx_eq_or_nan(&auto.values, &scalar.values, 1e-12);
1217 Ok(())
1218 }
1219
1220 #[test]
1221 fn squeeze_index_stream_matches_batch_and_recovers_after_nan() -> Result<(), Box<dyn Error>> {
1222 let mut close = load_close()?;
1223 close[64] = f64::NAN;
1224 let batch = squeeze_index_with_kernel(
1225 &SqueezeIndexInput::from_slice(
1226 &close,
1227 SqueezeIndexParams {
1228 conv: Some(50.0),
1229 length: Some(10),
1230 },
1231 ),
1232 Kernel::Scalar,
1233 )?;
1234 let mut stream = SqueezeIndexStream::try_new(SqueezeIndexParams {
1235 conv: Some(50.0),
1236 length: Some(10),
1237 })?;
1238 let mut streamed = vec![f64::NAN; close.len()];
1239 for (i, &value) in close.iter().enumerate() {
1240 if let Some(out) = stream.update(value) {
1241 streamed[i] = out;
1242 }
1243 }
1244 approx_eq_or_nan(&streamed, &batch.values, 1e-12);
1245 assert!(streamed[64].is_nan());
1246 assert!(streamed[72].is_nan());
1247 assert!(streamed[74].is_finite());
1248 Ok(())
1249 }
1250
1251 #[test]
1252 fn squeeze_index_batch_single_param_matches_single() -> Result<(), Box<dyn Error>> {
1253 let close = load_close()?;
1254 let batch = squeeze_index_batch_with_kernel(
1255 &close,
1256 &SqueezeIndexBatchRange {
1257 conv: (50.0, 50.0, 0.0),
1258 length: (20, 20, 0),
1259 },
1260 Kernel::Auto,
1261 )?;
1262 assert_eq!(batch.rows, 1);
1263 assert_eq!(batch.cols, close.len());
1264 let single = squeeze_index_with_kernel(
1265 &SqueezeIndexInput::from_slice(
1266 &close,
1267 SqueezeIndexParams {
1268 conv: Some(50.0),
1269 length: Some(20),
1270 },
1271 ),
1272 Kernel::Auto,
1273 )?;
1274 approx_eq_or_nan(&batch.values, &single.values, 1e-12);
1275 Ok(())
1276 }
1277}