1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyUntypedArrayMethods};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7
8#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
9use serde::{Deserialize, Serialize};
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use wasm_bindgen::prelude::*;
12
13use crate::indicators::moving_averages::sma::{sma_with_kernel, SmaInput, SmaParams};
14use crate::utilities::data_loader::Candles;
15use crate::utilities::enums::Kernel;
16use crate::utilities::helpers::{alloc_with_nan_prefix, detect_best_kernel};
17use std::error::Error;
18use thiserror::Error;
19
20#[derive(Debug, Clone)]
21pub enum TtmSqueezeData<'a> {
22 Candles {
23 candles: &'a Candles,
24 },
25 Slices {
26 high: &'a [f64],
27 low: &'a [f64],
28 close: &'a [f64],
29 },
30}
31
32#[derive(Debug, Clone)]
33pub struct TtmSqueezeOutput {
34 pub momentum: Vec<f64>,
35 pub squeeze: Vec<f64>,
36}
37
38#[derive(Debug, Clone)]
39#[cfg_attr(
40 all(target_arch = "wasm32", feature = "wasm"),
41 derive(Serialize, Deserialize)
42)]
43pub struct TtmSqueezeParams {
44 pub length: Option<usize>,
45 pub bb_mult: Option<f64>,
46 pub kc_mult_high: Option<f64>,
47 pub kc_mult_mid: Option<f64>,
48 pub kc_mult_low: Option<f64>,
49}
50
51impl Default for TtmSqueezeParams {
52 fn default() -> Self {
53 Self {
54 length: Some(20),
55 bb_mult: Some(2.0),
56 kc_mult_high: Some(1.0),
57 kc_mult_mid: Some(1.5),
58 kc_mult_low: Some(2.0),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
64pub struct TtmSqueezeInput<'a> {
65 pub data: TtmSqueezeData<'a>,
66 pub params: TtmSqueezeParams,
67}
68
69impl<'a> TtmSqueezeInput<'a> {
70 #[inline]
71 pub fn from_candles(candles: &'a Candles, params: TtmSqueezeParams) -> Self {
72 Self {
73 data: TtmSqueezeData::Candles { candles },
74 params,
75 }
76 }
77
78 #[inline]
79 pub fn from_slices(
80 high: &'a [f64],
81 low: &'a [f64],
82 close: &'a [f64],
83 params: TtmSqueezeParams,
84 ) -> Self {
85 Self {
86 data: TtmSqueezeData::Slices { high, low, close },
87 params,
88 }
89 }
90
91 #[inline]
92 pub fn with_default_candles(candles: &'a Candles) -> Self {
93 Self::from_candles(candles, TtmSqueezeParams::default())
94 }
95
96 #[inline]
97 pub fn get_length(&self) -> usize {
98 self.params.length.unwrap_or(20)
99 }
100
101 #[inline]
102 pub fn get_bb_mult(&self) -> f64 {
103 self.params.bb_mult.unwrap_or(2.0)
104 }
105
106 #[inline]
107 pub fn get_kc_mult_high(&self) -> f64 {
108 self.params.kc_mult_high.unwrap_or(1.0)
109 }
110
111 #[inline]
112 pub fn get_kc_mult_mid(&self) -> f64 {
113 self.params.kc_mult_mid.unwrap_or(1.5)
114 }
115
116 #[inline]
117 pub fn get_kc_mult_low(&self) -> f64 {
118 self.params.kc_mult_low.unwrap_or(2.0)
119 }
120}
121
122#[derive(Debug, Clone)]
123pub struct TtmSqueezeBuilder {
124 length: Option<usize>,
125 bb_mult: Option<f64>,
126 kc_mult_high: Option<f64>,
127 kc_mult_mid: Option<f64>,
128 kc_mult_low: Option<f64>,
129 kernel: Kernel,
130}
131
132impl Default for TtmSqueezeBuilder {
133 fn default() -> Self {
134 Self {
135 length: None,
136 bb_mult: None,
137 kc_mult_high: None,
138 kc_mult_mid: None,
139 kc_mult_low: None,
140 kernel: Kernel::Auto,
141 }
142 }
143}
144
145impl TtmSqueezeBuilder {
146 #[inline]
147 pub fn new() -> Self {
148 Self::default()
149 }
150
151 #[inline]
152 pub fn length(mut self, length: usize) -> Self {
153 self.length = Some(length);
154 self
155 }
156
157 #[inline]
158 pub fn bb_mult(mut self, mult: f64) -> Self {
159 self.bb_mult = Some(mult);
160 self
161 }
162
163 #[inline]
164 pub fn kc_mult_high(mut self, mult: f64) -> Self {
165 self.kc_mult_high = Some(mult);
166 self
167 }
168
169 #[inline]
170 pub fn kc_mult_mid(mut self, mult: f64) -> Self {
171 self.kc_mult_mid = Some(mult);
172 self
173 }
174
175 #[inline]
176 pub fn kc_mult_low(mut self, mult: f64) -> Self {
177 self.kc_mult_low = Some(mult);
178 self
179 }
180
181 #[inline]
182 pub fn kernel(mut self, kernel: Kernel) -> Self {
183 self.kernel = kernel;
184 self
185 }
186
187 #[inline]
188 pub fn build_params(self) -> TtmSqueezeParams {
189 TtmSqueezeParams {
190 length: self.length,
191 bb_mult: self.bb_mult,
192 kc_mult_high: self.kc_mult_high,
193 kc_mult_mid: self.kc_mult_mid,
194 kc_mult_low: self.kc_mult_low,
195 }
196 }
197
198 #[inline(always)]
199 pub fn apply(self, candles: &Candles) -> Result<TtmSqueezeOutput, TtmSqueezeError> {
200 let kernel = self.kernel;
201 let params = self.build_params();
202 let input = TtmSqueezeInput::from_candles(candles, params);
203 ttm_squeeze_with_kernel(&input, kernel)
204 }
205
206 #[inline(always)]
207 pub fn apply_slices(
208 self,
209 high: &[f64],
210 low: &[f64],
211 close: &[f64],
212 ) -> Result<TtmSqueezeOutput, TtmSqueezeError> {
213 let kernel = self.kernel;
214 let params = self.build_params();
215 let input = TtmSqueezeInput::from_slices(high, low, close, params);
216 ttm_squeeze_with_kernel(&input, kernel)
217 }
218
219 #[inline(always)]
220 pub fn into_stream(self) -> Result<TtmSqueezeStream, TtmSqueezeError> {
221 TtmSqueezeStream::try_new(self.build_params())
222 }
223}
224
225#[derive(Debug, Error)]
226pub enum TtmSqueezeError {
227 #[error("ttm_squeeze: Input data slice is empty.")]
228 EmptyInputData,
229
230 #[error("ttm_squeeze: All values are NaN.")]
231 AllValuesNaN,
232
233 #[error("ttm_squeeze: Invalid period: period = {period}, data length = {data_len}")]
234 InvalidPeriod { period: usize, data_len: usize },
235
236 #[error("ttm_squeeze: Not enough valid data: needed = {needed}, valid = {valid}")]
237 NotEnoughValidData { needed: usize, valid: usize },
238
239 #[error("ttm_squeeze: Output slice length mismatch: expected = {expected}, got = {got}")]
240 OutputLengthMismatch { expected: usize, got: usize },
241
242 #[error("ttm_squeeze: Inconsistent slice lengths - high={high}, low={low}, close={close}")]
243 InconsistentSliceLengths {
244 high: usize,
245 low: usize,
246 close: usize,
247 },
248
249 #[error("ttm_squeeze: Invalid bb_mult: must be positive")]
250 InvalidBbMult { bb_mult: f64 },
251
252 #[error("ttm_squeeze: Invalid kc_mult_high: must be positive")]
253 InvalidKcMultHigh { kc_mult_high: f64 },
254
255 #[error("ttm_squeeze: Invalid kc_mult_mid: must be positive")]
256 InvalidKcMultMid { kc_mult_mid: f64 },
257
258 #[error("ttm_squeeze: Invalid kc_mult_low: must be positive")]
259 InvalidKcMultLow { kc_mult_low: f64 },
260
261 #[error("ttm_squeeze: Invalid range in batch sweep: start={start}, end={end}, step={step}")]
262 InvalidRange { start: f64, end: f64, step: f64 },
263
264 #[error("ttm_squeeze: Invalid kernel for batch path: {0:?}")]
265 InvalidKernelForBatch(Kernel),
266
267 #[error("ttm_squeeze: SMA error: {0}")]
268 SmaError(String),
269
270 #[error("ttm_squeeze: LinReg error: {0}")]
271 LinRegError(String),
272}
273
274#[inline]
275fn std_dev(data: &[f64], mean: f64, start: usize, end: usize) -> f64 {
276 let mut sum_sq = 0.0;
277 let mut count = 0;
278
279 for i in start..=end {
280 if !data[i].is_nan() {
281 let diff = data[i] - mean;
282 sum_sq += diff * diff;
283 count += 1;
284 }
285 }
286
287 if count > 1 {
288 (sum_sq / count as f64).sqrt()
289 } else {
290 f64::NAN
291 }
292}
293
294#[inline]
295fn true_range(high: f64, low: f64, prev_close: Option<f64>) -> f64 {
296 match prev_close {
297 Some(pc) => {
298 let hl = high - low;
299 let hc = (high - pc).abs();
300 let lc = (low - pc).abs();
301 hl.max(hc).max(lc)
302 }
303 None => high - low,
304 }
305}
306
307fn validate_params(params: &TtmSqueezeParams) -> Result<(), TtmSqueezeError> {
308 let ok = |x: f64| x.is_finite() && x > 0.0;
309
310 if let Some(bb) = params.bb_mult {
311 if !ok(bb) {
312 return Err(TtmSqueezeError::InvalidBbMult { bb_mult: bb });
313 }
314 }
315
316 if let Some(x) = params.kc_mult_high {
317 if !ok(x) {
318 return Err(TtmSqueezeError::InvalidKcMultHigh { kc_mult_high: x });
319 }
320 }
321
322 if let Some(x) = params.kc_mult_mid {
323 if !ok(x) {
324 return Err(TtmSqueezeError::InvalidKcMultMid { kc_mult_mid: x });
325 }
326 }
327
328 if let Some(x) = params.kc_mult_low {
329 if !ok(x) {
330 return Err(TtmSqueezeError::InvalidKcMultLow { kc_mult_low: x });
331 }
332 }
333
334 Ok(())
335}
336
337#[inline]
338pub fn ttm_squeeze(input: &TtmSqueezeInput) -> Result<TtmSqueezeOutput, TtmSqueezeError> {
339 ttm_squeeze_with_kernel(input, Kernel::Auto)
340}
341
342pub fn ttm_squeeze_with_kernel(
343 input: &TtmSqueezeInput,
344 kernel: Kernel,
345) -> Result<TtmSqueezeOutput, TtmSqueezeError> {
346 validate_params(&input.params)?;
347
348 let (high, low, close) = match &input.data {
349 TtmSqueezeData::Candles { candles } => {
350 if candles.close.is_empty() {
351 return Err(TtmSqueezeError::EmptyInputData);
352 }
353 (&candles.high[..], &candles.low[..], &candles.close[..])
354 }
355 TtmSqueezeData::Slices { high, low, close } => {
356 if high.len() != low.len() || low.len() != close.len() {
357 return Err(TtmSqueezeError::InconsistentSliceLengths {
358 high: high.len(),
359 low: low.len(),
360 close: close.len(),
361 });
362 }
363 if close.is_empty() {
364 return Err(TtmSqueezeError::EmptyInputData);
365 }
366 (*high, *low, *close)
367 }
368 };
369
370 let len = close.len();
371 let length = input.get_length();
372 let bb_mult = input.params.bb_mult.unwrap_or(2.0);
373 let kc_mult_high = input.params.kc_mult_high.unwrap_or(1.0);
374 let kc_mult_mid = input.params.kc_mult_mid.unwrap_or(1.5);
375 let kc_mult_low = input.params.kc_mult_low.unwrap_or(2.0);
376
377 if length == 0 || length > len {
378 return Err(TtmSqueezeError::InvalidPeriod {
379 period: length,
380 data_len: len,
381 });
382 }
383
384 let first = close
385 .iter()
386 .position(|&x| !x.is_nan())
387 .ok_or(TtmSqueezeError::AllValuesNaN)?;
388 if len - first < length {
389 return Err(TtmSqueezeError::NotEnoughValidData {
390 needed: length,
391 valid: len - first,
392 });
393 }
394
395 let warmup = first + length - 1;
396
397 let chosen = match kernel {
398 Kernel::Auto => detect_best_kernel(),
399 k => k,
400 };
401
402 if chosen == Kernel::Scalar
403 && length == 20
404 && bb_mult == 2.0
405 && kc_mult_high == 1.0
406 && kc_mult_mid == 1.5
407 && kc_mult_low == 2.0
408 {
409 let mut momentum = alloc_with_nan_prefix(len, warmup);
410 let mut squeeze = alloc_with_nan_prefix(len, warmup);
411
412 unsafe {
413 ttm_squeeze_scalar_classic(
414 high,
415 low,
416 close,
417 length,
418 bb_mult,
419 kc_mult_high,
420 kc_mult_mid,
421 kc_mult_low,
422 first,
423 warmup,
424 &mut momentum,
425 &mut squeeze,
426 )?;
427 }
428
429 return Ok(TtmSqueezeOutput { momentum, squeeze });
430 }
431
432 let sma_params = SmaParams {
433 period: Some(length),
434 };
435 let sma_input = SmaInput::from_slice(close, sma_params);
436 let sma_result = sma_with_kernel(&sma_input, kernel)
437 .map_err(|e| TtmSqueezeError::SmaError(e.to_string()))?;
438 let sma_values = sma_result.values;
439
440 let mut tr = alloc_with_nan_prefix(len, first);
441 for i in first..len {
442 tr[i] = if i == first {
443 high[i] - low[i]
444 } else {
445 let pc = close[i - 1];
446 let hl = high[i] - low[i];
447 let hc = (high[i] - pc).abs();
448 let lc = (low[i] - pc).abs();
449 hl.max(hc).max(lc)
450 };
451 }
452
453 let tr_sma_params = SmaParams {
454 period: Some(length),
455 };
456 let tr_sma_input = SmaInput::from_slice(&tr, tr_sma_params);
457 let tr_sma_result = sma_with_kernel(&tr_sma_input, kernel)
458 .map_err(|e| TtmSqueezeError::SmaError(e.to_string()))?;
459 let dev_kc = tr_sma_result.values;
460
461 let mut squeeze = alloc_with_nan_prefix(len, warmup);
462 let mut momentum = alloc_with_nan_prefix(len, warmup);
463
464 for i in warmup..len {
465 let m = sma_values[i];
466 let dkc = dev_kc[i];
467 if m.is_nan() || dkc.is_nan() {
468 continue;
469 }
470
471 let start = i + 1 - length;
472 let mut sum = 0.0;
473 let mut cnt = 0usize;
474 for j in start..=i {
475 let v = close[j];
476 if v.is_nan() {
477 continue;
478 }
479 let d = v - m;
480 sum += d * d;
481 cnt += 1;
482 }
483
484 if cnt > 1 {
485 let std = (sum / cnt as f64).sqrt();
486 let bb_upper = m + bb_mult * std;
487 let bb_lower = m - bb_mult * std;
488
489 let kc_upper_low = m + dkc * kc_mult_low;
490 let kc_lower_low = m - dkc * kc_mult_low;
491 let kc_upper_mid = m + dkc * kc_mult_mid;
492 let kc_lower_mid = m - dkc * kc_mult_mid;
493 let kc_upper_high = m + dkc * kc_mult_high;
494 let kc_lower_high = m - dkc * kc_mult_high;
495
496 let no_sqz = bb_lower < kc_lower_low || bb_upper > kc_upper_low;
497 squeeze[i] = if no_sqz {
498 0.0
499 } else if bb_lower >= kc_lower_high || bb_upper <= kc_upper_high {
500 3.0
501 } else if bb_lower >= kc_lower_mid || bb_upper <= kc_upper_mid {
502 2.0
503 } else {
504 1.0
505 };
506 }
507
508 let mut highest = f64::NEG_INFINITY;
509 let mut lowest = f64::INFINITY;
510 let mut has_valid = false;
511
512 for j in start..=i {
513 if high[j].is_finite() && low[j].is_finite() {
514 highest = highest.max(high[j]);
515 lowest = lowest.min(low[j]);
516 has_valid = true;
517 }
518 }
519
520 if has_valid {
521 let midpoint = (highest + lowest) * 0.5;
522
523 let avg = (midpoint + m) * 0.5;
524
525 let mut sx = 0.0;
526 let mut sy = 0.0;
527 let mut sxy = 0.0;
528 let mut sx2 = 0.0;
529 let mut n = 0.0;
530
531 for (k, j) in (start..=i).enumerate() {
532 let y = close[j] - avg;
533 if y.is_nan() {
534 continue;
535 }
536 let x = k as f64;
537 sx += x;
538 sy += y;
539 sxy += x * y;
540 sx2 += x * x;
541 n += 1.0;
542 }
543
544 if n >= 2.0 {
545 let slope = (n * sxy - sx * sy) / (n * sx2 - sx * sx);
546 let intercept = (sy - slope * sx) / n;
547 momentum[i] = intercept + slope * ((length - 1) as f64);
548 }
549 }
550 }
551
552 Ok(TtmSqueezeOutput { momentum, squeeze })
553}
554
555#[inline]
556pub fn ttm_squeeze_into_slices(
557 dst_momentum: &mut [f64],
558 dst_squeeze: &mut [f64],
559 input: &TtmSqueezeInput,
560 kernel: Kernel,
561) -> Result<(), TtmSqueezeError> {
562 validate_params(&input.params)?;
563
564 let (high, low, close) = match &input.data {
565 TtmSqueezeData::Candles { candles } => {
566 (&candles.high[..], &candles.low[..], &candles.close[..])
567 }
568 TtmSqueezeData::Slices { high, low, close } => (*high, *low, *close),
569 };
570
571 if close.is_empty() {
572 return Err(TtmSqueezeError::EmptyInputData);
573 }
574
575 if dst_momentum.len() != close.len() || dst_squeeze.len() != close.len() {
576 return Err(TtmSqueezeError::OutputLengthMismatch {
577 expected: close.len(),
578 got: dst_momentum.len().min(dst_squeeze.len()),
579 });
580 }
581
582 let len = close.len();
583 let length = input.get_length();
584 let bb_mult = input.get_bb_mult();
585 let kc_mult_high = input.get_kc_mult_high();
586 let kc_mult_mid = input.get_kc_mult_mid();
587 let kc_mult_low = input.get_kc_mult_low();
588
589 let first = close
590 .iter()
591 .position(|&x| !x.is_nan())
592 .ok_or(TtmSqueezeError::AllValuesNaN)?;
593
594 if length == 0 || length > len {
595 return Err(TtmSqueezeError::InvalidPeriod {
596 period: length,
597 data_len: len,
598 });
599 }
600
601 if len - first < length {
602 return Err(TtmSqueezeError::NotEnoughValidData {
603 needed: length,
604 valid: len - first,
605 });
606 }
607
608 let warmup = first + length - 1;
609
610 for i in 0..warmup {
611 dst_momentum[i] = f64::NAN;
612 dst_squeeze[i] = f64::NAN;
613 }
614
615 let sma_params = SmaParams {
616 period: Some(length),
617 };
618 let sma_input = SmaInput::from_slice(close, sma_params);
619 let sma_result = sma_with_kernel(&sma_input, kernel)
620 .map_err(|e| TtmSqueezeError::SmaError(e.to_string()))?;
621 let sma_values = sma_result.values;
622
623 let mut tr_values = alloc_with_nan_prefix(len, first);
624 for i in first..len {
625 tr_values[i] = if i == first {
626 high[i] - low[i]
627 } else {
628 true_range(high[i], low[i], Some(close[i - 1]))
629 };
630 }
631
632 let tr_sma_params = SmaParams {
633 period: Some(length),
634 };
635 let tr_sma_input = SmaInput::from_slice(&tr_values, tr_sma_params);
636 let tr_sma_result = sma_with_kernel(&tr_sma_input, kernel)
637 .map_err(|e| TtmSqueezeError::SmaError(e.to_string()))?;
638 let dev_kc = tr_sma_result.values;
639
640 for i in warmup..len {
641 let m = sma_values[i];
642 let dev_kc_val = dev_kc[i];
643
644 if m.is_nan() || dev_kc_val.is_nan() {
645 dst_squeeze[i] = f64::NAN;
646 continue;
647 }
648
649 let start = i + 1 - length;
650 let mut sum = 0.0;
651 let mut count = 0;
652
653 for j in start..=i {
654 if !close[j].is_nan() {
655 let d = close[j] - m;
656 sum += d * d;
657 count += 1;
658 }
659 }
660
661 let std = if count > 1 {
662 (sum / count as f64).sqrt()
663 } else {
664 f64::NAN
665 };
666
667 if std.is_nan() {
668 dst_squeeze[i] = f64::NAN;
669 continue;
670 }
671
672 let bb_upper = m + bb_mult * std;
673 let bb_lower = m - bb_mult * std;
674 let kc_upper_low = m + dev_kc_val * kc_mult_low;
675 let kc_lower_low = m - dev_kc_val * kc_mult_low;
676 let kc_upper_mid = m + dev_kc_val * kc_mult_mid;
677 let kc_lower_mid = m - dev_kc_val * kc_mult_mid;
678 let kc_upper_high = m + dev_kc_val * kc_mult_high;
679 let kc_lower_high = m - dev_kc_val * kc_mult_high;
680
681 let no_sqz = bb_lower < kc_lower_low || bb_upper > kc_upper_low;
682
683 dst_squeeze[i] = if no_sqz {
684 0.0
685 } else if bb_lower >= kc_lower_high || bb_upper <= kc_upper_high {
686 3.0
687 } else if bb_lower >= kc_lower_mid || bb_upper <= kc_upper_mid {
688 2.0
689 } else {
690 1.0
691 };
692 }
693
694 for end_idx in warmup..len {
695 let start_idx = end_idx + 1 - length;
696
697 let mut highest = f64::NEG_INFINITY;
698 let mut lowest = f64::INFINITY;
699 let mut has_valid = false;
700
701 for j in start_idx..=end_idx {
702 if high[j].is_finite() && low[j].is_finite() {
703 highest = highest.max(high[j]);
704 lowest = lowest.min(low[j]);
705 has_valid = true;
706 }
707 }
708
709 if !has_valid || sma_values[end_idx].is_nan() {
710 dst_momentum[end_idx] = f64::NAN;
711 continue;
712 }
713
714 let midpoint = (highest + lowest) * 0.5;
715 let avg = (midpoint + sma_values[end_idx]) / 2.0;
716
717 let mut sum_x = 0.0;
718 let mut sum_y = 0.0;
719 let mut sum_xy = 0.0;
720 let mut sum_x2 = 0.0;
721 let mut n = 0.0;
722
723 for (k, j) in (start_idx..=end_idx).enumerate() {
724 if close[j].is_nan() {
725 continue;
726 }
727 let x = k as f64;
728 let y = close[j] - avg;
729 sum_x += x;
730 sum_y += y;
731 sum_xy += x * y;
732 sum_x2 += x * x;
733 n += 1.0;
734 }
735
736 if n >= 2.0 {
737 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
738 let intercept = (sum_y - slope * sum_x) / n;
739 dst_momentum[end_idx] = intercept + slope * ((length - 1) as f64);
740 } else {
741 dst_momentum[end_idx] = f64::NAN;
742 }
743 }
744
745 Ok(())
746}
747
748#[inline]
749pub fn ttm_squeeze_into(
750 dst_momentum: &mut [f64],
751 dst_squeeze: &mut [f64],
752 input: &TtmSqueezeInput,
753 kernel: Kernel,
754) -> Result<(), TtmSqueezeError> {
755 ttm_squeeze_into_slices(dst_momentum, dst_squeeze, input, kernel)
756}
757
758#[derive(Debug, Clone)]
759struct MonoDeque {
760 idx: Vec<usize>,
761 val: Vec<f64>,
762 head: usize,
763 tail: usize,
764 len: usize,
765 cap: usize,
766 is_max: bool,
767}
768
769impl MonoDeque {
770 #[inline(always)]
771 fn new(cap: usize, is_max: bool) -> Self {
772 Self {
773 idx: vec![0; cap],
774 val: vec![f64::NAN; cap],
775 head: 0,
776 tail: 0,
777 len: 0,
778 cap,
779 is_max,
780 }
781 }
782
783 #[inline(always)]
784 fn clear(&mut self) {
785 self.head = 0;
786 self.tail = 0;
787 self.len = 0;
788 }
789
790 #[inline(always)]
791 fn is_empty(&self) -> bool {
792 self.len == 0
793 }
794
795 #[inline(always)]
796 fn front_val(&self) -> f64 {
797 debug_assert!(self.len > 0);
798 self.val[self.head]
799 }
800
801 #[inline(always)]
802 fn expire(&mut self, min_idx: usize) {
803 while self.len > 0 {
804 let i = self.idx[self.head];
805 if i >= min_idx {
806 break;
807 }
808 self.head += 1;
809 if self.head == self.cap {
810 self.head = 0;
811 }
812 self.len -= 1;
813 }
814 }
815
816 #[inline(always)]
817 fn push(&mut self, idx: usize, value: f64) {
818 while self.len > 0 {
819 let back_pos = if self.tail == 0 {
820 self.cap - 1
821 } else {
822 self.tail - 1
823 };
824 let back_val = self.val[back_pos];
825
826 let ok = if self.is_max {
827 back_val >= value
828 } else {
829 back_val <= value
830 };
831 if ok {
832 break;
833 }
834 self.tail = back_pos;
835 self.len -= 1;
836 }
837 self.idx[self.tail] = idx;
838 self.val[self.tail] = value;
839 self.tail += 1;
840 if self.tail == self.cap {
841 self.tail = 0;
842 }
843 self.len += 1;
844 }
845}
846
847#[derive(Debug, Clone)]
848pub struct TtmSqueezeStream {
849 params: TtmSqueezeParams,
850
851 hi: Vec<f64>,
852 lo: Vec<f64>,
853 cl: Vec<f64>,
854 tr: Vec<f64>,
855
856 head: usize,
857 filled: bool,
858 t: usize,
859
860 sum0: f64,
861 sum1: f64,
862 sumsq: f64,
863 tr_sum: f64,
864
865 prev_close: Option<f64>,
866
867 n: usize,
868 n_f64: f64,
869 inv_n: f64,
870 sx: f64,
871 sx2: f64,
872 inv_den: f64,
873 half_nm1: f64,
874
875 bb_sq: f64,
876 kc_low_sq: f64,
877 kc_mid_sq: f64,
878 kc_high_sq: f64,
879
880 max_q: MonoDeque,
881 min_q: MonoDeque,
882}
883
884impl TtmSqueezeStream {
885 pub fn try_new(params: TtmSqueezeParams) -> Result<Self, TtmSqueezeError> {
886 let n = params.length.unwrap_or(20);
887 if n == 0 {
888 return Err(TtmSqueezeError::InvalidPeriod {
889 period: 0,
890 data_len: 0,
891 });
892 }
893
894 let n_f64 = n as f64;
895 let inv_n = 1.0 / n_f64;
896 let sx = 0.5 * n_f64 * (n_f64 - 1.0);
897 let sx2 = (n_f64 - 1.0) * n_f64 * (2.0 * n_f64 - 1.0) / 6.0;
898 let den = n_f64 * sx2 - sx * sx;
899 let inv_den = if den > 0.0 { 1.0 / den } else { 0.0 };
900 let half_nm1 = 0.5 * (n_f64 - 1.0);
901
902 let bb = params.bb_mult.unwrap_or(2.0);
903 let kc_hi = params.kc_mult_high.unwrap_or(1.0);
904 let kc_md = params.kc_mult_mid.unwrap_or(1.5);
905 let kc_lo = params.kc_mult_low.unwrap_or(2.0);
906
907 Ok(Self {
908 params,
909 hi: vec![f64::NAN; n],
910 lo: vec![f64::NAN; n],
911 cl: vec![f64::NAN; n],
912 tr: vec![0.0; n],
913
914 head: 0,
915 filled: false,
916 t: 0,
917
918 sum0: 0.0,
919 sum1: 0.0,
920 sumsq: 0.0,
921 tr_sum: 0.0,
922
923 prev_close: None,
924
925 n,
926 n_f64,
927 inv_n,
928 sx,
929 sx2,
930 inv_den,
931 half_nm1,
932
933 bb_sq: bb * bb,
934 kc_low_sq: kc_lo * kc_lo,
935 kc_mid_sq: kc_md * kc_md,
936 kc_high_sq: kc_hi * kc_hi,
937
938 max_q: MonoDeque::new(n, true),
939 min_q: MonoDeque::new(n, false),
940 })
941 }
942
943 #[inline]
944 pub fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
945 let n = self.n;
946 let pos = self.head;
947
948 let tr_new = match self.prev_close {
949 Some(pc) => {
950 let hl = high - low;
951 let hc = (high - pc).abs();
952 let lc = (low - pc).abs();
953 if hl >= hc {
954 if hl >= lc {
955 hl
956 } else {
957 lc
958 }
959 } else if hc >= lc {
960 hc
961 } else {
962 lc
963 }
964 }
965 None => high - low,
966 };
967
968 if self.filled {
969 let min_idx = self.t + 1 - n;
970 self.max_q.expire(min_idx);
971 self.min_q.expire(min_idx);
972 }
973
974 self.max_q.push(self.t, high);
975 self.min_q.push(self.t, low);
976
977 let old_c = self.cl[pos];
978 let old_tr = self.tr[pos];
979
980 self.hi[pos] = high;
981 self.lo[pos] = low;
982 self.cl[pos] = close;
983 self.tr[pos] = tr_new;
984
985 if !self.filled {
986 self.sum0 += close;
987 self.sumsq = close.mul_add(close, self.sumsq);
988 self.sum1 += (self.t as f64) * close;
989 self.tr_sum += tr_new;
990
991 self.prev_close = Some(close);
992 self.head = (pos + 1) % n;
993 self.t += 1;
994
995 if self.t < n {
996 return None;
997 }
998
999 self.filled = true;
1000 return Some(self.emit());
1001 }
1002
1003 let sum0_old = self.sum0;
1004
1005 self.sum0 += close - old_c;
1006 self.sumsq = close.mul_add(close, self.sumsq - old_c * old_c);
1007
1008 self.sum1 = self.sum1 - sum0_old + old_c + (self.n_f64 - 1.0) * close;
1009
1010 self.tr_sum += tr_new - old_tr;
1011
1012 self.prev_close = Some(close);
1013 self.head = (pos + 1) % n;
1014 self.t += 1;
1015
1016 Some(self.emit())
1017 }
1018
1019 #[inline]
1020 fn emit(&self) -> (f64, f64) {
1021 let m = self.sum0 * self.inv_n;
1022 let var = (-m).mul_add(m, self.sumsq * self.inv_n);
1023 let var_pos = if var > 0.0 { var } else { 0.0 };
1024
1025 let dkc = self.tr_sum * self.inv_n;
1026 let dkc2 = dkc * dkc;
1027
1028 let bbv = self.bb_sq * var_pos;
1029 let t_low = self.kc_low_sq * dkc2;
1030 let t_mid = self.kc_mid_sq * dkc2;
1031 let t_hi = self.kc_high_sq * dkc2;
1032
1033 let sqz = if bbv > t_low {
1034 0.0
1035 } else if bbv <= t_hi {
1036 3.0
1037 } else if bbv <= t_mid {
1038 2.0
1039 } else {
1040 1.0
1041 };
1042
1043 let highest = if self.max_q.is_empty() {
1044 f64::NAN
1045 } else {
1046 self.max_q.front_val()
1047 };
1048 let lowest = if self.min_q.is_empty() {
1049 f64::NAN
1050 } else {
1051 self.min_q.front_val()
1052 };
1053
1054 let midpoint = 0.5 * (highest + lowest);
1055 let avg = 0.5 * (midpoint + m);
1056
1057 let sy = self.sum0 - avg * self.n_f64;
1058 let sxy = self.sum1 - avg * self.sx;
1059
1060 let mom = if self.n >= 2 && self.inv_den.is_finite() {
1061 let slope = self.n_f64.mul_add(sxy, -(self.sx * sy)) * self.inv_den;
1062 sy * self.inv_n + slope * self.half_nm1
1063 } else {
1064 f64::NAN
1065 };
1066
1067 (mom, sqz)
1068 }
1069
1070 pub fn reset(&mut self) {
1071 self.hi.fill(f64::NAN);
1072 self.lo.fill(f64::NAN);
1073 self.cl.fill(f64::NAN);
1074 self.tr.fill(0.0);
1075
1076 self.head = 0;
1077 self.filled = false;
1078 self.t = 0;
1079
1080 self.sum0 = 0.0;
1081 self.sum1 = 0.0;
1082 self.sumsq = 0.0;
1083 self.tr_sum = 0.0;
1084
1085 self.prev_close = None;
1086
1087 self.max_q.clear();
1088 self.min_q.clear();
1089 }
1090}
1091
1092#[inline(always)]
1093pub unsafe fn ttm_squeeze_scalar_classic(
1094 high: &[f64],
1095 low: &[f64],
1096 close: &[f64],
1097 length: usize,
1098 bb_mult: f64,
1099 kc_mult_high: f64,
1100 kc_mult_mid: f64,
1101 kc_mult_low: f64,
1102 first: usize,
1103 warmup: usize,
1104 momentum: &mut [f64],
1105 squeeze: &mut [f64],
1106) -> Result<(), TtmSqueezeError> {
1107 let len = close.len();
1108 if len == 0 || length < 2 || warmup >= len {
1109 return Ok(());
1110 }
1111
1112 let n = length as f64;
1113 let sx = 0.5 * n * (n - 1.0);
1114 let sx2 = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
1115 let den = n * sx2 - sx * sx;
1116 let inv_den = 1.0 / den;
1117 let inv_n = 1.0 / n;
1118 let half_nm1 = 0.5 * (n - 1.0);
1119
1120 let mut cbuf = vec![0.0f64; length];
1121 let mut trbuf = vec![0.0f64; length];
1122 let mut cpos = 0usize;
1123 let mut trpos = 0usize;
1124
1125 let mut sum0 = 0.0f64;
1126 let mut sum1 = 0.0f64;
1127 let mut sumsq = 0.0f64;
1128 let mut tr_sum = 0.0f64;
1129
1130 let cap = length;
1131 let mut max_q = vec![0usize; cap];
1132 let mut min_q = vec![0usize; cap];
1133 let (mut max_head, mut max_tail, mut max_len) = (0usize, 0usize, 0usize);
1134 let (mut min_head, mut min_tail, mut min_len) = (0usize, 0usize, 0usize);
1135
1136 let bb_sq = bb_mult * bb_mult;
1137 let kc_low_sq = kc_mult_low * kc_mult_low;
1138 let kc_mid_sq = kc_mult_mid * kc_mult_mid;
1139 let kc_high_sq = kc_mult_high * kc_mult_high;
1140
1141 {
1142 let mut r = 0usize;
1143 let mut i = first;
1144 while i <= warmup {
1145 let c = *close.get_unchecked(i);
1146 *cbuf.get_unchecked_mut(cpos) = c;
1147 sum0 += c;
1148 sumsq = c.mul_add(c, sumsq);
1149 sum1 += (r as f64) * c;
1150
1151 let tr_val = if i == first {
1152 *high.get_unchecked(i) - *low.get_unchecked(i)
1153 } else {
1154 let pc = *close.get_unchecked(i - 1);
1155 let hl = *high.get_unchecked(i) - *low.get_unchecked(i);
1156 let hc = (*high.get_unchecked(i) - pc).abs();
1157 let lc = (*low.get_unchecked(i) - pc).abs();
1158 if hl >= hc {
1159 if hl >= lc {
1160 hl
1161 } else {
1162 lc
1163 }
1164 } else {
1165 if hc >= lc {
1166 hc
1167 } else {
1168 lc
1169 }
1170 }
1171 };
1172 *trbuf.get_unchecked_mut(trpos) = tr_val;
1173 tr_sum += tr_val;
1174
1175 while max_len > 0 {
1176 let back_pos = if max_tail == 0 { cap - 1 } else { max_tail - 1 };
1177 let back_idx = *max_q.get_unchecked(back_pos);
1178 if *high.get_unchecked(i) <= *high.get_unchecked(back_idx) {
1179 break;
1180 }
1181 max_tail = back_pos;
1182 max_len -= 1;
1183 }
1184 *max_q.get_unchecked_mut(max_tail) = i;
1185 max_tail += 1;
1186 if max_tail == cap {
1187 max_tail = 0;
1188 }
1189 max_len += 1;
1190
1191 while min_len > 0 {
1192 let back_pos = if min_tail == 0 { cap - 1 } else { min_tail - 1 };
1193 let back_idx = *min_q.get_unchecked(back_pos);
1194 if *low.get_unchecked(i) >= *low.get_unchecked(back_idx) {
1195 break;
1196 }
1197 min_tail = back_pos;
1198 min_len -= 1;
1199 }
1200 *min_q.get_unchecked_mut(min_tail) = i;
1201 min_tail += 1;
1202 if min_tail == cap {
1203 min_tail = 0;
1204 }
1205 min_len += 1;
1206
1207 cpos += 1;
1208 if cpos == length {
1209 cpos = 0;
1210 }
1211 trpos += 1;
1212 if trpos == length {
1213 trpos = 0;
1214 }
1215 r += 1;
1216 i += 1;
1217 }
1218 }
1219
1220 {
1221 let m = sum0 * inv_n;
1222 let var = (-m).mul_add(m, sumsq * inv_n);
1223 let var_pos = if var > 0.0 { var } else { 0.0 };
1224 let dkc = tr_sum * inv_n;
1225 let dkc2 = dkc * dkc;
1226
1227 let bbv = bb_sq * var_pos;
1228 let t_low = kc_low_sq * dkc2;
1229 let t_mid = kc_mid_sq * dkc2;
1230 let t_high = kc_high_sq * dkc2;
1231
1232 *squeeze.get_unchecked_mut(warmup) = if bbv > t_low {
1233 0.0
1234 } else if bbv <= t_high {
1235 3.0
1236 } else if bbv <= t_mid {
1237 2.0
1238 } else {
1239 1.0
1240 };
1241
1242 let hi_idx = *max_q.get_unchecked(max_head);
1243 let lo_idx = *min_q.get_unchecked(min_head);
1244 let highest = *high.get_unchecked(hi_idx);
1245 let lowest = *low.get_unchecked(lo_idx);
1246
1247 let midpoint = 0.5 * (highest + lowest);
1248 let avg = 0.5 * (midpoint + m);
1249 let sy = sum0 - avg * n;
1250 let sxy = sum1 - avg * sx;
1251 let slope = n.mul_add(sxy, -(sx * sy)) * inv_den;
1252 *momentum.get_unchecked_mut(warmup) = sy * inv_n + slope * half_nm1;
1253 }
1254
1255 let mut i = warmup + 1;
1256 while i < len {
1257 let start_idx = i + 1 - length;
1258
1259 while max_len > 0 {
1260 let front_idx = *max_q.get_unchecked(max_head);
1261 if front_idx >= start_idx {
1262 break;
1263 }
1264 max_head += 1;
1265 if max_head == cap {
1266 max_head = 0;
1267 }
1268 max_len -= 1;
1269 }
1270 while min_len > 0 {
1271 let front_idx = *min_q.get_unchecked(min_head);
1272 if front_idx >= start_idx {
1273 break;
1274 }
1275 min_head += 1;
1276 if min_head == cap {
1277 min_head = 0;
1278 }
1279 min_len -= 1;
1280 }
1281
1282 while max_len > 0 {
1283 let back_pos = if max_tail == 0 { cap - 1 } else { max_tail - 1 };
1284 let back_idx = *max_q.get_unchecked(back_pos);
1285 if *high.get_unchecked(i) <= *high.get_unchecked(back_idx) {
1286 break;
1287 }
1288 max_tail = back_pos;
1289 max_len -= 1;
1290 }
1291 *max_q.get_unchecked_mut(max_tail) = i;
1292 max_tail += 1;
1293 if max_tail == cap {
1294 max_tail = 0;
1295 }
1296 max_len += 1;
1297
1298 while min_len > 0 {
1299 let back_pos = if min_tail == 0 { cap - 1 } else { min_tail - 1 };
1300 let back_idx = *min_q.get_unchecked(back_pos);
1301 if *low.get_unchecked(i) >= *low.get_unchecked(back_idx) {
1302 break;
1303 }
1304 min_tail = back_pos;
1305 min_len -= 1;
1306 }
1307 *min_q.get_unchecked_mut(min_tail) = i;
1308 min_tail += 1;
1309 if min_tail == cap {
1310 min_tail = 0;
1311 }
1312 min_len += 1;
1313
1314 let old = *cbuf.get_unchecked(cpos);
1315 let new = *close.get_unchecked(i);
1316 let sum0_old = sum0;
1317 sum0 += new - old;
1318 sumsq = new.mul_add(new, sumsq - old * old);
1319 sum1 = sum1 - sum0_old + old + (n - 1.0) * new;
1320 *cbuf.get_unchecked_mut(cpos) = new;
1321 cpos += 1;
1322 if cpos == length {
1323 cpos = 0;
1324 }
1325
1326 let old_tr = *trbuf.get_unchecked(trpos);
1327 let pc = *close.get_unchecked(i - 1);
1328 let hi_i = *high.get_unchecked(i);
1329 let lo_i = *low.get_unchecked(i);
1330 let hl = hi_i - lo_i;
1331 let hc = (hi_i - pc).abs();
1332 let lc = (lo_i - pc).abs();
1333 let tr_new = hl.max(hc).max(lc);
1334 tr_sum += tr_new - old_tr;
1335 *trbuf.get_unchecked_mut(trpos) = tr_new;
1336 trpos += 1;
1337 if trpos == length {
1338 trpos = 0;
1339 }
1340
1341 let m = sum0 * inv_n;
1342 let var = (-m).mul_add(m, sumsq * inv_n);
1343 let var_pos = if var > 0.0 { var } else { 0.0 };
1344 let dkc = tr_sum * inv_n;
1345 let dkc2 = dkc * dkc;
1346 let bbv = bb_sq * var_pos;
1347 let t_low = kc_low_sq * dkc2;
1348 let t_mid = kc_mid_sq * dkc2;
1349 let t_high = kc_high_sq * dkc2;
1350 *squeeze.get_unchecked_mut(i) = if bbv > t_low {
1351 0.0
1352 } else if bbv <= t_high {
1353 3.0
1354 } else if bbv <= t_mid {
1355 2.0
1356 } else {
1357 1.0
1358 };
1359
1360 let hi_idx = *max_q.get_unchecked(max_head);
1361 let lo_idx = *min_q.get_unchecked(min_head);
1362 let highest = *high.get_unchecked(hi_idx);
1363 let lowest = *low.get_unchecked(lo_idx);
1364
1365 let midpoint = 0.5 * (highest + lowest);
1366 let avg = 0.5 * (midpoint + m);
1367 let sy = sum0 - avg * n;
1368 let sxy = sum1 - avg * sx;
1369 let slope = n.mul_add(sxy, -(sx * sy)) * inv_den;
1370 *momentum.get_unchecked_mut(i) = sy * inv_n + slope * half_nm1;
1371
1372 i += 1;
1373 }
1374
1375 Ok(())
1376}
1377
1378#[cfg(feature = "python")]
1379use crate::utilities::kernel_validation::validate_kernel;
1380
1381#[cfg(feature = "python")]
1382#[pyfunction(name = "ttm_squeeze")]
1383#[pyo3(signature = (high, low, close, length=20, bb_mult=2.0, kc_mult_high=1.0, kc_mult_mid=1.5, kc_mult_low=2.0, kernel=None))]
1384pub fn ttm_squeeze_py<'py>(
1385 py: Python<'py>,
1386 high: numpy::PyReadonlyArray1<'py, f64>,
1387 low: numpy::PyReadonlyArray1<'py, f64>,
1388 close: numpy::PyReadonlyArray1<'py, f64>,
1389 length: usize,
1390 bb_mult: f64,
1391 kc_mult_high: f64,
1392 kc_mult_mid: f64,
1393 kc_mult_low: f64,
1394 kernel: Option<&str>,
1395) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
1396 let h = high.as_slice()?;
1397 let l = low.as_slice()?;
1398 let c = close.as_slice()?;
1399
1400 if h.len() != l.len() || l.len() != c.len() {
1401 return Err(PyValueError::new_err(format!(
1402 "ttm_squeeze: Inconsistent slice lengths - high={}, low={}, close={}",
1403 h.len(),
1404 l.len(),
1405 c.len()
1406 )));
1407 }
1408
1409 let params = TtmSqueezeParams {
1410 length: Some(length),
1411 bb_mult: Some(bb_mult),
1412 kc_mult_high: Some(kc_mult_high),
1413 kc_mult_mid: Some(kc_mult_mid),
1414 kc_mult_low: Some(kc_mult_low),
1415 };
1416
1417 let input = TtmSqueezeInput::from_slices(h, l, c, params);
1418 let kern = validate_kernel(kernel, false)?;
1419
1420 let mut momentum = vec![f64::NAN; c.len()];
1421 let mut squeeze = vec![f64::NAN; c.len()];
1422
1423 py.allow_threads(|| ttm_squeeze_into_slices(&mut momentum, &mut squeeze, &input, kern))
1424 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1425
1426 Ok((momentum.into_pyarray(py), squeeze.into_pyarray(py)))
1427}
1428
1429#[cfg(all(feature = "python", feature = "cuda"))]
1430use crate::cuda::{cuda_available, CudaTtmSqueeze};
1431#[cfg(all(feature = "python", feature = "cuda"))]
1432use crate::indicators::moving_averages::alma::DeviceArrayF32Py;
1433#[cfg(all(feature = "python", feature = "cuda"))]
1434use numpy::PyReadonlyArray1;
1435#[cfg(all(feature = "python", feature = "cuda"))]
1436#[cfg(all(feature = "python", feature = "cuda"))]
1437#[pyfunction(name = "ttm_squeeze_cuda_batch_dev")]
1438#[pyo3(signature = (high_f32, low_f32, close_f32, length_range, bb_mult_range, kc_high_range, kc_mid_range, kc_low_range, device_id=0))]
1439pub fn ttm_squeeze_cuda_batch_dev_py(
1440 py: Python<'_>,
1441 high_f32: PyReadonlyArray1<'_, f32>,
1442 low_f32: PyReadonlyArray1<'_, f32>,
1443 close_f32: PyReadonlyArray1<'_, f32>,
1444 length_range: (usize, usize, usize),
1445 bb_mult_range: (f64, f64, f64),
1446 kc_high_range: (f64, f64, f64),
1447 kc_mid_range: (f64, f64, f64),
1448 kc_low_range: (f64, f64, f64),
1449 device_id: usize,
1450) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1451 if !cuda_available() {
1452 return Err(PyValueError::new_err("CUDA not available"));
1453 }
1454 let h = high_f32.as_slice()?;
1455 let l = low_f32.as_slice()?;
1456 let c = close_f32.as_slice()?;
1457 let sweep = TtmSqueezeBatchRange {
1458 length: length_range,
1459 bb_mult: bb_mult_range,
1460 kc_high: kc_high_range,
1461 kc_mid: kc_mid_range,
1462 kc_low: kc_low_range,
1463 };
1464 let (mo, sq, ctx, dev_id_u32) = py.allow_threads(|| {
1465 let cuda =
1466 CudaTtmSqueeze::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1467 let ctx = cuda.context_arc();
1468 let dev_id_u32 = cuda.device_id();
1469 match cuda.ttm_squeeze_batch_dev(h, l, c, &sweep) {
1470 Ok((mo, sq)) => Ok((mo, sq, ctx, dev_id_u32)),
1471 Err(e) => Err(PyValueError::new_err(e.to_string())),
1472 }
1473 })?;
1474 Ok((
1475 DeviceArrayF32Py {
1476 inner: mo,
1477 _ctx: Some(ctx.clone()),
1478 device_id: Some(dev_id_u32),
1479 },
1480 DeviceArrayF32Py {
1481 inner: sq,
1482 _ctx: Some(ctx),
1483 device_id: Some(dev_id_u32),
1484 },
1485 ))
1486}
1487
1488#[cfg(all(feature = "python", feature = "cuda"))]
1489#[pyfunction(name = "ttm_squeeze_cuda_many_series_one_param_dev")]
1490#[pyo3(signature = (high_tm_f32, low_tm_f32, close_tm_f32, cols, rows, length, bb_mult, kc_high, kc_mid, kc_low, device_id=0))]
1491pub fn ttm_squeeze_cuda_many_series_one_param_dev_py(
1492 py: Python<'_>,
1493 high_tm_f32: PyReadonlyArray1<'_, f32>,
1494 low_tm_f32: PyReadonlyArray1<'_, f32>,
1495 close_tm_f32: PyReadonlyArray1<'_, f32>,
1496 cols: usize,
1497 rows: usize,
1498 length: usize,
1499 bb_mult: f32,
1500 kc_high: f32,
1501 kc_mid: f32,
1502 kc_low: f32,
1503 device_id: usize,
1504) -> PyResult<(DeviceArrayF32Py, DeviceArrayF32Py)> {
1505 if !cuda_available() {
1506 return Err(PyValueError::new_err("CUDA not available"));
1507 }
1508 let h = high_tm_f32.as_slice()?;
1509 let l = low_tm_f32.as_slice()?;
1510 let c = close_tm_f32.as_slice()?;
1511 let (mo, sq, ctx, dev_id_u32) = py.allow_threads(|| {
1512 let cuda =
1513 CudaTtmSqueeze::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1514 let ctx = cuda.context_arc();
1515 let dev_id_u32 = cuda.device_id();
1516 match cuda.ttm_squeeze_many_series_one_param_time_major_dev(
1517 h, l, c, cols, rows, length, bb_mult, kc_high, kc_mid, kc_low,
1518 ) {
1519 Ok((mo, sq)) => Ok((mo, sq, ctx, dev_id_u32)),
1520 Err(e) => Err(PyValueError::new_err(e.to_string())),
1521 }
1522 })?;
1523 Ok((
1524 DeviceArrayF32Py {
1525 inner: mo,
1526 _ctx: Some(ctx.clone()),
1527 device_id: Some(dev_id_u32),
1528 },
1529 DeviceArrayF32Py {
1530 inner: sq,
1531 _ctx: Some(ctx),
1532 device_id: Some(dev_id_u32),
1533 },
1534 ))
1535}
1536
1537#[cfg(feature = "python")]
1538#[pyclass(name = "TtmSqueezeStream")]
1539pub struct TtmSqueezeStreamPy {
1540 stream: TtmSqueezeStream,
1541}
1542
1543#[cfg(feature = "python")]
1544#[pymethods]
1545impl TtmSqueezeStreamPy {
1546 #[new]
1547 fn new(
1548 length: usize,
1549 bb_mult: f64,
1550 kc_mult_high: f64,
1551 kc_mult_mid: f64,
1552 kc_mult_low: f64,
1553 ) -> PyResult<Self> {
1554 let params = TtmSqueezeParams {
1555 length: Some(length),
1556 bb_mult: Some(bb_mult),
1557 kc_mult_high: Some(kc_mult_high),
1558 kc_mult_mid: Some(kc_mult_mid),
1559 kc_mult_low: Some(kc_mult_low),
1560 };
1561 Ok(Self {
1562 stream: TtmSqueezeStream::try_new(params)
1563 .map_err(|e| PyValueError::new_err(e.to_string()))?,
1564 })
1565 }
1566
1567 fn update(&mut self, high: f64, low: f64, close: f64) -> Option<(f64, f64)> {
1568 self.stream.update(high, low, close)
1569 }
1570}
1571
1572#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1573#[derive(Serialize, Deserialize)]
1574pub struct TtmSqueezeJsResult {
1575 pub values: Vec<f64>,
1576 pub rows: usize,
1577 pub cols: usize,
1578}
1579
1580#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1581#[wasm_bindgen(js_name = ttm_squeeze)]
1582pub fn ttm_squeeze_js(
1583 high: &[f64],
1584 low: &[f64],
1585 close: &[f64],
1586 length: usize,
1587 bb_mult: f64,
1588 kc_mult_high: f64,
1589 kc_mult_mid: f64,
1590 kc_mult_low: f64,
1591) -> Result<JsValue, JsValue> {
1592 let params = TtmSqueezeParams {
1593 length: Some(length),
1594 bb_mult: Some(bb_mult),
1595 kc_mult_high: Some(kc_mult_high),
1596 kc_mult_mid: Some(kc_mult_mid),
1597 kc_mult_low: Some(kc_mult_low),
1598 };
1599
1600 let input = TtmSqueezeInput::from_slices(high, low, close, params);
1601
1602 let result = ttm_squeeze(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1603
1604 let cols = result.momentum.len();
1605 let mut values = Vec::with_capacity(2 * cols);
1606 values.extend_from_slice(&result.momentum);
1607 values.extend_from_slice(&result.squeeze);
1608
1609 let js_result = TtmSqueezeJsResult {
1610 values,
1611 rows: 2,
1612 cols,
1613 };
1614
1615 serde_wasm_bindgen::to_value(&js_result)
1616 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1617}
1618
1619#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1620#[wasm_bindgen(js_name = ttm_squeeze_into)]
1621pub fn ttm_squeeze_into_js(
1622 high: &[f64],
1623 low: &[f64],
1624 close: &[f64],
1625 length: usize,
1626 bb_mult: f64,
1627 kc_mult_high: f64,
1628 kc_mult_mid: f64,
1629 kc_mult_low: f64,
1630 out_momentum: &mut [f64],
1631 out_squeeze: &mut [f64],
1632) -> Result<(), JsValue> {
1633 if high.len() != low.len() || low.len() != close.len() {
1634 return Err(JsValue::from_str("slice length mismatch"));
1635 }
1636 if out_momentum.len() != close.len() || out_squeeze.len() != close.len() {
1637 return Err(JsValue::from_str("output length mismatch"));
1638 }
1639
1640 let params = TtmSqueezeParams {
1641 length: Some(length),
1642 bb_mult: Some(bb_mult),
1643 kc_mult_high: Some(kc_mult_high),
1644 kc_mult_mid: Some(kc_mult_mid),
1645 kc_mult_low: Some(kc_mult_low),
1646 };
1647
1648 let input = TtmSqueezeInput::from_slices(high, low, close, params);
1649
1650 ttm_squeeze_into_slices(out_momentum, out_squeeze, &input, detect_best_kernel())
1651 .map_err(|e| JsValue::from_str(&e.to_string()))
1652}
1653
1654#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1655#[wasm_bindgen]
1656pub fn ttm_squeeze_alloc(len: usize) -> *mut f64 {
1657 let mut v = Vec::<f64>::with_capacity(len);
1658 let p = v.as_mut_ptr();
1659 core::mem::forget(v);
1660 p
1661}
1662
1663#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1664#[wasm_bindgen]
1665pub fn ttm_squeeze_free(ptr: *mut f64, len: usize) {
1666 unsafe {
1667 let _ = Vec::from_raw_parts(ptr, len, len);
1668 }
1669}
1670
1671#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1672#[wasm_bindgen(js_name = ttm_squeeze_into_ptrs)]
1673pub fn ttm_squeeze_into_js_ptrs(
1674 high: *const f64,
1675 low: *const f64,
1676 close: *const f64,
1677 out_momentum: *mut f64,
1678 out_squeeze: *mut f64,
1679 len: usize,
1680 length: usize,
1681 bb_mult: f64,
1682 kc_high: f64,
1683 kc_mid: f64,
1684 kc_low: f64,
1685) -> Result<(), JsValue> {
1686 if high.is_null()
1687 || low.is_null()
1688 || close.is_null()
1689 || out_momentum.is_null()
1690 || out_squeeze.is_null()
1691 {
1692 return Err(JsValue::from_str("null pointer"));
1693 }
1694
1695 if len == 0 {
1696 return Err(JsValue::from_str("ttm_squeeze: Input data slice is empty."));
1697 }
1698
1699 if length == 0 || length > len {
1700 return Err(JsValue::from_str(&format!(
1701 "ttm_squeeze: Invalid period: period = {}, data length = {}",
1702 length, len
1703 )));
1704 }
1705
1706 unsafe {
1707 let h = core::slice::from_raw_parts(high, len);
1708 let l = core::slice::from_raw_parts(low, len);
1709 let c = core::slice::from_raw_parts(close, len);
1710
1711 let params = TtmSqueezeParams {
1712 length: Some(length),
1713 bb_mult: Some(bb_mult),
1714 kc_mult_high: Some(kc_high),
1715 kc_mult_mid: Some(kc_mid),
1716 kc_mult_low: Some(kc_low),
1717 };
1718
1719 let input = TtmSqueezeInput::from_slices(h, l, c, params);
1720 let out = ttm_squeeze(&input).map_err(|e| JsValue::from_str(&e.to_string()))?;
1721
1722 let dst_momentum = core::slice::from_raw_parts_mut(out_momentum, len);
1723 let dst_squeeze = core::slice::from_raw_parts_mut(out_squeeze, len);
1724 dst_momentum.copy_from_slice(&out.momentum);
1725 dst_squeeze.copy_from_slice(&out.squeeze);
1726
1727 Ok(())
1728 }
1729}
1730
1731use crate::utilities::helpers::{
1732 detect_best_batch_kernel, init_matrix_prefixes, make_uninit_matrix,
1733};
1734
1735#[derive(Clone, Debug)]
1736pub struct TtmSqueezeBatchRange {
1737 pub length: (usize, usize, usize),
1738 pub bb_mult: (f64, f64, f64),
1739 pub kc_high: (f64, f64, f64),
1740 pub kc_mid: (f64, f64, f64),
1741 pub kc_low: (f64, f64, f64),
1742}
1743
1744impl Default for TtmSqueezeBatchRange {
1745 fn default() -> Self {
1746 Self {
1747 length: (20, 269, 1),
1748 bb_mult: (2.0, 2.0, 0.0),
1749 kc_high: (1.0, 1.0, 0.0),
1750 kc_mid: (1.5, 1.5, 0.0),
1751 kc_low: (2.0, 2.0, 0.0),
1752 }
1753 }
1754}
1755
1756#[derive(Clone, Debug, Default)]
1757pub struct TtmSqueezeBatchBuilder {
1758 range: TtmSqueezeBatchRange,
1759 kernel: Kernel,
1760}
1761
1762impl TtmSqueezeBatchBuilder {
1763 pub fn new() -> Self {
1764 Self::default()
1765 }
1766
1767 pub fn kernel(mut self, k: Kernel) -> Self {
1768 self.kernel = k;
1769 self
1770 }
1771
1772 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
1773 self.range.length = (start, end, step);
1774 self
1775 }
1776
1777 pub fn bb_mult_range(mut self, start: f64, end: f64, step: f64) -> Self {
1778 self.range.bb_mult = (start, end, step);
1779 self
1780 }
1781
1782 pub fn kc_high_range(mut self, start: f64, end: f64, step: f64) -> Self {
1783 self.range.kc_high = (start, end, step);
1784 self
1785 }
1786
1787 pub fn kc_mid_range(mut self, start: f64, end: f64, step: f64) -> Self {
1788 self.range.kc_mid = (start, end, step);
1789 self
1790 }
1791
1792 pub fn kc_low_range(mut self, start: f64, end: f64, step: f64) -> Self {
1793 self.range.kc_low = (start, end, step);
1794 self
1795 }
1796
1797 pub fn apply_candles(
1798 self,
1799 candles: &Candles,
1800 ) -> Result<TtmSqueezeBatchOutput, TtmSqueezeError> {
1801 ttm_squeeze_batch_with_kernel(
1802 &candles.high,
1803 &candles.low,
1804 &candles.close,
1805 &self.range,
1806 self.kernel,
1807 )
1808 }
1809
1810 pub fn apply_slices(
1811 self,
1812 high: &[f64],
1813 low: &[f64],
1814 close: &[f64],
1815 ) -> Result<TtmSqueezeBatchOutput, TtmSqueezeError> {
1816 ttm_squeeze_batch_with_kernel(high, low, close, &self.range, self.kernel)
1817 }
1818
1819 pub fn with_default_candles(
1820 candles: &Candles,
1821 ) -> Result<TtmSqueezeBatchOutput, TtmSqueezeError> {
1822 TtmSqueezeBatchBuilder::new()
1823 .kernel(Kernel::Auto)
1824 .apply_candles(candles)
1825 }
1826}
1827
1828#[derive(Clone, Debug)]
1829pub struct TtmSqueezeBatchOutput {
1830 pub momentum: Vec<f64>,
1831 pub squeeze: Vec<f64>,
1832 pub combos: Vec<TtmSqueezeParams>,
1833 pub rows: usize,
1834 pub cols: usize,
1835}
1836
1837impl TtmSqueezeBatchOutput {
1838 pub fn row_for_params(&self, p: &TtmSqueezeParams) -> Option<usize> {
1839 self.combos.iter().position(|q| {
1840 q.length.unwrap_or(20) == p.length.unwrap_or(20)
1841 && (q.bb_mult.unwrap_or(2.0) - p.bb_mult.unwrap_or(2.0)).abs() < 1e-12
1842 && (q.kc_mult_high.unwrap_or(1.0) - p.kc_mult_high.unwrap_or(1.0)).abs() < 1e-12
1843 && (q.kc_mult_mid.unwrap_or(1.5) - p.kc_mult_mid.unwrap_or(1.5)).abs() < 1e-12
1844 && (q.kc_mult_low.unwrap_or(2.0) - p.kc_mult_low.unwrap_or(2.0)).abs() < 1e-12
1845 })
1846 }
1847
1848 pub fn momentum_for(&self, p: &TtmSqueezeParams) -> Option<&[f64]> {
1849 self.row_for_params(p).map(|r| {
1850 let s = r * self.cols;
1851 &self.momentum[s..s + self.cols]
1852 })
1853 }
1854
1855 pub fn squeeze_for(&self, p: &TtmSqueezeParams) -> Option<&[f64]> {
1856 self.row_for_params(p).map(|r| {
1857 let s = r * self.cols;
1858 &self.squeeze[s..s + self.cols]
1859 })
1860 }
1861}
1862
1863fn axis_usize(a: (usize, usize, usize)) -> Result<Vec<usize>, TtmSqueezeError> {
1864 let (start, end, step) = a;
1865 if step == 0 || start == end {
1866 return Ok(vec![start]);
1867 }
1868
1869 let mut v = Vec::new();
1870 if start < end {
1871 let mut x = start;
1872 while x <= end {
1873 v.push(x);
1874 match x.checked_add(step) {
1875 Some(next) => {
1876 if next == x {
1877 break;
1878 }
1879 x = next;
1880 }
1881 None => break,
1882 }
1883 }
1884 } else {
1885 let mut x = start;
1886 loop {
1887 if x < end {
1888 break;
1889 }
1890 v.push(x);
1891 match x.checked_sub(step) {
1892 Some(next) => {
1893 if next == x {
1894 break;
1895 }
1896 x = next;
1897 }
1898 None => break,
1899 }
1900 }
1901 }
1902
1903 if v.is_empty() {
1904 return Err(TtmSqueezeError::InvalidRange {
1905 start: start as f64,
1906 end: end as f64,
1907 step: step as f64,
1908 });
1909 }
1910
1911 Ok(v)
1912}
1913
1914fn axis_f64(a: (f64, f64, f64)) -> Result<Vec<f64>, TtmSqueezeError> {
1915 let (start, end, step) = a;
1916 let step_mag = step.abs();
1917 if step_mag < 1e-12 || (start - end).abs() < 1e-12 {
1918 return Ok(vec![start]);
1919 }
1920
1921 let mut v = Vec::new();
1922 let mut x = start;
1923 if start <= end {
1924 while x <= end + 1e-12 {
1925 v.push(x);
1926 x += step_mag;
1927 }
1928 } else {
1929 while x >= end - 1e-12 {
1930 v.push(x);
1931 x -= step_mag;
1932 }
1933 }
1934
1935 if v.is_empty() {
1936 return Err(TtmSqueezeError::InvalidRange { start, end, step });
1937 }
1938
1939 Ok(v)
1940}
1941
1942fn expand_grid_squeeze(r: &TtmSqueezeBatchRange) -> Result<Vec<TtmSqueezeParams>, TtmSqueezeError> {
1943 let lengths = axis_usize(r.length)?;
1944 let bb_mults = axis_f64(r.bb_mult)?;
1945 let kc_highs = axis_f64(r.kc_high)?;
1946 let kc_mids = axis_f64(r.kc_mid)?;
1947 let kc_lows = axis_f64(r.kc_low)?;
1948
1949 let cap = lengths
1950 .len()
1951 .checked_mul(bb_mults.len())
1952 .and_then(|v| v.checked_mul(kc_highs.len()))
1953 .and_then(|v| v.checked_mul(kc_mids.len()))
1954 .and_then(|v| v.checked_mul(kc_lows.len()))
1955 .ok_or(TtmSqueezeError::InvalidRange {
1956 start: r.length.0 as f64,
1957 end: r.length.1 as f64,
1958 step: r.length.2 as f64,
1959 })?;
1960
1961 if cap == 0 {
1962 return Err(TtmSqueezeError::InvalidRange {
1963 start: r.length.0 as f64,
1964 end: r.length.1 as f64,
1965 step: r.length.2 as f64,
1966 });
1967 }
1968
1969 let mut out = Vec::with_capacity(cap);
1970
1971 for &l in &lengths {
1972 for &bb in &bb_mults {
1973 for &h in &kc_highs {
1974 for &m in &kc_mids {
1975 for &lo in &kc_lows {
1976 out.push(TtmSqueezeParams {
1977 length: Some(l),
1978 bb_mult: Some(bb),
1979 kc_mult_high: Some(h),
1980 kc_mult_mid: Some(m),
1981 kc_mult_low: Some(lo),
1982 });
1983 }
1984 }
1985 }
1986 }
1987 }
1988
1989 Ok(out)
1990}
1991
1992pub fn ttm_squeeze_batch_with_kernel(
1993 high: &[f64],
1994 low: &[f64],
1995 close: &[f64],
1996 sweep: &TtmSqueezeBatchRange,
1997 k: Kernel,
1998) -> Result<TtmSqueezeBatchOutput, TtmSqueezeError> {
1999 if high.len() != low.len() || low.len() != close.len() {
2000 return Err(TtmSqueezeError::InconsistentSliceLengths {
2001 high: high.len(),
2002 low: low.len(),
2003 close: close.len(),
2004 });
2005 }
2006
2007 let combos = expand_grid_squeeze(sweep)?;
2008 let rows = combos.len();
2009 let cols = close.len();
2010 let _total = rows
2011 .checked_mul(cols)
2012 .ok_or(TtmSqueezeError::InvalidRange {
2013 start: sweep.length.0 as f64,
2014 end: sweep.length.1 as f64,
2015 step: sweep.length.2 as f64,
2016 })?;
2017
2018 let first = close
2019 .iter()
2020 .position(|x| !x.is_nan())
2021 .ok_or(TtmSqueezeError::AllValuesNaN)?;
2022 let warmup_periods: Vec<usize> = combos
2023 .iter()
2024 .map(|c| first + c.length.unwrap() - 1)
2025 .collect();
2026
2027 let mut mom_mu = make_uninit_matrix(rows, cols);
2028 let mut sqz_mu = make_uninit_matrix(rows, cols);
2029 init_matrix_prefixes(&mut mom_mu, cols, &warmup_periods);
2030 init_matrix_prefixes(&mut sqz_mu, cols, &warmup_periods);
2031
2032 let mut mom_guard = core::mem::ManuallyDrop::new(mom_mu);
2033 let mut sqz_guard = core::mem::ManuallyDrop::new(sqz_mu);
2034
2035 let mom_slice: &mut [f64] = unsafe {
2036 core::slice::from_raw_parts_mut(mom_guard.as_mut_ptr() as *mut f64, mom_guard.len())
2037 };
2038 let sqz_slice: &mut [f64] = unsafe {
2039 core::slice::from_raw_parts_mut(sqz_guard.as_mut_ptr() as *mut f64, sqz_guard.len())
2040 };
2041
2042 let chosen_batch = match k {
2043 Kernel::Auto => detect_best_batch_kernel(),
2044 kb if kb.is_batch() => kb,
2045 other => {
2046 return Err(TtmSqueezeError::InvalidKernelForBatch(other));
2047 }
2048 };
2049
2050 let row_kernel = match chosen_batch {
2051 Kernel::Avx512Batch => Kernel::Avx512,
2052 Kernel::Avx2Batch => Kernel::Avx2,
2053 Kernel::ScalarBatch => Kernel::Scalar,
2054 _ => unreachable!(),
2055 };
2056
2057 if chosen_batch == Kernel::ScalarBatch {
2058 let mut lengths: Vec<usize> = Vec::new();
2059 for p in &combos {
2060 let l = p.length.unwrap_or(20);
2061 if !lengths.contains(&l) {
2062 lengths.push(l);
2063 }
2064 }
2065
2066 for l in lengths {
2067 if l < 2 || first + l > cols {
2068 continue;
2069 }
2070
2071 struct RowCfg {
2072 row: usize,
2073 bb_sq: f64,
2074 kc_low_sq: f64,
2075 kc_mid_sq: f64,
2076 kc_high_sq: f64,
2077 }
2078 let mut group: Vec<RowCfg> = Vec::new();
2079 for (idx, p) in combos.iter().enumerate() {
2080 if p.length.unwrap_or(20) == l {
2081 let bb = p.bb_mult.unwrap_or(2.0);
2082 let kh = p.kc_mult_high.unwrap_or(1.0);
2083 let km = p.kc_mult_mid.unwrap_or(1.5);
2084 let kl = p.kc_mult_low.unwrap_or(2.0);
2085 group.push(RowCfg {
2086 row: idx,
2087 bb_sq: bb * bb,
2088 kc_low_sq: kl * kl,
2089 kc_mid_sq: km * km,
2090 kc_high_sq: kh * kh,
2091 });
2092 }
2093 }
2094 if group.is_empty() {
2095 continue;
2096 }
2097
2098 let n = l as f64;
2099 let sx = 0.5 * n * (n - 1.0);
2100 let sx2 = (n - 1.0) * n * (2.0 * n - 1.0) / 6.0;
2101 let den = n * sx2 - sx * sx;
2102 let inv_den = 1.0 / den;
2103 let inv_n = 1.0 / n;
2104 let half_nm1 = 0.5 * (n - 1.0);
2105
2106 let mut cbuf = vec![0.0f64; l];
2107 let mut trbuf = vec![0.0f64; l];
2108 let (mut cpos, mut trpos) = (0usize, 0usize);
2109 let mut sum0 = 0.0f64;
2110 let mut sum1 = 0.0f64;
2111 let mut sumsq = 0.0f64;
2112 let mut tr_sum = 0.0f64;
2113
2114 let cap = l;
2115 let mut max_q = vec![0usize; cap];
2116 let mut min_q = vec![0usize; cap];
2117 let (mut max_head, mut max_tail, mut max_len) = (0usize, 0usize, 0usize);
2118 let (mut min_head, mut min_tail, mut min_len) = (0usize, 0usize, 0usize);
2119
2120 let warm = first + l - 1;
2121
2122 let mut r = 0usize;
2123 let mut i = first;
2124 while i <= warm {
2125 let c = close[i];
2126 cbuf[cpos] = c;
2127 sum0 += c;
2128 sumsq = c.mul_add(c, sumsq);
2129 sum1 += (r as f64) * c;
2130
2131 let tr_val = if i == first {
2132 high[i] - low[i]
2133 } else {
2134 let pc = close[i - 1];
2135 let hl = high[i] - low[i];
2136 let hc = (high[i] - pc).abs();
2137 let lc = (low[i] - pc).abs();
2138 hl.max(hc).max(lc)
2139 };
2140 trbuf[trpos] = tr_val;
2141 tr_sum += tr_val;
2142
2143 while max_len > 0 {
2144 let back_pos = if max_tail == 0 { cap - 1 } else { max_tail - 1 };
2145 let back_idx = max_q[back_pos];
2146 if high[i] <= high[back_idx] {
2147 break;
2148 }
2149 max_tail = back_pos;
2150 max_len -= 1;
2151 }
2152 max_q[max_tail] = i;
2153 max_tail += 1;
2154 if max_tail == cap {
2155 max_tail = 0;
2156 }
2157 max_len += 1;
2158
2159 while min_len > 0 {
2160 let back_pos = if min_tail == 0 { cap - 1 } else { min_tail - 1 };
2161 let back_idx = min_q[back_pos];
2162 if low[i] >= low[back_idx] {
2163 break;
2164 }
2165 min_tail = back_pos;
2166 min_len -= 1;
2167 }
2168 min_q[min_tail] = i;
2169 min_tail += 1;
2170 if min_tail == cap {
2171 min_tail = 0;
2172 }
2173 min_len += 1;
2174
2175 cpos += 1;
2176 if cpos == l {
2177 cpos = 0;
2178 }
2179 trpos += 1;
2180 if trpos == l {
2181 trpos = 0;
2182 }
2183 r += 1;
2184 i += 1;
2185 }
2186
2187 let m = sum0 * inv_n;
2188 let var = (-m).mul_add(m, sumsq * inv_n);
2189 let var_pos = if var > 0.0 { var } else { 0.0 };
2190 let dkc = tr_sum * inv_n;
2191 let dkc2 = dkc * dkc;
2192
2193 let hi_idx = max_q[max_head];
2194 let lo_idx = min_q[min_head];
2195 let highest = high[hi_idx];
2196 let lowest = low[lo_idx];
2197 let midpoint = 0.5 * (highest + lowest);
2198 let avg = 0.5 * (midpoint + m);
2199 let sy = sum0 - avg * n;
2200 let sxy = sum1 - avg * sx;
2201 let slope = n.mul_add(sxy, -(sx * sy)) * inv_den;
2202 let mom_val = sy * inv_n + slope * half_nm1;
2203
2204 for rc in &group {
2205 let bbv = rc.bb_sq * var_pos;
2206 let t_low = rc.kc_low_sq * dkc2;
2207 let t_mid = rc.kc_mid_sq * dkc2;
2208 let t_high = rc.kc_high_sq * dkc2;
2209 let sqz = if bbv > t_low {
2210 0.0
2211 } else if bbv <= t_high {
2212 3.0
2213 } else if bbv <= t_mid {
2214 2.0
2215 } else {
2216 1.0
2217 };
2218 let s_off = rc.row * cols + warm;
2219 let m_off = rc.row * cols + warm;
2220 sqz_slice[s_off] = sqz;
2221 mom_slice[m_off] = mom_val;
2222 }
2223
2224 let mut i = warm + 1;
2225 while i < cols {
2226 let start_idx = i + 1 - l;
2227
2228 while max_len > 0 {
2229 let front_idx = max_q[max_head];
2230 if front_idx >= start_idx {
2231 break;
2232 }
2233 max_head += 1;
2234 if max_head == cap {
2235 max_head = 0;
2236 }
2237 max_len -= 1;
2238 }
2239 while min_len > 0 {
2240 let front_idx = min_q[min_head];
2241 if front_idx >= start_idx {
2242 break;
2243 }
2244 min_head += 1;
2245 if min_head == cap {
2246 min_head = 0;
2247 }
2248 min_len -= 1;
2249 }
2250
2251 while max_len > 0 {
2252 let back_pos = if max_tail == 0 { cap - 1 } else { max_tail - 1 };
2253 let back_idx = max_q[back_pos];
2254 if high[i] <= high[back_idx] {
2255 break;
2256 }
2257 max_tail = back_pos;
2258 max_len -= 1;
2259 }
2260 max_q[max_tail] = i;
2261 max_tail += 1;
2262 if max_tail == cap {
2263 max_tail = 0;
2264 }
2265 max_len += 1;
2266
2267 while min_len > 0 {
2268 let back_pos = if min_tail == 0 { cap - 1 } else { min_tail - 1 };
2269 let back_idx = min_q[back_pos];
2270 if low[i] >= low[back_idx] {
2271 break;
2272 }
2273 min_tail = back_pos;
2274 min_len -= 1;
2275 }
2276 min_q[min_tail] = i;
2277 min_tail += 1;
2278 if min_tail == cap {
2279 min_tail = 0;
2280 }
2281 min_len += 1;
2282
2283 let old = cbuf[cpos];
2284 let new = close[i];
2285 let sum0_old = sum0;
2286 sum0 += new - old;
2287 sumsq = new.mul_add(new, sumsq - old * old);
2288 sum1 = sum1 - sum0_old + old + (n - 1.0) * new;
2289 cbuf[cpos] = new;
2290 cpos += 1;
2291 if cpos == l {
2292 cpos = 0;
2293 }
2294
2295 let old_tr = trbuf[trpos];
2296 let pc = close[i - 1];
2297 let hi_i = high[i];
2298 let lo_i = low[i];
2299 let hl = hi_i - lo_i;
2300 let hc = (hi_i - pc).abs();
2301 let lc = (lo_i - pc).abs();
2302 let tr_new = hl.max(hc).max(lc);
2303 tr_sum += tr_new - old_tr;
2304 trbuf[trpos] = tr_new;
2305 trpos += 1;
2306 if trpos == l {
2307 trpos = 0;
2308 }
2309
2310 let m = sum0 * inv_n;
2311 let var = (-m).mul_add(m, sumsq * inv_n);
2312 let var_pos = if var > 0.0 { var } else { 0.0 };
2313 let dkc = tr_sum * inv_n;
2314 let dkc2 = dkc * dkc;
2315
2316 let hi_idx = max_q[max_head];
2317 let lo_idx = min_q[min_head];
2318 let highest = high[hi_idx];
2319 let lowest = low[lo_idx];
2320 let midpoint = 0.5 * (highest + lowest);
2321 let avg = 0.5 * (midpoint + m);
2322 let sy = sum0 - avg * n;
2323 let sxy = sum1 - avg * sx;
2324 let slope = n.mul_add(sxy, -(sx * sy)) * inv_den;
2325 let mom_val = sy * inv_n + slope * half_nm1;
2326
2327 for rc in &group {
2328 let bbv = rc.bb_sq * var_pos;
2329 let t_low = rc.kc_low_sq * dkc2;
2330 let t_mid = rc.kc_mid_sq * dkc2;
2331 let t_high = rc.kc_high_sq * dkc2;
2332 let sqz = if bbv > t_low {
2333 0.0
2334 } else if bbv <= t_high {
2335 3.0
2336 } else if bbv <= t_mid {
2337 2.0
2338 } else {
2339 1.0
2340 };
2341 let s_off = rc.row * cols + i;
2342 let m_off = rc.row * cols + i;
2343 sqz_slice[s_off] = sqz;
2344 mom_slice[m_off] = mom_val;
2345 }
2346
2347 i += 1;
2348 }
2349 }
2350 } else {
2351 for (row, p) in combos.iter().enumerate() {
2352 let input = TtmSqueezeInput::from_slices(high, low, close, p.clone());
2353 let dst_m = &mut mom_slice[row * cols..(row + 1) * cols];
2354 let dst_s = &mut sqz_slice[row * cols..(row + 1) * cols];
2355
2356 ttm_squeeze_into_slices(dst_m, dst_s, &input, row_kernel)?;
2357 }
2358 }
2359
2360 let momentum = unsafe {
2361 Vec::from_raw_parts(
2362 mom_guard.as_mut_ptr() as *mut f64,
2363 mom_guard.len(),
2364 mom_guard.capacity(),
2365 )
2366 };
2367
2368 let squeeze = unsafe {
2369 Vec::from_raw_parts(
2370 sqz_guard.as_mut_ptr() as *mut f64,
2371 sqz_guard.len(),
2372 sqz_guard.capacity(),
2373 )
2374 };
2375
2376 Ok(TtmSqueezeBatchOutput {
2377 momentum,
2378 squeeze,
2379 combos,
2380 rows,
2381 cols,
2382 })
2383}
2384
2385#[cfg(feature = "python")]
2386#[pyfunction(name = "ttm_squeeze_batch")]
2387#[pyo3(signature = (high, low, close, length_range, bb_mult_range, kc_high_range, kc_mid_range, kc_low_range, kernel=None))]
2388pub fn ttm_squeeze_batch_py<'py>(
2389 py: Python<'py>,
2390 high: numpy::PyReadonlyArray1<'py, f64>,
2391 low: numpy::PyReadonlyArray1<'py, f64>,
2392 close: numpy::PyReadonlyArray1<'py, f64>,
2393 length_range: (usize, usize, usize),
2394 bb_mult_range: (f64, f64, f64),
2395 kc_high_range: (f64, f64, f64),
2396 kc_mid_range: (f64, f64, f64),
2397 kc_low_range: (f64, f64, f64),
2398 kernel: Option<&str>,
2399) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
2400 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2401
2402 let h = high.as_slice()?;
2403 let l = low.as_slice()?;
2404 let c = close.as_slice()?;
2405
2406 let sweep = TtmSqueezeBatchRange {
2407 length: length_range,
2408 bb_mult: bb_mult_range,
2409 kc_high: kc_high_range,
2410 kc_mid: kc_mid_range,
2411 kc_low: kc_low_range,
2412 };
2413
2414 let kern = validate_kernel(kernel, true)?;
2415
2416 let out = py
2417 .allow_threads(|| ttm_squeeze_batch_with_kernel(h, l, c, &sweep, kern))
2418 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2419
2420 let rows = out.rows;
2421 let cols = out.cols;
2422 let dict = pyo3::types::PyDict::new(py);
2423
2424 let mom = unsafe { PyArray1::<f64>::from_vec(py, out.momentum).reshape((rows, cols))? };
2425 let sqz = unsafe { PyArray1::<f64>::from_vec(py, out.squeeze).reshape((rows, cols))? };
2426
2427 dict.set_item("momentum", mom)?;
2428 dict.set_item("squeeze", sqz)?;
2429 dict.set_item(
2430 "lengths",
2431 out.combos
2432 .iter()
2433 .map(|p| p.length.unwrap() as u64)
2434 .collect::<Vec<_>>()
2435 .into_pyarray(py),
2436 )?;
2437 dict.set_item(
2438 "bb_mults",
2439 out.combos
2440 .iter()
2441 .map(|p| p.bb_mult.unwrap())
2442 .collect::<Vec<_>>()
2443 .into_pyarray(py),
2444 )?;
2445 dict.set_item(
2446 "kc_highs",
2447 out.combos
2448 .iter()
2449 .map(|p| p.kc_mult_high.unwrap())
2450 .collect::<Vec<_>>()
2451 .into_pyarray(py),
2452 )?;
2453 dict.set_item(
2454 "kc_mids",
2455 out.combos
2456 .iter()
2457 .map(|p| p.kc_mult_mid.unwrap())
2458 .collect::<Vec<_>>()
2459 .into_pyarray(py),
2460 )?;
2461 dict.set_item(
2462 "kc_lows",
2463 out.combos
2464 .iter()
2465 .map(|p| p.kc_mult_low.unwrap())
2466 .collect::<Vec<_>>()
2467 .into_pyarray(py),
2468 )?;
2469
2470 Ok(dict)
2471}
2472
2473#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2474#[derive(Serialize, Deserialize)]
2475pub struct TtmSqueezeBatchConfig {
2476 pub length_range: (usize, usize, usize),
2477 pub bb_mult_range: (f64, f64, f64),
2478 pub kc_high_range: (f64, f64, f64),
2479 pub kc_mid_range: (f64, f64, f64),
2480 pub kc_low_range: (f64, f64, f64),
2481}
2482
2483#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2484#[derive(Serialize, Deserialize)]
2485pub struct TtmSqueezeBatchJsOutput {
2486 pub values: Vec<f64>,
2487 pub rows: usize,
2488 pub cols: usize,
2489 pub combos: Vec<TtmSqueezeParams>,
2490}
2491
2492#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2493#[wasm_bindgen(js_name = "ttm_squeeze_batch")]
2494pub fn ttm_squeeze_batch_unified_js(
2495 high: &[f64],
2496 low: &[f64],
2497 close: &[f64],
2498 config: JsValue,
2499) -> Result<JsValue, JsValue> {
2500 let cfg: TtmSqueezeBatchConfig = serde_wasm_bindgen::from_value(config)
2501 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2502
2503 let sweep = TtmSqueezeBatchRange {
2504 length: cfg.length_range,
2505 bb_mult: cfg.bb_mult_range,
2506 kc_high: cfg.kc_high_range,
2507 kc_mid: cfg.kc_mid_range,
2508 kc_low: cfg.kc_low_range,
2509 };
2510
2511 let out = ttm_squeeze_batch_with_kernel(high, low, close, &sweep, detect_best_batch_kernel())
2512 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2513
2514 let mut values = Vec::with_capacity(2 * out.rows * out.cols);
2515 for r in 0..out.rows {
2516 let s = r * out.cols;
2517 values.extend_from_slice(&out.momentum[s..s + out.cols]);
2518 values.extend_from_slice(&out.squeeze[s..s + out.cols]);
2519 }
2520
2521 let js = TtmSqueezeBatchJsOutput {
2522 values,
2523 rows: out.rows * 2,
2524 cols: out.cols,
2525 combos: out.combos,
2526 };
2527
2528 serde_wasm_bindgen::to_value(&js)
2529 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
2530}
2531
2532#[cfg(test)]
2533mod tests {
2534 use super::*;
2535 use crate::utilities::data_loader::read_candles_from_csv;
2536 use std::error::Error;
2537
2538 macro_rules! skip_if_unsupported {
2539 ($kernel:expr, $test_name:expr) => {
2540 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
2541 {
2542 if matches!(
2543 $kernel,
2544 Kernel::Avx2 | Kernel::Avx512 | Kernel::Avx2Batch | Kernel::Avx512Batch
2545 ) {
2546 eprintln!("Skipping {} - AVX not supported", $test_name);
2547 return Ok(());
2548 }
2549 }
2550 };
2551 }
2552
2553 fn check_ttm_squeeze_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2554 skip_if_unsupported!(kernel, test_name);
2555 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2556 let candles = read_candles_from_csv(file_path)?;
2557
2558 let input = TtmSqueezeInput::with_default_candles(&candles);
2559 let result = ttm_squeeze_with_kernel(&input, kernel)?;
2560
2561 assert_eq!(result.momentum.len(), candles.close.len());
2562 assert_eq!(result.squeeze.len(), candles.close.len());
2563
2564 let expected_momentum = [
2565 -167.98676428571423,
2566 -154.99159285714336,
2567 -148.98427857142892,
2568 -131.80910714285744,
2569 -89.35822142857162,
2570 ];
2571
2572 let expected_squeeze = [0.0, 0.0, 0.0, 0.0, 1.0];
2573
2574 let warmup_period = 19;
2575
2576 for (i, &expected) in expected_momentum.iter().enumerate() {
2577 let actual = result.momentum[warmup_period + i];
2578 let diff = (actual - expected).abs();
2579 assert!(
2580 diff < 0.0001,
2581 "[{}] Momentum at index {}: expected {}, got {}, diff: {}",
2582 test_name,
2583 i,
2584 expected,
2585 actual,
2586 diff
2587 );
2588 }
2589
2590 for (i, &expected) in expected_squeeze.iter().enumerate() {
2591 let actual = result.squeeze[warmup_period + i];
2592 assert_eq!(
2593 actual, expected,
2594 "[{}] Squeeze mismatch at index {}: expected {}, got {}",
2595 test_name, i, expected, actual
2596 );
2597 }
2598
2599 let first_valid_momentum = result.momentum.iter().position(|&x| !x.is_nan());
2600 let first_valid_squeeze = result.squeeze.iter().position(|&x| !x.is_nan());
2601
2602 assert!(
2603 first_valid_momentum.is_some(),
2604 "[{}] No valid momentum values found",
2605 test_name
2606 );
2607 assert!(
2608 first_valid_squeeze.is_some(),
2609 "[{}] No valid squeeze values found",
2610 test_name
2611 );
2612
2613 if let Some(first_mom) = first_valid_momentum {
2614 for i in 0..first_mom.min(10) {
2615 assert!(
2616 result.momentum[i].is_nan(),
2617 "[{}] Expected NaN at index {}",
2618 test_name,
2619 i
2620 );
2621 }
2622 }
2623
2624 if let Some(first_sqz) = first_valid_squeeze {
2625 for i in 0..first_sqz.min(10) {
2626 assert!(
2627 result.squeeze[i].is_nan(),
2628 "[{}] Expected NaN at index {}",
2629 test_name,
2630 i
2631 );
2632 }
2633 }
2634
2635 Ok(())
2636 }
2637
2638 fn check_ttm_squeeze_partial_params(
2639 test_name: &str,
2640 kernel: Kernel,
2641 ) -> Result<(), Box<dyn Error>> {
2642 skip_if_unsupported!(kernel, test_name);
2643 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2644 let candles = read_candles_from_csv(file_path)?;
2645
2646 let params = TtmSqueezeParams {
2647 length: None,
2648 bb_mult: None,
2649 kc_mult_high: None,
2650 kc_mult_mid: None,
2651 kc_mult_low: None,
2652 };
2653
2654 let input = TtmSqueezeInput::from_candles(&candles, params);
2655 let result = ttm_squeeze_with_kernel(&input, kernel)?;
2656
2657 assert_eq!(result.momentum.len(), candles.close.len());
2658 assert_eq!(result.squeeze.len(), candles.close.len());
2659
2660 Ok(())
2661 }
2662
2663 fn check_ttm_squeeze_default_candles(
2664 test_name: &str,
2665 kernel: Kernel,
2666 ) -> Result<(), Box<dyn Error>> {
2667 skip_if_unsupported!(kernel, test_name);
2668 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2669 let candles = read_candles_from_csv(file_path)?;
2670
2671 let input = TtmSqueezeInput::with_default_candles(&candles);
2672 let result = ttm_squeeze_with_kernel(&input, kernel)?;
2673
2674 assert_eq!(result.momentum.len(), candles.close.len());
2675 assert_eq!(result.squeeze.len(), candles.close.len());
2676
2677 Ok(())
2678 }
2679
2680 fn check_ttm_squeeze_zero_period(
2681 test_name: &str,
2682 kernel: Kernel,
2683 ) -> Result<(), Box<dyn Error>> {
2684 skip_if_unsupported!(kernel, test_name);
2685 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2686 let params = TtmSqueezeParams {
2687 length: Some(0),
2688 bb_mult: None,
2689 kc_mult_high: None,
2690 kc_mult_mid: None,
2691 kc_mult_low: None,
2692 };
2693
2694 let input = TtmSqueezeInput::from_slices(&data, &data, &data, params);
2695 let result = ttm_squeeze_with_kernel(&input, kernel);
2696
2697 assert!(
2698 result.is_err(),
2699 "[{}] Should fail with zero period",
2700 test_name
2701 );
2702 Ok(())
2703 }
2704
2705 fn check_ttm_squeeze_period_exceeds_length(
2706 test_name: &str,
2707 kernel: Kernel,
2708 ) -> Result<(), Box<dyn Error>> {
2709 skip_if_unsupported!(kernel, test_name);
2710 let data = vec![1.0, 2.0, 3.0];
2711 let params = TtmSqueezeParams {
2712 length: Some(10),
2713 bb_mult: None,
2714 kc_mult_high: None,
2715 kc_mult_mid: None,
2716 kc_mult_low: None,
2717 };
2718
2719 let input = TtmSqueezeInput::from_slices(&data, &data, &data, params);
2720 let result = ttm_squeeze_with_kernel(&input, kernel);
2721
2722 assert!(
2723 result.is_err(),
2724 "[{}] Should fail when period exceeds length",
2725 test_name
2726 );
2727 Ok(())
2728 }
2729
2730 fn check_ttm_squeeze_very_small_dataset(
2731 test_name: &str,
2732 kernel: Kernel,
2733 ) -> Result<(), Box<dyn Error>> {
2734 skip_if_unsupported!(kernel, test_name);
2735 let data = vec![42.0];
2736 let params = TtmSqueezeParams::default();
2737
2738 let input = TtmSqueezeInput::from_slices(&data, &data, &data, params);
2739 let result = ttm_squeeze_with_kernel(&input, kernel);
2740
2741 assert!(
2742 result.is_err(),
2743 "[{}] Should fail with very small dataset",
2744 test_name
2745 );
2746 Ok(())
2747 }
2748
2749 fn check_ttm_squeeze_empty_input(
2750 test_name: &str,
2751 kernel: Kernel,
2752 ) -> Result<(), Box<dyn Error>> {
2753 skip_if_unsupported!(kernel, test_name);
2754 let empty_data: Vec<f64> = vec![];
2755 let params = TtmSqueezeParams::default();
2756
2757 let input = TtmSqueezeInput::from_slices(&empty_data, &empty_data, &empty_data, params);
2758 let result = ttm_squeeze_with_kernel(&input, kernel);
2759
2760 assert!(
2761 result.is_err(),
2762 "[{}] Should fail with empty input",
2763 test_name
2764 );
2765 Ok(())
2766 }
2767
2768 fn check_ttm_squeeze_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2769 skip_if_unsupported!(kernel, test_name);
2770 let nan_data = vec![f64::NAN; 50];
2771 let params = TtmSqueezeParams::default();
2772
2773 let input = TtmSqueezeInput::from_slices(&nan_data, &nan_data, &nan_data, params);
2774 let result = ttm_squeeze_with_kernel(&input, kernel);
2775
2776 assert!(
2777 result.is_err(),
2778 "[{}] Should fail with all NaN values",
2779 test_name
2780 );
2781 Ok(())
2782 }
2783
2784 fn check_ttm_squeeze_inconsistent_slices(
2785 test_name: &str,
2786 kernel: Kernel,
2787 ) -> Result<(), Box<dyn Error>> {
2788 skip_if_unsupported!(kernel, test_name);
2789 let high = vec![1.0; 10];
2790 let low = vec![0.9; 10];
2791 let close = vec![0.95; 5];
2792 let params = TtmSqueezeParams::default();
2793
2794 let input = TtmSqueezeInput::from_slices(&high, &low, &close, params);
2795 let result = ttm_squeeze_with_kernel(&input, kernel);
2796
2797 assert!(
2798 result.is_err(),
2799 "[{}] Should fail with inconsistent slice lengths",
2800 test_name
2801 );
2802 Ok(())
2803 }
2804
2805 fn check_ttm_squeeze_nan_handling(
2806 test_name: &str,
2807 kernel: Kernel,
2808 ) -> Result<(), Box<dyn Error>> {
2809 skip_if_unsupported!(kernel, test_name);
2810 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2811 let candles = read_candles_from_csv(file_path)?;
2812
2813 let input = TtmSqueezeInput::with_default_candles(&candles);
2814 let result = ttm_squeeze_with_kernel(&input, kernel)?;
2815
2816 assert_eq!(result.momentum.len(), candles.close.len());
2817 assert_eq!(result.squeeze.len(), candles.close.len());
2818
2819 if result.momentum.len() > 40 {
2820 for i in 40..result.momentum.len() {
2821 assert!(
2822 !result.momentum[i].is_nan(),
2823 "[{}] Unexpected NaN in momentum at {}",
2824 test_name,
2825 i
2826 );
2827 assert!(
2828 !result.squeeze[i].is_nan(),
2829 "[{}] Unexpected NaN in squeeze at {}",
2830 test_name,
2831 i
2832 );
2833 }
2834 }
2835
2836 Ok(())
2837 }
2838
2839 fn check_ttm_squeeze_builder(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2840 skip_if_unsupported!(kernel, test_name);
2841 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2842 let candles = read_candles_from_csv(file_path)?;
2843
2844 let result = TtmSqueezeBuilder::new()
2845 .length(30)
2846 .bb_mult(2.5)
2847 .kc_mult_high(1.2)
2848 .kc_mult_mid(1.8)
2849 .kc_mult_low(2.5)
2850 .kernel(kernel)
2851 .apply(&candles)?;
2852
2853 assert_eq!(result.momentum.len(), candles.close.len());
2854 assert_eq!(result.squeeze.len(), candles.close.len());
2855
2856 Ok(())
2857 }
2858
2859 fn check_ttm_squeeze_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2860 skip_if_unsupported!(kernel, test_name);
2861 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2862 let candles = read_candles_from_csv(file_path)?;
2863
2864 let params = TtmSqueezeParams::default();
2865 let mut stream = TtmSqueezeStream::try_new(params.clone())?;
2866
2867 let input = TtmSqueezeInput::from_candles(&candles, params);
2868 let batch_result = ttm_squeeze_with_kernel(&input, kernel)?;
2869
2870 let mut stream_momentum = Vec::new();
2871 let mut stream_squeeze = Vec::new();
2872
2873 for i in 0..candles.close.len().min(100) {
2874 if let Some((mom, sqz)) =
2875 stream.update(candles.high[i], candles.low[i], candles.close[i])
2876 {
2877 stream_momentum.push(mom);
2878 stream_squeeze.push(sqz);
2879 }
2880 }
2881
2882 assert!(
2883 !stream_momentum.is_empty(),
2884 "[{}] Stream should produce values",
2885 test_name
2886 );
2887
2888 Ok(())
2889 }
2890
2891 #[cfg(debug_assertions)]
2892 fn check_ttm_squeeze_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2893 skip_if_unsupported!(kernel, test_name);
2894 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2895 let candles = read_candles_from_csv(file_path)?;
2896
2897 let test_params = vec![
2898 TtmSqueezeParams::default(),
2899 TtmSqueezeParams {
2900 length: Some(10),
2901 bb_mult: Some(1.5),
2902 kc_mult_high: Some(0.8),
2903 kc_mult_mid: Some(1.2),
2904 kc_mult_low: Some(1.8),
2905 },
2906 TtmSqueezeParams {
2907 length: Some(30),
2908 bb_mult: Some(3.0),
2909 kc_mult_high: Some(1.5),
2910 kc_mult_mid: Some(2.0),
2911 kc_mult_low: Some(2.5),
2912 },
2913 ];
2914
2915 for params in test_params {
2916 let input = TtmSqueezeInput::from_candles(&candles, params);
2917 let output = ttm_squeeze_with_kernel(&input, kernel)?;
2918
2919 for (i, &val) in output.momentum.iter().enumerate() {
2920 if val.is_nan() {
2921 continue;
2922 }
2923
2924 let bits = val.to_bits();
2925 assert!(
2926 bits != 0x11111111_11111111
2927 && bits != 0x22222222_22222222
2928 && bits != 0x33333333_33333333,
2929 "[{}] Found poison value in momentum at {}: 0x{:016X}",
2930 test_name,
2931 i,
2932 bits
2933 );
2934 }
2935
2936 for (i, &val) in output.squeeze.iter().enumerate() {
2937 if val.is_nan() {
2938 continue;
2939 }
2940
2941 let bits = val.to_bits();
2942 assert!(
2943 bits != 0x11111111_11111111
2944 && bits != 0x22222222_22222222
2945 && bits != 0x33333333_33333333,
2946 "[{}] Found poison value in squeeze at {}: 0x{:016X}",
2947 test_name,
2948 i,
2949 bits
2950 );
2951 }
2952 }
2953
2954 Ok(())
2955 }
2956
2957 #[cfg(not(debug_assertions))]
2958 fn check_ttm_squeeze_no_poison(
2959 _test_name: &str,
2960 _kernel: Kernel,
2961 ) -> Result<(), Box<dyn Error>> {
2962 Ok(())
2963 }
2964
2965 fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2966 skip_if_unsupported!(kernel, test_name);
2967 let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
2968 let out = TtmSqueezeBatchBuilder::new()
2969 .kernel(kernel)
2970 .apply_candles(&candles)?;
2971 let def = TtmSqueezeParams::default();
2972 let row_m = out.momentum_for(&def).expect("default row missing");
2973 let row_s = out.squeeze_for(&def).expect("default row missing");
2974 assert_eq!(row_m.len(), candles.close.len());
2975 assert_eq!(row_s.len(), candles.close.len());
2976 Ok(())
2977 }
2978
2979 fn check_batch_sweep_count(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
2980 skip_if_unsupported!(kernel, test_name);
2981 let candles = read_candles_from_csv("src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv")?;
2982 let out = TtmSqueezeBatchBuilder::new()
2983 .kernel(kernel)
2984 .length_range(20, 24, 1)
2985 .bb_mult_range(2.0, 2.0, 0.0)
2986 .kc_high_range(1.0, 1.2, 0.1)
2987 .kc_mid_range(1.5, 1.7, 0.1)
2988 .kc_low_range(2.0, 2.2, 0.1)
2989 .apply_candles(&candles)?;
2990 assert_eq!(out.rows, 5 * 1 * 3 * 3 * 3);
2991 assert_eq!(out.cols, candles.close.len());
2992 Ok(())
2993 }
2994
2995 macro_rules! generate_ttm_squeeze_tests {
2996 ($($test_fn:ident),*) => {
2997 paste::paste! {
2998 $(
2999 #[test]
3000 fn [<$test_fn _scalar>]() {
3001 let _ = $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar);
3002 }
3003 )*
3004
3005 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3006 $(
3007 #[test]
3008 fn [<$test_fn _avx2>]() {
3009 let _ = $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2);
3010 }
3011
3012 #[test]
3013 fn [<$test_fn _avx512>]() {
3014 let _ = $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512);
3015 }
3016 )*
3017 }
3018 };
3019 }
3020
3021 generate_ttm_squeeze_tests!(
3022 check_ttm_squeeze_accuracy,
3023 check_ttm_squeeze_partial_params,
3024 check_ttm_squeeze_default_candles,
3025 check_ttm_squeeze_zero_period,
3026 check_ttm_squeeze_period_exceeds_length,
3027 check_ttm_squeeze_very_small_dataset,
3028 check_ttm_squeeze_empty_input,
3029 check_ttm_squeeze_all_nan,
3030 check_ttm_squeeze_inconsistent_slices,
3031 check_ttm_squeeze_nan_handling,
3032 check_ttm_squeeze_builder,
3033 check_ttm_squeeze_streaming,
3034 check_ttm_squeeze_no_poison
3035 );
3036
3037 macro_rules! gen_batch_tests {
3038 ($f:ident) => {
3039 paste::paste! {
3040 #[test]
3041 fn [<$f _scalar>]() {
3042 let _ = $f(stringify!([<$f _scalar>]), Kernel::ScalarBatch);
3043 }
3044
3045 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3046 #[test]
3047 fn [<$f _avx2>]() {
3048 let _ = $f(stringify!([<$f _avx2>]), Kernel::Avx2Batch);
3049 }
3050
3051 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3052 #[test]
3053 fn [<$f _avx512>]() {
3054 let _ = $f(stringify!([<$f _avx512>]), Kernel::Avx512Batch);
3055 }
3056
3057 #[test]
3058 fn [<$f _auto>]() {
3059 let _ = $f(stringify!([<$f _auto>]), Kernel::Auto);
3060 }
3061 }
3062 };
3063 }
3064
3065 gen_batch_tests!(check_batch_default_row);
3066 gen_batch_tests!(check_batch_sweep_count);
3067
3068 #[inline]
3069 fn eq_or_both_nan_eps(a: f64, b: f64, eps: f64) -> bool {
3070 (a.is_nan() && b.is_nan()) || (a - b).abs() <= eps
3071 }
3072
3073 #[test]
3074 fn test_ttm_squeeze_into_matches_api() -> Result<(), Box<dyn Error>> {
3075 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
3076 let candles = read_candles_from_csv(file_path)?;
3077
3078 let input = TtmSqueezeInput::with_default_candles(&candles);
3079
3080 let baseline = ttm_squeeze(&input)?;
3081
3082 let len = candles.close.len();
3083 let mut mom_out = vec![0.0f64; len];
3084 let mut sqz_out = vec![0.0f64; len];
3085 ttm_squeeze_into(&mut mom_out, &mut sqz_out, &input, Kernel::Auto)?;
3086
3087 assert_eq!(baseline.momentum.len(), len);
3088 assert_eq!(baseline.squeeze.len(), len);
3089
3090 for i in 0..len {
3091 assert!(
3092 eq_or_both_nan_eps(baseline.momentum[i], mom_out[i], 1e-7),
3093 "Momentum mismatch at {}: baseline={} into={}",
3094 i,
3095 baseline.momentum[i],
3096 mom_out[i]
3097 );
3098 assert!(
3099 eq_or_both_nan_eps(baseline.squeeze[i], sqz_out[i], 1e-7),
3100 "Squeeze mismatch at {}: baseline={} into={}",
3101 i,
3102 baseline.squeeze[i],
3103 sqz_out[i]
3104 );
3105 }
3106
3107 Ok(())
3108 }
3109
3110 #[test]
3111 fn ttm_squeeze_scalar_batch_matches_single_scalar_on_dispatch_fixture(
3112 ) -> Result<(), Box<dyn Error>> {
3113 let len = 192usize;
3114 let open: Vec<f64> = (0..len)
3115 .map(|i| 100.0f64 + (i as f64 * 0.1) + ((i as f64) * 0.03).sin())
3116 .collect();
3117 let high: Vec<f64> = open
3118 .iter()
3119 .enumerate()
3120 .map(|(i, v)| v + 0.8 + ((i as f64) * 0.02).cos().abs() * 0.3)
3121 .collect();
3122 let low: Vec<f64> = open
3123 .iter()
3124 .enumerate()
3125 .map(|(i, v)| v - 0.8 - ((i as f64) * 0.02).sin().abs() * 0.3)
3126 .collect();
3127 let close: Vec<f64> = open
3128 .iter()
3129 .enumerate()
3130 .map(|(i, v)| v + ((i as f64) * 0.05).sin() * 0.4)
3131 .collect();
3132
3133 let params = TtmSqueezeParams {
3134 length: Some(20),
3135 bb_mult: Some(2.0),
3136 kc_mult_high: Some(1.0),
3137 kc_mult_mid: Some(1.5),
3138 kc_mult_low: Some(2.0),
3139 };
3140 let single = ttm_squeeze_with_kernel(
3141 &TtmSqueezeInput::from_slices(&high, &low, &close, params.clone()),
3142 Kernel::Scalar,
3143 )?;
3144 let batch = ttm_squeeze_batch_with_kernel(
3145 &high,
3146 &low,
3147 &close,
3148 &TtmSqueezeBatchRange {
3149 length: (20, 20, 0),
3150 bb_mult: (2.0, 2.0, 0.0),
3151 kc_high: (1.0, 1.0, 0.0),
3152 kc_mid: (1.5, 1.5, 0.0),
3153 kc_low: (2.0, 2.0, 0.0),
3154 },
3155 Kernel::ScalarBatch,
3156 )?;
3157
3158 let row = batch
3159 .momentum_for(¶ms)
3160 .expect("default batch row should exist");
3161 let mut max_diff = 0.0f64;
3162 let mut worst = None;
3163 for (idx, (&lhs, &rhs)) in single.momentum.iter().zip(row.iter()).enumerate() {
3164 if lhs.is_nan() && rhs.is_nan() {
3165 continue;
3166 }
3167 let diff = (lhs - rhs).abs();
3168 if diff > max_diff {
3169 max_diff = diff;
3170 worst = Some((idx, lhs, rhs));
3171 }
3172 }
3173 if let Some((idx, lhs, rhs)) = worst {
3174 assert!(
3175 max_diff < 1e-6,
3176 "worst momentum mismatch at {idx}: lhs={lhs} rhs={rhs} diff={max_diff}"
3177 );
3178 }
3179
3180 Ok(())
3181 }
3182}