1#[cfg(feature = "python")]
2use numpy::{IntoPyArray, PyArray1};
3#[cfg(feature = "python")]
4use pyo3::exceptions::PyValueError;
5#[cfg(feature = "python")]
6use pyo3::prelude::*;
7#[cfg(feature = "python")]
8use pyo3::types::{PyDict, PyList};
9
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use serde::{Deserialize, Serialize};
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use wasm_bindgen::prelude::*;
14
15use crate::utilities::data_loader::{source_type, Candles};
16use crate::utilities::enums::Kernel;
17use crate::utilities::helpers::{
18 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
19 make_uninit_matrix,
20};
21#[cfg(feature = "python")]
22use crate::utilities::kernel_validation::validate_kernel;
23use aligned_vec::{AVec, CACHELINE_ALIGN};
24#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
25use core::arch::x86_64::*;
26use paste::paste;
27#[cfg(not(target_arch = "wasm32"))]
28use rayon::prelude::*;
29use std::convert::AsRef;
30use thiserror::Error;
31
32impl<'a> AsRef<[f64]> for AlligatorInput<'a> {
33 #[inline(always)]
34 fn as_ref(&self) -> &[f64] {
35 match &self.data {
36 AlligatorData::Slice(slice) => slice,
37 AlligatorData::Candles { candles, source } => source_type(candles, source),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
43pub enum AlligatorData<'a> {
44 Candles {
45 candles: &'a Candles,
46 source: &'a str,
47 },
48 Slice(&'a [f64]),
49}
50
51#[derive(Debug, Clone)]
52pub struct AlligatorOutput {
53 pub jaw: Vec<f64>,
54 pub teeth: Vec<f64>,
55 pub lips: Vec<f64>,
56}
57
58#[derive(Debug, Clone)]
59#[cfg_attr(
60 all(target_arch = "wasm32", feature = "wasm"),
61 derive(Serialize, Deserialize)
62)]
63pub struct AlligatorParams {
64 pub jaw_period: Option<usize>,
65 pub jaw_offset: Option<usize>,
66 pub teeth_period: Option<usize>,
67 pub teeth_offset: Option<usize>,
68 pub lips_period: Option<usize>,
69 pub lips_offset: Option<usize>,
70}
71impl Default for AlligatorParams {
72 fn default() -> Self {
73 Self {
74 jaw_period: Some(13),
75 jaw_offset: Some(8),
76 teeth_period: Some(8),
77 teeth_offset: Some(5),
78 lips_period: Some(5),
79 lips_offset: Some(3),
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
85pub struct AlligatorInput<'a> {
86 pub data: AlligatorData<'a>,
87 pub params: AlligatorParams,
88}
89impl<'a> AlligatorInput<'a> {
90 #[inline]
91 pub fn from_candles(c: &'a Candles, s: &'a str, p: AlligatorParams) -> Self {
92 Self {
93 data: AlligatorData::Candles {
94 candles: c,
95 source: s,
96 },
97 params: p,
98 }
99 }
100 #[inline]
101 pub fn from_slice(sl: &'a [f64], p: AlligatorParams) -> Self {
102 Self {
103 data: AlligatorData::Slice(sl),
104 params: p,
105 }
106 }
107 #[inline]
108 pub fn with_default_candles(c: &'a Candles) -> Self {
109 Self::from_candles(c, "hl2", AlligatorParams::default())
110 }
111 #[inline]
112 pub fn get_jaw_period(&self) -> usize {
113 self.params.jaw_period.unwrap_or(13)
114 }
115 #[inline]
116 pub fn get_jaw_offset(&self) -> usize {
117 self.params.jaw_offset.unwrap_or(8)
118 }
119 #[inline]
120 pub fn get_teeth_period(&self) -> usize {
121 self.params.teeth_period.unwrap_or(8)
122 }
123 #[inline]
124 pub fn get_teeth_offset(&self) -> usize {
125 self.params.teeth_offset.unwrap_or(5)
126 }
127 #[inline]
128 pub fn get_lips_period(&self) -> usize {
129 self.params.lips_period.unwrap_or(5)
130 }
131 #[inline]
132 pub fn get_lips_offset(&self) -> usize {
133 self.params.lips_offset.unwrap_or(3)
134 }
135}
136
137#[derive(Copy, Clone, Debug)]
138pub struct AlligatorBuilder {
139 jaw_period: Option<usize>,
140 jaw_offset: Option<usize>,
141 teeth_period: Option<usize>,
142 teeth_offset: Option<usize>,
143 lips_period: Option<usize>,
144 lips_offset: Option<usize>,
145 kernel: Kernel,
146}
147impl Default for AlligatorBuilder {
148 fn default() -> Self {
149 Self {
150 jaw_period: None,
151 jaw_offset: None,
152 teeth_period: None,
153 teeth_offset: None,
154 lips_period: None,
155 lips_offset: None,
156 kernel: Kernel::Auto,
157 }
158 }
159}
160impl AlligatorBuilder {
161 #[inline(always)]
162 pub fn new() -> Self {
163 Self::default()
164 }
165 #[inline(always)]
166 pub fn jaw_period(mut self, n: usize) -> Self {
167 self.jaw_period = Some(n);
168 self
169 }
170 #[inline(always)]
171 pub fn jaw_offset(mut self, n: usize) -> Self {
172 self.jaw_offset = Some(n);
173 self
174 }
175 #[inline(always)]
176 pub fn teeth_period(mut self, n: usize) -> Self {
177 self.teeth_period = Some(n);
178 self
179 }
180 #[inline(always)]
181 pub fn teeth_offset(mut self, n: usize) -> Self {
182 self.teeth_offset = Some(n);
183 self
184 }
185 #[inline(always)]
186 pub fn lips_period(mut self, n: usize) -> Self {
187 self.lips_period = Some(n);
188 self
189 }
190 #[inline(always)]
191 pub fn lips_offset(mut self, n: usize) -> Self {
192 self.lips_offset = Some(n);
193 self
194 }
195 #[inline(always)]
196 pub fn kernel(mut self, k: Kernel) -> Self {
197 self.kernel = k;
198 self
199 }
200
201 #[inline(always)]
202 pub fn apply(self, c: &Candles) -> Result<AlligatorOutput, AlligatorError> {
203 let p = AlligatorParams {
204 jaw_period: self.jaw_period,
205 jaw_offset: self.jaw_offset,
206 teeth_period: self.teeth_period,
207 teeth_offset: self.teeth_offset,
208 lips_period: self.lips_period,
209 lips_offset: self.lips_offset,
210 };
211 let i = AlligatorInput::from_candles(c, "hl2", p);
212 alligator_with_kernel(&i, self.kernel)
213 }
214 #[inline(always)]
215 pub fn apply_slice(self, d: &[f64]) -> Result<AlligatorOutput, AlligatorError> {
216 let p = AlligatorParams {
217 jaw_period: self.jaw_period,
218 jaw_offset: self.jaw_offset,
219 teeth_period: self.teeth_period,
220 teeth_offset: self.teeth_offset,
221 lips_period: self.lips_period,
222 lips_offset: self.lips_offset,
223 };
224 let i = AlligatorInput::from_slice(d, p);
225 alligator_with_kernel(&i, self.kernel)
226 }
227 #[inline(always)]
228 pub fn into_stream(self) -> Result<AlligatorStream, AlligatorError> {
229 let p = AlligatorParams {
230 jaw_period: self.jaw_period,
231 jaw_offset: self.jaw_offset,
232 teeth_period: self.teeth_period,
233 teeth_offset: self.teeth_offset,
234 lips_period: self.lips_period,
235 lips_offset: self.lips_offset,
236 };
237 AlligatorStream::try_new(p)
238 }
239}
240
241#[derive(Debug, Error)]
242pub enum AlligatorError {
243 #[error("alligator: Input data slice is empty.")]
244 EmptyInputData,
245 #[error("alligator: All values are NaN.")]
246 AllValuesNaN,
247 #[error("alligator: Invalid jaw period: period = {period}, data length = {data_len}")]
248 InvalidJawPeriod { period: usize, data_len: usize },
249 #[error("alligator: Invalid jaw offset: offset = {offset}, data_len = {data_len}")]
250 InvalidJawOffset { offset: usize, data_len: usize },
251 #[error("alligator: Invalid teeth period: period = {period}, data length = {data_len}")]
252 InvalidTeethPeriod { period: usize, data_len: usize },
253 #[error("alligator: Invalid teeth offset: offset = {offset}, data_len = {data_len}")]
254 InvalidTeethOffset { offset: usize, data_len: usize },
255 #[error("alligator: Invalid lips period: period = {period}, data length = {data_len}")]
256 InvalidLipsPeriod { period: usize, data_len: usize },
257 #[error("alligator: Invalid lips offset: offset = {offset}, data_len = {data_len}")]
258 InvalidLipsOffset { offset: usize, data_len: usize },
259 #[error(
260 "alligator: Invalid kernel for batch operation. Expected batch kernel, got: {kernel:?}"
261 )]
262 InvalidKernel { kernel: Kernel },
263 #[error("alligator: invalid kernel for batch: {0:?}")]
264 InvalidKernelForBatch(Kernel),
265 #[error("alligator: Not enough valid data: needed = {needed}, valid = {valid}")]
266 NotEnoughValidData { needed: usize, valid: usize },
267 #[error("alligator: output length mismatch: expected = {expected}, got = {got}")]
268 OutputLengthMismatch { expected: usize, got: usize },
269 #[error("alligator: Invalid range: start={start}, end={end}, step={step}")]
270 InvalidRange { start: i64, end: i64, step: i64 },
271}
272
273#[inline]
274pub fn alligator(input: &AlligatorInput) -> Result<AlligatorOutput, AlligatorError> {
275 alligator_with_kernel(input, Kernel::Auto)
276}
277pub fn alligator_with_kernel(
278 input: &AlligatorInput,
279 kernel: Kernel,
280) -> Result<AlligatorOutput, AlligatorError> {
281 let data: &[f64] = match &input.data {
282 AlligatorData::Candles { candles, source } => source_type(candles, source),
283 AlligatorData::Slice(sl) => sl,
284 };
285 if data.is_empty() {
286 return Err(AlligatorError::EmptyInputData);
287 }
288 let first = data
289 .iter()
290 .position(|x| !x.is_nan())
291 .ok_or(AlligatorError::AllValuesNaN)?;
292 let len = data.len();
293 let jaw_period = input.get_jaw_period();
294 let jaw_offset = input.get_jaw_offset();
295 let teeth_period = input.get_teeth_period();
296 let teeth_offset = input.get_teeth_offset();
297 let lips_period = input.get_lips_period();
298 let lips_offset = input.get_lips_offset();
299 if jaw_period == 0 || jaw_period > len {
300 return Err(AlligatorError::InvalidJawPeriod {
301 period: jaw_period,
302 data_len: len,
303 });
304 }
305 if jaw_offset > len {
306 return Err(AlligatorError::InvalidJawOffset {
307 offset: jaw_offset,
308 data_len: len,
309 });
310 }
311 if teeth_period == 0 || teeth_period > len {
312 return Err(AlligatorError::InvalidTeethPeriod {
313 period: teeth_period,
314 data_len: len,
315 });
316 }
317 if teeth_offset > len {
318 return Err(AlligatorError::InvalidTeethOffset {
319 offset: teeth_offset,
320 data_len: len,
321 });
322 }
323 if lips_period == 0 || lips_period > len {
324 return Err(AlligatorError::InvalidLipsPeriod {
325 period: lips_period,
326 data_len: len,
327 });
328 }
329 if lips_offset > len {
330 return Err(AlligatorError::InvalidLipsOffset {
331 offset: lips_offset,
332 data_len: len,
333 });
334 }
335
336 let needed = jaw_period.max(teeth_period).max(lips_period);
337 let valid = len - first;
338 if valid < needed {
339 return Err(AlligatorError::NotEnoughValidData { needed, valid });
340 }
341
342 let chosen = match kernel {
343 Kernel::Auto => Kernel::Scalar,
344 other => other,
345 };
346 unsafe {
347 match chosen {
348 Kernel::Scalar | Kernel::ScalarBatch => alligator_scalar(
349 data,
350 jaw_period,
351 jaw_offset,
352 teeth_period,
353 teeth_offset,
354 lips_period,
355 lips_offset,
356 first,
357 len,
358 ),
359 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
360 Kernel::Avx2 | Kernel::Avx2Batch => alligator_avx2(
361 data,
362 jaw_period,
363 jaw_offset,
364 teeth_period,
365 teeth_offset,
366 lips_period,
367 lips_offset,
368 first,
369 len,
370 ),
371 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
372 Kernel::Avx512 | Kernel::Avx512Batch => alligator_avx512(
373 data,
374 jaw_period,
375 jaw_offset,
376 teeth_period,
377 teeth_offset,
378 lips_period,
379 lips_offset,
380 first,
381 len,
382 ),
383 _ => unreachable!(),
384 }
385 }
386}
387#[inline]
388pub unsafe fn alligator_scalar(
389 data: &[f64],
390 jaw_period: usize,
391 jaw_offset: usize,
392 teeth_period: usize,
393 teeth_offset: usize,
394 lips_period: usize,
395 lips_offset: usize,
396 first: usize,
397 len: usize,
398) -> Result<AlligatorOutput, AlligatorError> {
399 let jaw_warmup = first + jaw_period - 1 + jaw_offset;
400 let teeth_warmup = first + teeth_period - 1 + teeth_offset;
401 let lips_warmup = first + lips_period - 1 + lips_offset;
402
403 let mut jaw = alloc_with_nan_prefix(len, jaw_warmup);
404 let mut teeth = alloc_with_nan_prefix(len, teeth_warmup);
405 let mut lips = alloc_with_nan_prefix(len, lips_warmup);
406
407 let _ = alligator_smma_scalar(
408 data,
409 jaw_period,
410 jaw_offset,
411 teeth_period,
412 teeth_offset,
413 lips_period,
414 lips_offset,
415 first,
416 len,
417 &mut jaw,
418 &mut teeth,
419 &mut lips,
420 );
421 Ok(AlligatorOutput { jaw, teeth, lips })
422}
423
424#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
425#[inline]
426pub unsafe fn alligator_avx2(
427 data: &[f64],
428 jaw_period: usize,
429 jaw_offset: usize,
430 teeth_period: usize,
431 teeth_offset: usize,
432 lips_period: usize,
433 lips_offset: usize,
434 first: usize,
435 len: usize,
436) -> Result<AlligatorOutput, AlligatorError> {
437 alligator_scalar(
438 data,
439 jaw_period,
440 jaw_offset,
441 teeth_period,
442 teeth_offset,
443 lips_period,
444 lips_offset,
445 first,
446 len,
447 )
448}
449
450#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
451#[inline]
452pub unsafe fn alligator_avx512(
453 data: &[f64],
454 jaw_period: usize,
455 jaw_offset: usize,
456 teeth_period: usize,
457 teeth_offset: usize,
458 lips_period: usize,
459 lips_offset: usize,
460 first: usize,
461 len: usize,
462) -> Result<AlligatorOutput, AlligatorError> {
463 if jaw_period <= 32 && teeth_period <= 32 && lips_period <= 32 {
464 alligator_avx512_short(
465 data,
466 jaw_period,
467 jaw_offset,
468 teeth_period,
469 teeth_offset,
470 lips_period,
471 lips_offset,
472 first,
473 len,
474 )
475 } else {
476 alligator_avx512_long(
477 data,
478 jaw_period,
479 jaw_offset,
480 teeth_period,
481 teeth_offset,
482 lips_period,
483 lips_offset,
484 first,
485 len,
486 )
487 }
488}
489
490#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
491#[inline]
492pub unsafe fn alligator_avx512_short(
493 data: &[f64],
494 jaw_period: usize,
495 jaw_offset: usize,
496 teeth_period: usize,
497 teeth_offset: usize,
498 lips_period: usize,
499 lips_offset: usize,
500 first: usize,
501 len: usize,
502) -> Result<AlligatorOutput, AlligatorError> {
503 alligator_scalar(
504 data,
505 jaw_period,
506 jaw_offset,
507 teeth_period,
508 teeth_offset,
509 lips_period,
510 lips_offset,
511 first,
512 len,
513 )
514}
515#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
516#[inline]
517pub unsafe fn alligator_avx512_long(
518 data: &[f64],
519 jaw_period: usize,
520 jaw_offset: usize,
521 teeth_period: usize,
522 teeth_offset: usize,
523 lips_period: usize,
524 lips_offset: usize,
525 first: usize,
526 len: usize,
527) -> Result<AlligatorOutput, AlligatorError> {
528 alligator_scalar(
529 data,
530 jaw_period,
531 jaw_offset,
532 teeth_period,
533 teeth_offset,
534 lips_period,
535 lips_offset,
536 first,
537 len,
538 )
539}
540
541#[inline(always)]
542pub unsafe fn alligator_smma_scalar(
543 data: &[f64],
544 jaw_period: usize,
545 jaw_offset: usize,
546 teeth_period: usize,
547 teeth_offset: usize,
548 lips_period: usize,
549 lips_offset: usize,
550 first: usize,
551 len: usize,
552 jaw: &mut [f64],
553 teeth: &mut [f64],
554 lips: &mut [f64],
555) -> (f64, f64, f64) {
556 let mut jaw_sum = 0.0;
557 let mut teeth_sum = 0.0;
558 let mut lips_sum = 0.0;
559
560 let mut jaw_smma_val = 0.0;
561 let mut teeth_smma_val = 0.0;
562 let mut lips_smma_val = 0.0;
563
564 let mut jaw_ready = false;
565 let mut teeth_ready = false;
566 let mut lips_ready = false;
567
568 let jaw_scale = (jaw_period - 1) as f64;
569 let jaw_inv_period = 1.0 / jaw_period as f64;
570
571 let teeth_scale = (teeth_period - 1) as f64;
572 let teeth_inv_period = 1.0 / teeth_period as f64;
573
574 let lips_scale = (lips_period - 1) as f64;
575 let lips_inv_period = 1.0 / lips_period as f64;
576
577 for i in first..len {
578 let data_point = data[i];
579 if !jaw_ready {
580 if i < first + jaw_period {
581 jaw_sum += data_point;
582 if i == first + jaw_period - 1 {
583 jaw_smma_val = jaw_sum / (jaw_period as f64);
584 jaw_ready = true;
585 let shifted_index = i + jaw_offset;
586 if shifted_index < len {
587 jaw[shifted_index] = jaw_smma_val;
588 }
589 }
590 }
591 } else {
592 jaw_smma_val = (jaw_smma_val * jaw_scale + data_point) * jaw_inv_period;
593 let shifted_index = i + jaw_offset;
594 if shifted_index < len {
595 jaw[shifted_index] = jaw_smma_val;
596 }
597 }
598
599 if !teeth_ready {
600 if i < first + teeth_period {
601 teeth_sum += data_point;
602 if i == first + teeth_period - 1 {
603 teeth_smma_val = teeth_sum / (teeth_period as f64);
604 teeth_ready = true;
605 let shifted_index = i + teeth_offset;
606 if shifted_index < len {
607 teeth[shifted_index] = teeth_smma_val;
608 }
609 }
610 }
611 } else {
612 teeth_smma_val = (teeth_smma_val * teeth_scale + data_point) * teeth_inv_period;
613 let shifted_index = i + teeth_offset;
614 if shifted_index < len {
615 teeth[shifted_index] = teeth_smma_val;
616 }
617 }
618
619 if !lips_ready {
620 if i < first + lips_period {
621 lips_sum += data_point;
622 if i == first + lips_period - 1 {
623 lips_smma_val = lips_sum / (lips_period as f64);
624 lips_ready = true;
625 let shifted_index = i + lips_offset;
626 if shifted_index < len {
627 lips[shifted_index] = lips_smma_val;
628 }
629 }
630 }
631 } else {
632 lips_smma_val = (lips_smma_val * lips_scale + data_point) * lips_inv_period;
633 let shifted_index = i + lips_offset;
634 if shifted_index < len {
635 lips[shifted_index] = lips_smma_val;
636 }
637 }
638 }
639 (jaw_smma_val, teeth_smma_val, lips_smma_val)
640}
641
642#[derive(Debug, Clone)]
643struct Smmaline {
644 period: usize,
645 offset: usize,
646 inv: f64,
647
648 seeded: bool,
649 count: usize,
650 sum: f64,
651 value: f64,
652
653 off_head: usize,
654 off_filled: bool,
655 off_buf: Vec<f64>,
656}
657
658impl Smmaline {
659 #[inline(always)]
660 fn new(period: usize, offset: usize) -> Self {
661 debug_assert!(period > 0);
662
663 let off_buf = if offset > 0 {
664 vec![0.0_f64; offset]
665 } else {
666 Vec::new()
667 };
668 Self {
669 period,
670 offset,
671 inv: 1.0 / period as f64,
672 seeded: false,
673 count: 0,
674 sum: 0.0,
675 value: f64::NAN,
676 off_head: 0,
677 off_filled: false,
678 off_buf,
679 }
680 }
681
682 #[inline(always)]
683 fn update_unshifted(&mut self, x: f64) -> Option<f64> {
684 if !self.seeded {
685 self.sum += x;
686 self.count += 1;
687 if self.count == self.period {
688 self.value = self.sum * self.inv;
689 self.seeded = true;
690 Some(self.value)
691 } else {
692 None
693 }
694 } else {
695 let delta = x - self.value;
696
697 self.value = delta.mul_add(self.inv, self.value);
698 Some(self.value)
699 }
700 }
701
702 #[inline(always)]
703 fn update_shifted(&mut self, x: f64) -> Option<f64> {
704 let y = self.update_unshifted(x)?;
705 if self.offset == 0 {
706 return Some(y);
707 }
708
709 let out = if self.off_filled {
710 Some(self.off_buf[self.off_head])
711 } else {
712 None
713 };
714 self.off_buf[self.off_head] = y;
715 self.off_head += 1;
716 if self.off_head == self.offset {
717 self.off_head = 0;
718 self.off_filled = true;
719 }
720 out
721 }
722
723 #[inline(always)]
724 fn is_seeded(&self) -> bool {
725 self.seeded
726 }
727}
728
729#[derive(Debug, Clone)]
730pub struct AlligatorStream {
731 jaw: Smmaline,
732 teeth: Smmaline,
733 lips: Smmaline,
734}
735
736impl AlligatorStream {
737 pub fn try_new(params: AlligatorParams) -> Result<Self, AlligatorError> {
738 let jaw_period = params.jaw_period.unwrap_or(13);
739 let jaw_offset = params.jaw_offset.unwrap_or(8);
740 let teeth_period = params.teeth_period.unwrap_or(8);
741 let teeth_offset = params.teeth_offset.unwrap_or(5);
742 let lips_period = params.lips_period.unwrap_or(5);
743 let lips_offset = params.lips_offset.unwrap_or(3);
744
745 if jaw_period == 0 {
746 return Err(AlligatorError::InvalidJawPeriod {
747 period: jaw_period,
748 data_len: 0,
749 });
750 }
751 if teeth_period == 0 {
752 return Err(AlligatorError::InvalidTeethPeriod {
753 period: teeth_period,
754 data_len: 0,
755 });
756 }
757 if lips_period == 0 {
758 return Err(AlligatorError::InvalidLipsPeriod {
759 period: lips_period,
760 data_len: 0,
761 });
762 }
763
764 Ok(Self {
765 jaw: Smmaline::new(jaw_period, jaw_offset),
766 teeth: Smmaline::new(teeth_period, teeth_offset),
767 lips: Smmaline::new(lips_period, lips_offset),
768 })
769 }
770
771 #[inline(always)]
772 pub fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
773 let j = self.jaw.update_unshifted(value);
774 let t = self.teeth.update_unshifted(value);
775 let l = self.lips.update_unshifted(value);
776 match (j, t, l) {
777 (Some(jv), Some(tv), Some(lv)) => Some((jv, tv, lv)),
778 _ => None,
779 }
780 }
781
782 #[inline(always)]
783 pub fn update_shifted(&mut self, value: f64) -> Option<(f64, f64, f64)> {
784 let j = self.jaw.update_shifted(value);
785 let t = self.teeth.update_shifted(value);
786 let l = self.lips.update_shifted(value);
787 match (j, t, l) {
788 (Some(jv), Some(tv), Some(lv)) => Some((jv, tv, lv)),
789 _ => None,
790 }
791 }
792}
793
794#[derive(Clone, Debug)]
795pub struct AlligatorBatchRange {
796 pub jaw_period: (usize, usize, usize),
797 pub jaw_offset: (usize, usize, usize),
798 pub teeth_period: (usize, usize, usize),
799 pub teeth_offset: (usize, usize, usize),
800 pub lips_period: (usize, usize, usize),
801 pub lips_offset: (usize, usize, usize),
802}
803impl Default for AlligatorBatchRange {
804 fn default() -> Self {
805 Self {
806 jaw_period: (13, 262, 1),
807 jaw_offset: (8, 8, 0),
808 teeth_period: (8, 8, 0),
809 teeth_offset: (5, 5, 0),
810 lips_period: (5, 5, 0),
811 lips_offset: (3, 3, 0),
812 }
813 }
814}
815#[derive(Clone, Debug, Default)]
816pub struct AlligatorBatchBuilder {
817 range: AlligatorBatchRange,
818 kernel: Kernel,
819}
820impl AlligatorBatchBuilder {
821 pub fn new() -> Self {
822 Self::default()
823 }
824 pub fn kernel(mut self, k: Kernel) -> Self {
825 self.kernel = k;
826 self
827 }
828 pub fn jaw_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
829 self.range.jaw_period = (start, end, step);
830 self
831 }
832 pub fn jaw_offset_range(mut self, start: usize, end: usize, step: usize) -> Self {
833 self.range.jaw_offset = (start, end, step);
834 self
835 }
836 pub fn teeth_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
837 self.range.teeth_period = (start, end, step);
838 self
839 }
840 pub fn teeth_offset_range(mut self, start: usize, end: usize, step: usize) -> Self {
841 self.range.teeth_offset = (start, end, step);
842 self
843 }
844 pub fn lips_period_range(mut self, start: usize, end: usize, step: usize) -> Self {
845 self.range.lips_period = (start, end, step);
846 self
847 }
848 pub fn lips_offset_range(mut self, start: usize, end: usize, step: usize) -> Self {
849 self.range.lips_offset = (start, end, step);
850 self
851 }
852 pub fn apply_slice(self, data: &[f64]) -> Result<AlligatorBatchOutput, AlligatorError> {
853 alligator_batch_with_kernel(data, &self.range, self.kernel)
854 }
855 pub fn with_default_slice(
856 data: &[f64],
857 k: Kernel,
858 ) -> Result<AlligatorBatchOutput, AlligatorError> {
859 AlligatorBatchBuilder::new().kernel(k).apply_slice(data)
860 }
861 pub fn apply_candles(
862 self,
863 c: &Candles,
864 src: &str,
865 ) -> Result<AlligatorBatchOutput, AlligatorError> {
866 let slice = source_type(c, src);
867 self.apply_slice(slice)
868 }
869 pub fn with_default_candles(c: &Candles) -> Result<AlligatorBatchOutput, AlligatorError> {
870 AlligatorBatchBuilder::new()
871 .kernel(Kernel::Auto)
872 .apply_candles(c, "hl2")
873 }
874}
875
876pub fn alligator_batch_with_kernel(
877 data: &[f64],
878 sweep: &AlligatorBatchRange,
879 k: Kernel,
880) -> Result<AlligatorBatchOutput, AlligatorError> {
881 let kernel = match k {
882 Kernel::Auto => detect_best_batch_kernel(),
883 other if other.is_batch() => other,
884 non_batch => return Err(AlligatorError::InvalidKernelForBatch(non_batch)),
885 };
886
887 alligator_batch_par_slice(data, sweep, kernel)
888}
889
890#[derive(Clone, Debug)]
891pub struct AlligatorBatchOutput {
892 pub jaw: Vec<f64>,
893 pub teeth: Vec<f64>,
894 pub lips: Vec<f64>,
895 pub combos: Vec<AlligatorParams>,
896 pub rows: usize,
897 pub cols: usize,
898}
899impl AlligatorBatchOutput {
900 pub fn row_for_params(&self, p: &AlligatorParams) -> Option<usize> {
901 self.combos.iter().position(|c| {
902 c.jaw_period.unwrap_or(13) == p.jaw_period.unwrap_or(13)
903 && c.jaw_offset.unwrap_or(8) == p.jaw_offset.unwrap_or(8)
904 && c.teeth_period.unwrap_or(8) == p.teeth_period.unwrap_or(8)
905 && c.teeth_offset.unwrap_or(5) == p.teeth_offset.unwrap_or(5)
906 && c.lips_period.unwrap_or(5) == p.lips_period.unwrap_or(5)
907 && c.lips_offset.unwrap_or(3) == p.lips_offset.unwrap_or(3)
908 })
909 }
910 pub fn values_for(&self, p: &AlligatorParams) -> Option<(&[f64], &[f64], &[f64])> {
911 self.row_for_params(p).map(|row| {
912 let start = row * self.cols;
913 (
914 &self.jaw[start..start + self.cols],
915 &self.teeth[start..start + self.cols],
916 &self.lips[start..start + self.cols],
917 )
918 })
919 }
920}
921
922#[inline(always)]
923fn expand_grid(r: &AlligatorBatchRange) -> Result<Vec<AlligatorParams>, AlligatorError> {
924 fn axis((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, AlligatorError> {
925 if step == 0 || start == end {
926 return Ok(vec![start]);
927 }
928 if start < end {
929 let v: Vec<usize> = (start..=end).step_by(step).collect();
930 if v.is_empty() {
931 return Err(AlligatorError::InvalidRange {
932 start: start as i64,
933 end: end as i64,
934 step: step as i64,
935 });
936 }
937 Ok(v)
938 } else {
939 let mut v = Vec::new();
940 let mut cur = start;
941 while cur >= end {
942 v.push(cur);
943 if cur - end < step {
944 break;
945 }
946 cur -= step;
947 }
948 if v.is_empty() {
949 return Err(AlligatorError::InvalidRange {
950 start: start as i64,
951 end: end as i64,
952 step: step as i64,
953 });
954 }
955 Ok(v)
956 }
957 }
958 let jaw_periods = axis(r.jaw_period)?;
959 let jaw_offsets = axis(r.jaw_offset)?;
960 let teeth_periods = axis(r.teeth_period)?;
961 let teeth_offsets = axis(r.teeth_offset)?;
962 let lips_periods = axis(r.lips_period)?;
963 let lips_offsets = axis(r.lips_offset)?;
964
965 let cap = jaw_periods
966 .len()
967 .checked_mul(jaw_offsets.len())
968 .and_then(|v| v.checked_mul(teeth_periods.len()))
969 .and_then(|v| v.checked_mul(teeth_offsets.len()))
970 .and_then(|v| v.checked_mul(lips_periods.len()))
971 .and_then(|v| v.checked_mul(lips_offsets.len()))
972 .unwrap_or(0);
973 let mut out = Vec::with_capacity(cap);
974 for &jp in &jaw_periods {
975 for &jo in &jaw_offsets {
976 for &tp in &teeth_periods {
977 for &to in &teeth_offsets {
978 for &lp in &lips_periods {
979 for &lo in &lips_offsets {
980 out.push(AlligatorParams {
981 jaw_period: Some(jp),
982 jaw_offset: Some(jo),
983 teeth_period: Some(tp),
984 teeth_offset: Some(to),
985 lips_period: Some(lp),
986 lips_offset: Some(lo),
987 });
988 }
989 }
990 }
991 }
992 }
993 }
994 if out.is_empty() {
995 return Err(AlligatorError::InvalidRange {
996 start: 0,
997 end: 0,
998 step: 0,
999 });
1000 }
1001 Ok(out)
1002}
1003
1004#[inline(always)]
1005pub fn alligator_batch_slice(
1006 data: &[f64],
1007 sweep: &AlligatorBatchRange,
1008 kern: Kernel,
1009) -> Result<AlligatorBatchOutput, AlligatorError> {
1010 alligator_batch_inner(data, sweep, kern, false)
1011}
1012#[inline(always)]
1013pub fn alligator_batch_par_slice(
1014 data: &[f64],
1015 sweep: &AlligatorBatchRange,
1016 kern: Kernel,
1017) -> Result<AlligatorBatchOutput, AlligatorError> {
1018 alligator_batch_inner(data, sweep, kern, true)
1019}
1020#[inline(always)]
1021fn alligator_batch_inner(
1022 data: &[f64],
1023 sweep: &AlligatorBatchRange,
1024 kern: Kernel,
1025 parallel: bool,
1026) -> Result<AlligatorBatchOutput, AlligatorError> {
1027 let combos = expand_grid(sweep)?;
1028 let first = data
1029 .iter()
1030 .position(|x| !x.is_nan())
1031 .ok_or(AlligatorError::AllValuesNaN)?;
1032 let max_p = combos
1033 .iter()
1034 .map(|c| {
1035 c.jaw_period
1036 .unwrap()
1037 .max(c.teeth_period.unwrap())
1038 .max(c.lips_period.unwrap())
1039 })
1040 .max()
1041 .unwrap();
1042 if data.len() - first < max_p {
1043 return Err(AlligatorError::InvalidJawPeriod {
1044 period: max_p,
1045 data_len: data.len(),
1046 });
1047 }
1048 let rows = combos.len();
1049 let cols = data.len();
1050
1051 let _rc = rows.checked_mul(cols).ok_or(AlligatorError::InvalidRange {
1052 start: rows as i64,
1053 end: cols as i64,
1054 step: 1,
1055 })?;
1056 let mut jaw_mu = make_uninit_matrix(rows, cols);
1057 let mut teeth_mu = make_uninit_matrix(rows, cols);
1058 let mut lips_mu = make_uninit_matrix(rows, cols);
1059
1060 let jaw_warmups: Vec<usize> = combos
1061 .iter()
1062 .map(|c| first + c.jaw_period.unwrap() - 1 + c.jaw_offset.unwrap())
1063 .collect();
1064 let teeth_warmups: Vec<usize> = combos
1065 .iter()
1066 .map(|c| first + c.teeth_period.unwrap() - 1 + c.teeth_offset.unwrap())
1067 .collect();
1068 let lips_warmups: Vec<usize> = combos
1069 .iter()
1070 .map(|c| first + c.lips_period.unwrap() - 1 + c.lips_offset.unwrap())
1071 .collect();
1072
1073 init_matrix_prefixes(&mut jaw_mu, cols, &jaw_warmups);
1074 init_matrix_prefixes(&mut teeth_mu, cols, &teeth_warmups);
1075 init_matrix_prefixes(&mut lips_mu, cols, &lips_warmups);
1076
1077 let mut jaw_guard = std::mem::ManuallyDrop::new(jaw_mu);
1078 let mut teeth_guard = std::mem::ManuallyDrop::new(teeth_mu);
1079 let mut lips_guard = std::mem::ManuallyDrop::new(lips_mu);
1080
1081 let jaw: &mut [f64] = unsafe {
1082 core::slice::from_raw_parts_mut(jaw_guard.as_mut_ptr() as *mut f64, jaw_guard.len())
1083 };
1084 let teeth: &mut [f64] = unsafe {
1085 core::slice::from_raw_parts_mut(teeth_guard.as_mut_ptr() as *mut f64, teeth_guard.len())
1086 };
1087 let lips: &mut [f64] = unsafe {
1088 core::slice::from_raw_parts_mut(lips_guard.as_mut_ptr() as *mut f64, lips_guard.len())
1089 };
1090
1091 let combos = alligator_batch_inner_into(data, sweep, kern, parallel, jaw, teeth, lips)?;
1092
1093 let jaw_vec = unsafe {
1094 Vec::from_raw_parts(
1095 jaw_guard.as_mut_ptr() as *mut f64,
1096 jaw_guard.len(),
1097 jaw_guard.capacity(),
1098 )
1099 };
1100 let teeth_vec = unsafe {
1101 Vec::from_raw_parts(
1102 teeth_guard.as_mut_ptr() as *mut f64,
1103 teeth_guard.len(),
1104 teeth_guard.capacity(),
1105 )
1106 };
1107 let lips_vec = unsafe {
1108 Vec::from_raw_parts(
1109 lips_guard.as_mut_ptr() as *mut f64,
1110 lips_guard.len(),
1111 lips_guard.capacity(),
1112 )
1113 };
1114
1115 Ok(AlligatorBatchOutput {
1116 jaw: jaw_vec,
1117 teeth: teeth_vec,
1118 lips: lips_vec,
1119 combos,
1120 rows,
1121 cols,
1122 })
1123}
1124
1125#[inline]
1126fn alligator_batch_inner_into(
1127 data: &[f64],
1128 sweep: &AlligatorBatchRange,
1129 kern: Kernel,
1130 parallel: bool,
1131 jaw_out: &mut [f64],
1132 teeth_out: &mut [f64],
1133 lips_out: &mut [f64],
1134) -> Result<Vec<AlligatorParams>, AlligatorError> {
1135 let combos = expand_grid(sweep)?;
1136
1137 let cols = data.len();
1138 let rows = combos.len();
1139 let total = rows.checked_mul(cols).ok_or(AlligatorError::InvalidRange {
1140 start: rows as i64,
1141 end: cols as i64,
1142 step: 1,
1143 })?;
1144 if jaw_out.len() != total {
1145 return Err(AlligatorError::OutputLengthMismatch {
1146 expected: total,
1147 got: jaw_out.len(),
1148 });
1149 }
1150 if teeth_out.len() != total {
1151 return Err(AlligatorError::OutputLengthMismatch {
1152 expected: total,
1153 got: teeth_out.len(),
1154 });
1155 }
1156 if lips_out.len() != total {
1157 return Err(AlligatorError::OutputLengthMismatch {
1158 expected: total,
1159 got: lips_out.len(),
1160 });
1161 }
1162
1163 let first = data
1164 .iter()
1165 .position(|x| !x.is_nan())
1166 .ok_or(AlligatorError::AllValuesNaN)?;
1167 let max_p = combos
1168 .iter()
1169 .map(|c| {
1170 c.jaw_period
1171 .unwrap()
1172 .max(c.teeth_period.unwrap())
1173 .max(c.lips_period.unwrap())
1174 })
1175 .max()
1176 .unwrap();
1177
1178 if data.len() - first < max_p {
1179 return Err(AlligatorError::InvalidJawPeriod {
1180 period: max_p,
1181 data_len: data.len(),
1182 });
1183 }
1184
1185 let actual = match kern {
1186 Kernel::Auto => detect_best_batch_kernel(),
1187 k => k,
1188 };
1189 let simd = match actual {
1190 Kernel::Avx512Batch => Kernel::Avx512,
1191 Kernel::Avx2Batch => Kernel::Avx2,
1192 Kernel::ScalarBatch => Kernel::Scalar,
1193 _ => unreachable!(),
1194 };
1195
1196 let do_row = |row: usize, jdst: &mut [f64], tdst: &mut [f64], ldst: &mut [f64]| unsafe {
1197 let p = &combos[row];
1198 match simd {
1199 Kernel::Scalar => {
1200 let _ = alligator_row_scalar(
1201 data,
1202 first,
1203 p.jaw_period.unwrap(),
1204 p.jaw_offset.unwrap(),
1205 p.teeth_period.unwrap(),
1206 p.teeth_offset.unwrap(),
1207 p.lips_period.unwrap(),
1208 p.lips_offset.unwrap(),
1209 cols,
1210 jdst,
1211 tdst,
1212 ldst,
1213 );
1214 }
1215 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1216 Kernel::Avx2 => {
1217 let _ = alligator_row_avx2(
1218 data,
1219 first,
1220 p.jaw_period.unwrap(),
1221 p.jaw_offset.unwrap(),
1222 p.teeth_period.unwrap(),
1223 p.teeth_offset.unwrap(),
1224 p.lips_period.unwrap(),
1225 p.lips_offset.unwrap(),
1226 cols,
1227 jdst,
1228 tdst,
1229 ldst,
1230 );
1231 }
1232 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1233 Kernel::Avx512 => {
1234 let _ = alligator_row_avx512(
1235 data,
1236 first,
1237 p.jaw_period.unwrap(),
1238 p.jaw_offset.unwrap(),
1239 p.teeth_period.unwrap(),
1240 p.teeth_offset.unwrap(),
1241 p.lips_period.unwrap(),
1242 p.lips_offset.unwrap(),
1243 cols,
1244 jdst,
1245 tdst,
1246 ldst,
1247 );
1248 }
1249 _ => unreachable!(),
1250 }
1251 };
1252
1253 if parallel {
1254 #[cfg(not(target_arch = "wasm32"))]
1255 {
1256 use rayon::prelude::*;
1257 jaw_out
1258 .par_chunks_mut(cols)
1259 .zip(teeth_out.par_chunks_mut(cols))
1260 .zip(lips_out.par_chunks_mut(cols))
1261 .enumerate()
1262 .for_each(|(r, ((j, t), l))| do_row(r, j, t, l));
1263 }
1264 #[cfg(target_arch = "wasm32")]
1265 {
1266 for (r, ((j, t), l)) in jaw_out
1267 .chunks_mut(cols)
1268 .zip(teeth_out.chunks_mut(cols))
1269 .zip(lips_out.chunks_mut(cols))
1270 .enumerate()
1271 {
1272 do_row(r, j, t, l);
1273 }
1274 }
1275 } else {
1276 for (r, ((j, t), l)) in jaw_out
1277 .chunks_mut(cols)
1278 .zip(teeth_out.chunks_mut(cols))
1279 .zip(lips_out.chunks_mut(cols))
1280 .enumerate()
1281 {
1282 do_row(r, j, t, l);
1283 }
1284 }
1285
1286 Ok(combos)
1287}
1288
1289#[inline(always)]
1290pub unsafe fn alligator_row_scalar(
1291 data: &[f64],
1292 first: usize,
1293 jaw_period: usize,
1294 jaw_offset: usize,
1295 teeth_period: usize,
1296 teeth_offset: usize,
1297 lips_period: usize,
1298 lips_offset: usize,
1299 cols: usize,
1300 jaw: &mut [f64],
1301 teeth: &mut [f64],
1302 lips: &mut [f64],
1303) -> (f64, f64, f64) {
1304 alligator_smma_scalar(
1305 data,
1306 jaw_period,
1307 jaw_offset,
1308 teeth_period,
1309 teeth_offset,
1310 lips_period,
1311 lips_offset,
1312 first,
1313 cols,
1314 jaw,
1315 teeth,
1316 lips,
1317 )
1318}
1319
1320#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1321#[inline(always)]
1322pub unsafe fn alligator_row_avx2(
1323 data: &[f64],
1324 first: usize,
1325 jaw_period: usize,
1326 jaw_offset: usize,
1327 teeth_period: usize,
1328 teeth_offset: usize,
1329 lips_period: usize,
1330 lips_offset: usize,
1331 cols: usize,
1332 jaw: &mut [f64],
1333 teeth: &mut [f64],
1334 lips: &mut [f64],
1335) -> (f64, f64, f64) {
1336 alligator_row_scalar(
1337 data,
1338 first,
1339 jaw_period,
1340 jaw_offset,
1341 teeth_period,
1342 teeth_offset,
1343 lips_period,
1344 lips_offset,
1345 cols,
1346 jaw,
1347 teeth,
1348 lips,
1349 )
1350}
1351
1352#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1353#[inline(always)]
1354pub unsafe fn alligator_row_avx512(
1355 data: &[f64],
1356 first: usize,
1357 jaw_period: usize,
1358 jaw_offset: usize,
1359 teeth_period: usize,
1360 teeth_offset: usize,
1361 lips_period: usize,
1362 lips_offset: usize,
1363 cols: usize,
1364 jaw: &mut [f64],
1365 teeth: &mut [f64],
1366 lips: &mut [f64],
1367) -> (f64, f64, f64) {
1368 alligator_row_scalar(
1369 data,
1370 first,
1371 jaw_period,
1372 jaw_offset,
1373 teeth_period,
1374 teeth_offset,
1375 lips_period,
1376 lips_offset,
1377 cols,
1378 jaw,
1379 teeth,
1380 lips,
1381 )
1382}
1383
1384#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1385#[inline(always)]
1386pub unsafe fn alligator_row_avx512_short(
1387 data: &[f64],
1388 first: usize,
1389 jaw_period: usize,
1390 jaw_offset: usize,
1391 teeth_period: usize,
1392 teeth_offset: usize,
1393 lips_period: usize,
1394 lips_offset: usize,
1395 cols: usize,
1396 jaw: &mut [f64],
1397 teeth: &mut [f64],
1398 lips: &mut [f64],
1399) -> (f64, f64, f64) {
1400 alligator_row_scalar(
1401 data,
1402 first,
1403 jaw_period,
1404 jaw_offset,
1405 teeth_period,
1406 teeth_offset,
1407 lips_period,
1408 lips_offset,
1409 cols,
1410 jaw,
1411 teeth,
1412 lips,
1413 )
1414}
1415#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1416#[inline(always)]
1417pub unsafe fn alligator_row_avx512_long(
1418 data: &[f64],
1419 first: usize,
1420 jaw_period: usize,
1421 jaw_offset: usize,
1422 teeth_period: usize,
1423 teeth_offset: usize,
1424 lips_period: usize,
1425 lips_offset: usize,
1426 cols: usize,
1427 jaw: &mut [f64],
1428 teeth: &mut [f64],
1429 lips: &mut [f64],
1430) -> (f64, f64, f64) {
1431 alligator_row_scalar(
1432 data,
1433 first,
1434 jaw_period,
1435 jaw_offset,
1436 teeth_period,
1437 teeth_offset,
1438 lips_period,
1439 lips_offset,
1440 cols,
1441 jaw,
1442 teeth,
1443 lips,
1444 )
1445}
1446
1447#[inline(always)]
1448fn expand_grid_len(r: &AlligatorBatchRange) -> usize {
1449 fn axis((start, end, step): (usize, usize, usize)) -> usize {
1450 if step == 0 || start == end {
1451 1
1452 } else {
1453 ((end - start) / step + 1)
1454 }
1455 }
1456 axis(r.jaw_period)
1457 * axis(r.jaw_offset)
1458 * axis(r.teeth_period)
1459 * axis(r.teeth_offset)
1460 * axis(r.lips_period)
1461 * axis(r.lips_offset)
1462}
1463
1464#[cfg(test)]
1465mod tests {
1466 use super::*;
1467 use crate::skip_if_unsupported;
1468 use crate::utilities::data_loader::read_candles_from_csv;
1469
1470 #[test]
1471 fn test_alligator_into_matches_api() -> Result<(), Box<dyn std::error::Error>> {
1472 let mut data = Vec::with_capacity(256);
1473 for _ in 0..7 {
1474 data.push(f64::NAN);
1475 }
1476 for i in 0..249 {
1477 let x = i as f64;
1478 data.push((x * 0.01) + (x.sin() * 0.1));
1479 }
1480
1481 let input = AlligatorInput::from_slice(&data, AlligatorParams::default());
1482
1483 let AlligatorOutput {
1484 jaw: bj,
1485 teeth: bt,
1486 lips: bl,
1487 } = alligator(&input)?;
1488
1489 let mut oj = vec![0.0; data.len()];
1490 let mut ot = vec![0.0; data.len()];
1491 let mut ol = vec![0.0; data.len()];
1492 alligator_into(&input, &mut oj, &mut ot, &mut ol)?;
1493
1494 assert_eq!(oj.len(), bj.len());
1495 assert_eq!(ot.len(), bt.len());
1496 assert_eq!(ol.len(), bl.len());
1497
1498 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1499 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1500 }
1501
1502 for i in 0..data.len() {
1503 assert!(
1504 eq_or_both_nan(oj[i], bj[i]),
1505 "jaw mismatch at {}: {} vs {}",
1506 i,
1507 oj[i],
1508 bj[i]
1509 );
1510 assert!(
1511 eq_or_both_nan(ot[i], bt[i]),
1512 "teeth mismatch at {}: {} vs {}",
1513 i,
1514 ot[i],
1515 bt[i]
1516 );
1517 assert!(
1518 eq_or_both_nan(ol[i], bl[i]),
1519 "lips mismatch at {}: {} vs {}",
1520 i,
1521 ol[i],
1522 bl[i]
1523 );
1524 }
1525 Ok(())
1526 }
1527 fn check_alligator_partial_params(
1528 test_name: &str,
1529 kernel: Kernel,
1530 ) -> Result<(), Box<dyn std::error::Error>> {
1531 skip_if_unsupported!(kernel, test_name);
1532 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1533 let candles = read_candles_from_csv(file_path)?;
1534 let partial_params = AlligatorParams {
1535 jaw_period: Some(14),
1536 jaw_offset: None,
1537 teeth_period: None,
1538 teeth_offset: None,
1539 lips_period: None,
1540 lips_offset: Some(2),
1541 };
1542 let input = AlligatorInput::from_candles(&candles, "hl2", partial_params);
1543 let result = alligator_with_kernel(&input, kernel)?;
1544 assert_eq!(result.jaw.len(), candles.close.len());
1545 assert_eq!(result.teeth.len(), candles.close.len());
1546 assert_eq!(result.lips.len(), candles.close.len());
1547 Ok(())
1548 }
1549 fn check_alligator_accuracy(
1550 test_name: &str,
1551 kernel: Kernel,
1552 ) -> Result<(), Box<dyn std::error::Error>> {
1553 skip_if_unsupported!(kernel, test_name);
1554 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1555 let candles = read_candles_from_csv(file_path)?;
1556 let hl2_prices = candles.get_calculated_field("hl2").expect("hl2 fail");
1557 let input = AlligatorInput::with_default_candles(&candles);
1558 let result = alligator_with_kernel(&input, kernel)?;
1559 let expected_last_five_jaw_result = [60742.4, 60632.6, 60555.1, 60442.7, 60308.7];
1560 let expected_last_five_teeth_result = [59908.0, 59757.2, 59684.3, 59653.5, 59621.1];
1561 let expected_last_five_lips_result = [59355.2, 59371.7, 59376.2, 59334.1, 59316.2];
1562 let start_index: usize = result.jaw.len() - 5;
1563 let result_last_five_jaws = &result.jaw[start_index..];
1564 let result_last_five_teeth = &result.teeth[start_index..];
1565 let result_last_five_lips = &result.lips[start_index..];
1566 for (i, &value) in result_last_five_jaws.iter().enumerate() {
1567 let expected_value = expected_last_five_jaw_result[i];
1568 assert!(
1569 (value - expected_value).abs() < 1e-1,
1570 "alligator jaw value mismatch at index {}: expected {}, got {}",
1571 i,
1572 expected_value,
1573 value
1574 );
1575 }
1576 for (i, &value) in result_last_five_teeth.iter().enumerate() {
1577 let expected_value = expected_last_five_teeth_result[i];
1578 assert!(
1579 (value - expected_value).abs() < 1e-1,
1580 "alligator teeth value mismatch at index {}: expected {}, got {}",
1581 i,
1582 expected_value,
1583 value
1584 );
1585 }
1586 for (i, &value) in result_last_five_lips.iter().enumerate() {
1587 let expected_value = expected_last_five_lips_result[i];
1588 assert!(
1589 (value - expected_value).abs() < 1e-1,
1590 "alligator lips value mismatch at index {}: expected {}, got {}",
1591 i,
1592 expected_value,
1593 value
1594 );
1595 }
1596 Ok(())
1597 }
1598 fn check_alligator_default_candles(
1599 test_name: &str,
1600 kernel: Kernel,
1601 ) -> Result<(), Box<dyn std::error::Error>> {
1602 skip_if_unsupported!(kernel, test_name);
1603 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1604 let candles = read_candles_from_csv(file_path)?;
1605 let input = AlligatorInput::with_default_candles(&candles);
1606 match input.data {
1607 AlligatorData::Candles { source, .. } => assert_eq!(source, "hl2"),
1608 _ => panic!("Expected AlligatorData::Candles"),
1609 }
1610 let output = alligator_with_kernel(&input, kernel)?;
1611 assert_eq!(output.jaw.len(), candles.close.len());
1612 Ok(())
1613 }
1614 fn check_alligator_with_slice_data_reinput(
1615 test_name: &str,
1616 kernel: Kernel,
1617 ) -> Result<(), Box<dyn std::error::Error>> {
1618 skip_if_unsupported!(kernel, test_name);
1619 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1620 let candles = read_candles_from_csv(file_path)?;
1621 let first_input = AlligatorInput::with_default_candles(&candles);
1622 let first_result = alligator_with_kernel(&first_input, kernel)?;
1623 let second_input =
1624 AlligatorInput::from_slice(&first_result.jaw, AlligatorParams::default());
1625 let second_result = alligator_with_kernel(&second_input, kernel)?;
1626 assert_eq!(second_result.jaw.len(), first_result.jaw.len());
1627 assert_eq!(second_result.teeth.len(), first_result.teeth.len());
1628 assert_eq!(second_result.lips.len(), first_result.lips.len());
1629 Ok(())
1630 }
1631 fn check_alligator_nan_handling(
1632 test_name: &str,
1633 kernel: Kernel,
1634 ) -> Result<(), Box<dyn std::error::Error>> {
1635 skip_if_unsupported!(kernel, test_name);
1636 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1637 let candles = read_candles_from_csv(file_path)?;
1638 let input = AlligatorInput::with_default_candles(&candles);
1639 let result = alligator_with_kernel(&input, kernel)?;
1640 if result.jaw.len() > 50 {
1641 for i in 50..result.jaw.len() {
1642 assert!(!result.jaw[i].is_nan());
1643 assert!(!result.teeth[i].is_nan());
1644 assert!(!result.lips[i].is_nan());
1645 }
1646 }
1647 Ok(())
1648 }
1649 fn check_alligator_zero_jaw_period(
1650 test_name: &str,
1651 kernel: Kernel,
1652 ) -> Result<(), Box<dyn std::error::Error>> {
1653 skip_if_unsupported!(kernel, test_name);
1654 let data = vec![10.0, 20.0, 30.0];
1655 let params = AlligatorParams {
1656 jaw_period: Some(0),
1657 ..AlligatorParams::default()
1658 };
1659 let input = AlligatorInput::from_slice(&data, params);
1660 let res = alligator_with_kernel(&input, kernel);
1661 assert!(
1662 res.is_err(),
1663 "[{}] Alligator should fail with zero jaw period",
1664 test_name
1665 );
1666 Ok(())
1667 }
1668
1669 #[cfg(feature = "proptest")]
1670 #[allow(clippy::float_cmp)]
1671 fn check_alligator_property(
1672 test_name: &str,
1673 kernel: Kernel,
1674 ) -> Result<(), Box<dyn std::error::Error>> {
1675 use proptest::prelude::*;
1676 skip_if_unsupported!(kernel, test_name);
1677
1678 let strat = (6usize..=50).prop_flat_map(|max_period| {
1679 let min_len = max_period + 10;
1680 (
1681 prop::collection::vec(
1682 (-1e6f64..1e6f64).prop_filter("finite", |x| x.is_finite()),
1683 min_len..400,
1684 ),
1685 ((max_period / 2).max(2)..=max_period),
1686 (0usize..=10),
1687 ((max_period / 3).max(2)..=(max_period * 2 / 3).max(2)),
1688 (0usize..=8),
1689 (2usize..=(max_period / 3).max(2)),
1690 (0usize..=5),
1691 )
1692 });
1693
1694 proptest::test_runner::TestRunner::default()
1695 .run(
1696 &strat,
1697 |(
1698 data,
1699 jaw_period,
1700 jaw_offset,
1701 teeth_period,
1702 teeth_offset,
1703 lips_period,
1704 lips_offset,
1705 )| {
1706 let params = AlligatorParams {
1707 jaw_period: Some(jaw_period),
1708 jaw_offset: Some(jaw_offset),
1709 teeth_period: Some(teeth_period),
1710 teeth_offset: Some(teeth_offset),
1711 lips_period: Some(lips_period),
1712 lips_offset: Some(lips_offset),
1713 };
1714 let input = AlligatorInput::from_slice(&data, params);
1715
1716 let AlligatorOutput {
1717 jaw: out_jaw,
1718 teeth: out_teeth,
1719 lips: out_lips,
1720 } = alligator_with_kernel(&input, kernel).unwrap();
1721 let AlligatorOutput {
1722 jaw: ref_jaw,
1723 teeth: ref_teeth,
1724 lips: ref_lips,
1725 } = alligator_with_kernel(&input, Kernel::Scalar).unwrap();
1726
1727 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
1728
1729 let jaw_warmup = first + jaw_period - 1 + jaw_offset;
1730 let teeth_warmup = first + teeth_period - 1 + teeth_offset;
1731 let lips_warmup = first + lips_period - 1 + lips_offset;
1732
1733 for i in 0..jaw_warmup.min(out_jaw.len()) {
1734 prop_assert!(
1735 out_jaw[i].is_nan(),
1736 "Expected NaN in jaw warmup at index {}",
1737 i
1738 );
1739 }
1740 for i in 0..teeth_warmup.min(out_teeth.len()) {
1741 prop_assert!(
1742 out_teeth[i].is_nan(),
1743 "Expected NaN in teeth warmup at index {}",
1744 i
1745 );
1746 }
1747 for i in 0..lips_warmup.min(out_lips.len()) {
1748 prop_assert!(
1749 out_lips[i].is_nan(),
1750 "Expected NaN in lips warmup at index {}",
1751 i
1752 );
1753 }
1754
1755 if jaw_warmup > 0 && jaw_warmup < data.len() {
1756 prop_assert!(
1757 out_jaw[jaw_warmup].is_finite(),
1758 "Expected first jaw value at index {} after warmup",
1759 jaw_warmup
1760 );
1761 if jaw_warmup > 0 {
1762 prop_assert!(
1763 out_jaw[jaw_warmup - 1].is_nan(),
1764 "Expected NaN before jaw warmup at index {}",
1765 jaw_warmup - 1
1766 );
1767 }
1768 }
1769 if teeth_warmup > 0 && teeth_warmup < data.len() {
1770 prop_assert!(
1771 out_teeth[teeth_warmup].is_finite(),
1772 "Expected first teeth value at index {} after warmup",
1773 teeth_warmup
1774 );
1775 if teeth_warmup > 0 {
1776 prop_assert!(
1777 out_teeth[teeth_warmup - 1].is_nan(),
1778 "Expected NaN before teeth warmup at index {}",
1779 teeth_warmup - 1
1780 );
1781 }
1782 }
1783 if lips_warmup > 0 && lips_warmup < data.len() {
1784 prop_assert!(
1785 out_lips[lips_warmup].is_finite(),
1786 "Expected first lips value at index {} after warmup",
1787 lips_warmup
1788 );
1789 if lips_warmup > 0 {
1790 prop_assert!(
1791 out_lips[lips_warmup - 1].is_nan(),
1792 "Expected NaN before lips warmup at index {}",
1793 lips_warmup - 1
1794 );
1795 }
1796 }
1797
1798 for i in 0..data.len() {
1799 let y_jaw = out_jaw[i];
1800 let r_jaw = ref_jaw[i];
1801 if !y_jaw.is_finite() || !r_jaw.is_finite() {
1802 prop_assert!(
1803 y_jaw.to_bits() == r_jaw.to_bits(),
1804 "jaw finite/NaN mismatch idx {}: {} vs {}",
1805 i,
1806 y_jaw,
1807 r_jaw
1808 );
1809 } else {
1810 let ulp_diff: u64 = y_jaw.to_bits().abs_diff(r_jaw.to_bits());
1811 prop_assert!(
1812 (y_jaw - r_jaw).abs() <= 1e-8 || ulp_diff <= 16,
1813 "jaw mismatch idx {}: {} vs {} (ULP={})",
1814 i,
1815 y_jaw,
1816 r_jaw,
1817 ulp_diff
1818 );
1819 }
1820
1821 let y_teeth = out_teeth[i];
1822 let r_teeth = ref_teeth[i];
1823 if !y_teeth.is_finite() || !r_teeth.is_finite() {
1824 prop_assert!(
1825 y_teeth.to_bits() == r_teeth.to_bits(),
1826 "teeth finite/NaN mismatch idx {}: {} vs {}",
1827 i,
1828 y_teeth,
1829 r_teeth
1830 );
1831 } else {
1832 let ulp_diff: u64 = y_teeth.to_bits().abs_diff(r_teeth.to_bits());
1833 prop_assert!(
1834 (y_teeth - r_teeth).abs() <= 1e-8 || ulp_diff <= 16,
1835 "teeth mismatch idx {}: {} vs {} (ULP={})",
1836 i,
1837 y_teeth,
1838 r_teeth,
1839 ulp_diff
1840 );
1841 }
1842
1843 let y_lips = out_lips[i];
1844 let r_lips = ref_lips[i];
1845 if !y_lips.is_finite() || !r_lips.is_finite() {
1846 prop_assert!(
1847 y_lips.to_bits() == r_lips.to_bits(),
1848 "lips finite/NaN mismatch idx {}: {} vs {}",
1849 i,
1850 y_lips,
1851 r_lips
1852 );
1853 } else {
1854 let ulp_diff: u64 = y_lips.to_bits().abs_diff(r_lips.to_bits());
1855 prop_assert!(
1856 (y_lips - r_lips).abs() <= 1e-8 || ulp_diff <= 16,
1857 "lips mismatch idx {}: {} vs {} (ULP={})",
1858 i,
1859 y_lips,
1860 r_lips,
1861 ulp_diff
1862 );
1863 }
1864 }
1865
1866 if data.len() > jaw_warmup + 10 {
1867 let segment_start = jaw_warmup;
1868 let segment_end = (jaw_warmup + 20).min(data.len());
1869
1870 let input_variance = if segment_end > segment_start + 1 {
1871 let input_segment = &data[segment_start..segment_end];
1872 let input_mean: f64 =
1873 input_segment.iter().sum::<f64>() / input_segment.len() as f64;
1874 let var: f64 = input_segment
1875 .iter()
1876 .map(|x| (x - input_mean).powi(2))
1877 .sum::<f64>()
1878 / input_segment.len() as f64;
1879 var
1880 } else {
1881 0.0
1882 };
1883
1884 let output_variance = if segment_end > segment_start + 1 {
1885 let output_segment = &out_jaw[segment_start..segment_end];
1886 let valid_outputs: Vec<f64> = output_segment
1887 .iter()
1888 .filter(|x| x.is_finite())
1889 .cloned()
1890 .collect();
1891 if valid_outputs.len() > 1 {
1892 let output_mean: f64 =
1893 valid_outputs.iter().sum::<f64>() / valid_outputs.len() as f64;
1894 let var: f64 = valid_outputs
1895 .iter()
1896 .map(|x| (x - output_mean).powi(2))
1897 .sum::<f64>()
1898 / valid_outputs.len() as f64;
1899 var
1900 } else {
1901 0.0
1902 }
1903 } else {
1904 0.0
1905 };
1906
1907 if input_variance > 1e-10 && output_variance > 1e-10 {
1908 prop_assert!(
1909 output_variance <= input_variance * 1.1,
1910 "SMMA should smooth the data: output variance {} > input variance {}",
1911 output_variance, input_variance
1912 );
1913 }
1914 }
1915
1916 if jaw_period == 1 && jaw_offset == 0 {
1917 for i in first..data.len() {
1918 prop_assert!(
1919 (out_jaw[i] - data[i]).abs() <= f64::EPSILON,
1920 "jaw with period=1, offset=0 should match input at idx {}",
1921 i
1922 );
1923 }
1924 }
1925 if teeth_period == 1 && teeth_offset == 0 {
1926 for i in first..data.len() {
1927 prop_assert!(
1928 (out_teeth[i] - data[i]).abs() <= f64::EPSILON,
1929 "teeth with period=1, offset=0 should match input at idx {}",
1930 i
1931 );
1932 }
1933 }
1934 if lips_period == 1 && lips_offset == 0 {
1935 for i in first..data.len() {
1936 prop_assert!(
1937 (out_lips[i] - data[i]).abs() <= f64::EPSILON,
1938 "lips with period=1, offset=0 should match input at idx {}",
1939 i
1940 );
1941 }
1942 }
1943
1944 if data.windows(2).all(|w| (w[0] - w[1]).abs() < f64::EPSILON)
1945 && !data.is_empty()
1946 {
1947 let constant = data[first];
1948
1949 if data.len() >= jaw_warmup + jaw_period * 5 {
1950 let check_start = data.len().saturating_sub(5);
1951 for i in check_start..data.len() {
1952 if i >= jaw_warmup && i < out_jaw.len() {
1953 prop_assert!(
1954 (out_jaw[i] - constant).abs() <= 1e-4,
1955 "jaw should converge to constant {} at idx {}, got {}",
1956 constant,
1957 i,
1958 out_jaw[i]
1959 );
1960 }
1961 }
1962 }
1963 if data.len() >= teeth_warmup + teeth_period * 5 {
1964 let check_start = data.len().saturating_sub(5);
1965 for i in check_start..data.len() {
1966 if i >= teeth_warmup && i < out_teeth.len() {
1967 prop_assert!(
1968 (out_teeth[i] - constant).abs() <= 1e-4,
1969 "teeth should converge to constant {} at idx {}, got {}",
1970 constant,
1971 i,
1972 out_teeth[i]
1973 );
1974 }
1975 }
1976 }
1977 if data.len() >= lips_warmup + lips_period * 5 {
1978 let check_start = data.len().saturating_sub(5);
1979 for i in check_start..data.len() {
1980 if i >= lips_warmup && i < out_lips.len() {
1981 prop_assert!(
1982 (out_lips[i] - constant).abs() <= 1e-4,
1983 "lips should converge to constant {} at idx {}, got {}",
1984 constant,
1985 i,
1986 out_lips[i]
1987 );
1988 }
1989 }
1990 }
1991 }
1992
1993 Ok(())
1994 },
1995 )
1996 .unwrap();
1997
1998 Ok(())
1999 }
2000
2001 #[cfg(not(feature = "proptest"))]
2002 fn check_alligator_property(
2003 _test_name: &str,
2004 _kernel: Kernel,
2005 ) -> Result<(), Box<dyn std::error::Error>> {
2006 Ok(())
2007 }
2008
2009 #[cfg(debug_assertions)]
2010 fn check_alligator_no_poison(
2011 test_name: &str,
2012 kernel: Kernel,
2013 ) -> Result<(), Box<dyn std::error::Error>> {
2014 skip_if_unsupported!(kernel, test_name);
2015
2016 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2017 let candles = read_candles_from_csv(file_path)?;
2018
2019 let test_params = vec![
2020 AlligatorParams::default(),
2021 AlligatorParams {
2022 jaw_period: Some(5),
2023 jaw_offset: Some(3),
2024 teeth_period: Some(3),
2025 teeth_offset: Some(2),
2026 lips_period: Some(2),
2027 lips_offset: Some(1),
2028 },
2029 AlligatorParams {
2030 jaw_period: Some(21),
2031 jaw_offset: Some(13),
2032 teeth_period: Some(13),
2033 teeth_offset: Some(8),
2034 lips_period: Some(8),
2035 lips_offset: Some(5),
2036 },
2037 AlligatorParams {
2038 jaw_period: Some(30),
2039 jaw_offset: Some(15),
2040 teeth_period: Some(20),
2041 teeth_offset: Some(10),
2042 lips_period: Some(10),
2043 lips_offset: Some(5),
2044 },
2045 AlligatorParams {
2046 jaw_period: Some(50),
2047 jaw_offset: Some(25),
2048 teeth_period: Some(30),
2049 teeth_offset: Some(15),
2050 lips_period: Some(20),
2051 lips_offset: Some(10),
2052 },
2053 ];
2054
2055 for (param_idx, params) in test_params.iter().enumerate() {
2056 let input = AlligatorInput::from_candles(&candles, "hl2", params.clone());
2057 let output = alligator_with_kernel(&input, kernel)?;
2058
2059 for (i, &val) in output.jaw.iter().enumerate() {
2060 if val.is_nan() {
2061 continue;
2062 }
2063
2064 let bits = val.to_bits();
2065
2066 if bits == 0x11111111_11111111 {
2067 panic!(
2068 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at jaw index {} \
2069 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2070 test_name,
2071 val,
2072 bits,
2073 i,
2074 params.jaw_period.unwrap_or(13),
2075 params.jaw_offset.unwrap_or(8),
2076 params.teeth_period.unwrap_or(8),
2077 params.teeth_offset.unwrap_or(5),
2078 params.lips_period.unwrap_or(5),
2079 params.lips_offset.unwrap_or(3),
2080 );
2081 }
2082
2083 if bits == 0x22222222_22222222 {
2084 panic!(
2085 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at jaw index {} \
2086 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2087 test_name,
2088 val,
2089 bits,
2090 i,
2091 params.jaw_period.unwrap_or(13),
2092 params.jaw_offset.unwrap_or(8),
2093 params.teeth_period.unwrap_or(8),
2094 params.teeth_offset.unwrap_or(5),
2095 params.lips_period.unwrap_or(5),
2096 params.lips_offset.unwrap_or(3),
2097 );
2098 }
2099
2100 if bits == 0x33333333_33333333 {
2101 panic!(
2102 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at jaw index {} \
2103 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2104 test_name,
2105 val,
2106 bits,
2107 i,
2108 params.jaw_period.unwrap_or(13),
2109 params.jaw_offset.unwrap_or(8),
2110 params.teeth_period.unwrap_or(8),
2111 params.teeth_offset.unwrap_or(5),
2112 params.lips_period.unwrap_or(5),
2113 params.lips_offset.unwrap_or(3),
2114 );
2115 }
2116 }
2117
2118 for (i, &val) in output.teeth.iter().enumerate() {
2119 if val.is_nan() {
2120 continue;
2121 }
2122
2123 let bits = val.to_bits();
2124
2125 if bits == 0x11111111_11111111 {
2126 panic!(
2127 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at teeth index {} \
2128 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2129 test_name,
2130 val,
2131 bits,
2132 i,
2133 params.jaw_period.unwrap_or(13),
2134 params.jaw_offset.unwrap_or(8),
2135 params.teeth_period.unwrap_or(8),
2136 params.teeth_offset.unwrap_or(5),
2137 params.lips_period.unwrap_or(5),
2138 params.lips_offset.unwrap_or(3),
2139 );
2140 }
2141
2142 if bits == 0x22222222_22222222 {
2143 panic!(
2144 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at teeth index {} \
2145 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2146 test_name,
2147 val,
2148 bits,
2149 i,
2150 params.jaw_period.unwrap_or(13),
2151 params.jaw_offset.unwrap_or(8),
2152 params.teeth_period.unwrap_or(8),
2153 params.teeth_offset.unwrap_or(5),
2154 params.lips_period.unwrap_or(5),
2155 params.lips_offset.unwrap_or(3),
2156 );
2157 }
2158
2159 if bits == 0x33333333_33333333 {
2160 panic!(
2161 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at teeth index {} \
2162 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2163 test_name,
2164 val,
2165 bits,
2166 i,
2167 params.jaw_period.unwrap_or(13),
2168 params.jaw_offset.unwrap_or(8),
2169 params.teeth_period.unwrap_or(8),
2170 params.teeth_offset.unwrap_or(5),
2171 params.lips_period.unwrap_or(5),
2172 params.lips_offset.unwrap_or(3),
2173 );
2174 }
2175 }
2176
2177 for (i, &val) in output.lips.iter().enumerate() {
2178 if val.is_nan() {
2179 continue;
2180 }
2181
2182 let bits = val.to_bits();
2183
2184 if bits == 0x11111111_11111111 {
2185 panic!(
2186 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at lips index {} \
2187 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2188 test_name,
2189 val,
2190 bits,
2191 i,
2192 params.jaw_period.unwrap_or(13),
2193 params.jaw_offset.unwrap_or(8),
2194 params.teeth_period.unwrap_or(8),
2195 params.teeth_offset.unwrap_or(5),
2196 params.lips_period.unwrap_or(5),
2197 params.lips_offset.unwrap_or(3),
2198 );
2199 }
2200
2201 if bits == 0x22222222_22222222 {
2202 panic!(
2203 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at lips index {} \
2204 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2205 test_name,
2206 val,
2207 bits,
2208 i,
2209 params.jaw_period.unwrap_or(13),
2210 params.jaw_offset.unwrap_or(8),
2211 params.teeth_period.unwrap_or(8),
2212 params.teeth_offset.unwrap_or(5),
2213 params.lips_period.unwrap_or(5),
2214 params.lips_offset.unwrap_or(3),
2215 );
2216 }
2217
2218 if bits == 0x33333333_33333333 {
2219 panic!(
2220 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at lips index {} \
2221 with params: jaw_period={}, jaw_offset={}, teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2222 test_name,
2223 val,
2224 bits,
2225 i,
2226 params.jaw_period.unwrap_or(13),
2227 params.jaw_offset.unwrap_or(8),
2228 params.teeth_period.unwrap_or(8),
2229 params.teeth_offset.unwrap_or(5),
2230 params.lips_period.unwrap_or(5),
2231 params.lips_offset.unwrap_or(3),
2232 );
2233 }
2234 }
2235 }
2236
2237 Ok(())
2238 }
2239
2240 #[cfg(not(debug_assertions))]
2241 fn check_alligator_no_poison(
2242 _test_name: &str,
2243 _kernel: Kernel,
2244 ) -> Result<(), Box<dyn std::error::Error>> {
2245 Ok(())
2246 }
2247 macro_rules! generate_all_alligator_tests {
2248 ($($test_fn:ident),*) => {
2249 paste! {
2250 $(
2251 #[test]
2252 fn [<$test_fn _scalar_f64>]() {
2253 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
2254 }
2255 )*
2256 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2257 $(
2258 #[test]
2259 fn [<$test_fn _avx2_f64>]() {
2260 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
2261 }
2262 #[test]
2263 fn [<$test_fn _avx512_f64>]() {
2264 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
2265 }
2266 )*
2267 }
2268 }
2269 }
2270 generate_all_alligator_tests!(
2271 check_alligator_partial_params,
2272 check_alligator_accuracy,
2273 check_alligator_default_candles,
2274 check_alligator_with_slice_data_reinput,
2275 check_alligator_nan_handling,
2276 check_alligator_zero_jaw_period,
2277 check_alligator_property,
2278 check_alligator_no_poison
2279 );
2280 fn check_batch_default_row(
2281 test: &str,
2282 kernel: Kernel,
2283 ) -> Result<(), Box<dyn std::error::Error>> {
2284 skip_if_unsupported!(kernel, test);
2285 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2286 let c = read_candles_from_csv(file)?;
2287 let output = AlligatorBatchBuilder::new()
2288 .kernel(kernel)
2289 .apply_candles(&c, "hl2")?;
2290 let def = AlligatorParams::default();
2291 let (row_jaw, row_teeth, row_lips) = output.values_for(&def).expect("default row missing");
2292 assert_eq!(row_jaw.len(), c.close.len());
2293 let expected = [60742.4, 60632.6, 60555.1, 60442.7, 60308.7];
2294 let start = row_jaw.len() - 5;
2295 for (i, &v) in row_jaw[start..].iter().enumerate() {
2296 assert!(
2297 (v - expected[i]).abs() < 1e-1,
2298 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
2299 );
2300 }
2301 Ok(())
2302 }
2303 macro_rules! gen_batch_tests {
2304 ($fn_name:ident) => {
2305 paste! {
2306 #[test] fn [<$fn_name _scalar>]() {
2307 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
2308 }
2309 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2310 #[test] fn [<$fn_name _avx2>]() {
2311 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
2312 }
2313 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
2314 #[test] fn [<$fn_name _avx512>]() {
2315 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
2316 }
2317 #[test] fn [<$fn_name _auto_detect>]() {
2318 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
2319 }
2320 }
2321 };
2322 }
2323 #[cfg(debug_assertions)]
2324 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn std::error::Error>> {
2325 skip_if_unsupported!(kernel, test);
2326
2327 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2328 let c = read_candles_from_csv(file)?;
2329
2330 let test_configs = vec![
2331 (5, 15, 2, 3, 10, 2, 3, 10, 2, 2, 8, 2, 2, 8, 2, 1, 5, 2),
2332 (10, 20, 5, 5, 10, 5, 8, 15, 5, 3, 8, 5, 3, 8, 3, 1, 5, 2),
2333 (13, 13, 0, 8, 8, 0, 8, 8, 0, 5, 5, 0, 5, 5, 0, 3, 3, 0),
2334 (
2335 20, 30, 10, 10, 15, 5, 15, 20, 5, 8, 10, 2, 10, 12, 2, 5, 6, 1,
2336 ),
2337 ];
2338
2339 for (
2340 cfg_idx,
2341 &(
2342 jp_start,
2343 jp_end,
2344 jp_step,
2345 jo_start,
2346 jo_end,
2347 jo_step,
2348 tp_start,
2349 tp_end,
2350 tp_step,
2351 to_start,
2352 to_end,
2353 to_step,
2354 lp_start,
2355 lp_end,
2356 lp_step,
2357 lo_start,
2358 lo_end,
2359 lo_step,
2360 ),
2361 ) in test_configs.iter().enumerate()
2362 {
2363 let output = AlligatorBatchBuilder::new()
2364 .kernel(kernel)
2365 .jaw_period_range(jp_start, jp_end, jp_step)
2366 .jaw_offset_range(jo_start, jo_end, jo_step)
2367 .teeth_period_range(tp_start, tp_end, tp_step)
2368 .teeth_offset_range(to_start, to_end, to_step)
2369 .lips_period_range(lp_start, lp_end, lp_step)
2370 .lips_offset_range(lo_start, lo_end, lo_step)
2371 .apply_candles(&c, "hl2")?;
2372
2373 for (idx, &val) in output.jaw.iter().enumerate() {
2374 if val.is_nan() {
2375 continue;
2376 }
2377
2378 let bits = val.to_bits();
2379 let row = idx / output.cols;
2380 let col = idx % output.cols;
2381 let combo = &output.combos[row];
2382
2383 if bits == 0x11111111_11111111 {
2384 panic!(
2385 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2386 at jaw row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2387 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2388 test,
2389 cfg_idx,
2390 val,
2391 bits,
2392 row,
2393 col,
2394 idx,
2395 combo.jaw_period.unwrap_or(13),
2396 combo.jaw_offset.unwrap_or(8),
2397 combo.teeth_period.unwrap_or(8),
2398 combo.teeth_offset.unwrap_or(5),
2399 combo.lips_period.unwrap_or(5),
2400 combo.lips_offset.unwrap_or(3),
2401 );
2402 }
2403
2404 if bits == 0x22222222_22222222 {
2405 panic!(
2406 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2407 at jaw row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2408 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2409 test,
2410 cfg_idx,
2411 val,
2412 bits,
2413 row,
2414 col,
2415 idx,
2416 combo.jaw_period.unwrap_or(13),
2417 combo.jaw_offset.unwrap_or(8),
2418 combo.teeth_period.unwrap_or(8),
2419 combo.teeth_offset.unwrap_or(5),
2420 combo.lips_period.unwrap_or(5),
2421 combo.lips_offset.unwrap_or(3),
2422 );
2423 }
2424
2425 if bits == 0x33333333_33333333 {
2426 panic!(
2427 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2428 at jaw row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2429 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2430 test,
2431 cfg_idx,
2432 val,
2433 bits,
2434 row,
2435 col,
2436 idx,
2437 combo.jaw_period.unwrap_or(13),
2438 combo.jaw_offset.unwrap_or(8),
2439 combo.teeth_period.unwrap_or(8),
2440 combo.teeth_offset.unwrap_or(5),
2441 combo.lips_period.unwrap_or(5),
2442 combo.lips_offset.unwrap_or(3),
2443 );
2444 }
2445 }
2446
2447 for (idx, &val) in output.teeth.iter().enumerate() {
2448 if val.is_nan() {
2449 continue;
2450 }
2451
2452 let bits = val.to_bits();
2453 let row = idx / output.cols;
2454 let col = idx % output.cols;
2455 let combo = &output.combos[row];
2456
2457 if bits == 0x11111111_11111111 {
2458 panic!(
2459 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2460 at teeth row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2461 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2462 test,
2463 cfg_idx,
2464 val,
2465 bits,
2466 row,
2467 col,
2468 idx,
2469 combo.jaw_period.unwrap_or(13),
2470 combo.jaw_offset.unwrap_or(8),
2471 combo.teeth_period.unwrap_or(8),
2472 combo.teeth_offset.unwrap_or(5),
2473 combo.lips_period.unwrap_or(5),
2474 combo.lips_offset.unwrap_or(3),
2475 );
2476 }
2477
2478 if bits == 0x22222222_22222222 {
2479 panic!(
2480 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2481 at teeth row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2482 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2483 test,
2484 cfg_idx,
2485 val,
2486 bits,
2487 row,
2488 col,
2489 idx,
2490 combo.jaw_period.unwrap_or(13),
2491 combo.jaw_offset.unwrap_or(8),
2492 combo.teeth_period.unwrap_or(8),
2493 combo.teeth_offset.unwrap_or(5),
2494 combo.lips_period.unwrap_or(5),
2495 combo.lips_offset.unwrap_or(3),
2496 );
2497 }
2498
2499 if bits == 0x33333333_33333333 {
2500 panic!(
2501 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2502 at teeth row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2503 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2504 test,
2505 cfg_idx,
2506 val,
2507 bits,
2508 row,
2509 col,
2510 idx,
2511 combo.jaw_period.unwrap_or(13),
2512 combo.jaw_offset.unwrap_or(8),
2513 combo.teeth_period.unwrap_or(8),
2514 combo.teeth_offset.unwrap_or(5),
2515 combo.lips_period.unwrap_or(5),
2516 combo.lips_offset.unwrap_or(3),
2517 );
2518 }
2519 }
2520
2521 for (idx, &val) in output.lips.iter().enumerate() {
2522 if val.is_nan() {
2523 continue;
2524 }
2525
2526 let bits = val.to_bits();
2527 let row = idx / output.cols;
2528 let col = idx % output.cols;
2529 let combo = &output.combos[row];
2530
2531 if bits == 0x11111111_11111111 {
2532 panic!(
2533 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
2534 at lips row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2535 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2536 test,
2537 cfg_idx,
2538 val,
2539 bits,
2540 row,
2541 col,
2542 idx,
2543 combo.jaw_period.unwrap_or(13),
2544 combo.jaw_offset.unwrap_or(8),
2545 combo.teeth_period.unwrap_or(8),
2546 combo.teeth_offset.unwrap_or(5),
2547 combo.lips_period.unwrap_or(5),
2548 combo.lips_offset.unwrap_or(3),
2549 );
2550 }
2551
2552 if bits == 0x22222222_22222222 {
2553 panic!(
2554 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
2555 at lips row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2556 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2557 test,
2558 cfg_idx,
2559 val,
2560 bits,
2561 row,
2562 col,
2563 idx,
2564 combo.jaw_period.unwrap_or(13),
2565 combo.jaw_offset.unwrap_or(8),
2566 combo.teeth_period.unwrap_or(8),
2567 combo.teeth_offset.unwrap_or(5),
2568 combo.lips_period.unwrap_or(5),
2569 combo.lips_offset.unwrap_or(3),
2570 );
2571 }
2572
2573 if bits == 0x33333333_33333333 {
2574 panic!(
2575 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
2576 at lips row {} col {} (flat index {}) with params: jaw_period={}, jaw_offset={}, \
2577 teeth_period={}, teeth_offset={}, lips_period={}, lips_offset={}",
2578 test,
2579 cfg_idx,
2580 val,
2581 bits,
2582 row,
2583 col,
2584 idx,
2585 combo.jaw_period.unwrap_or(13),
2586 combo.jaw_offset.unwrap_or(8),
2587 combo.teeth_period.unwrap_or(8),
2588 combo.teeth_offset.unwrap_or(5),
2589 combo.lips_period.unwrap_or(5),
2590 combo.lips_offset.unwrap_or(3),
2591 );
2592 }
2593 }
2594 }
2595
2596 Ok(())
2597 }
2598
2599 #[cfg(not(debug_assertions))]
2600 fn check_batch_no_poison(
2601 _test: &str,
2602 _kernel: Kernel,
2603 ) -> Result<(), Box<dyn std::error::Error>> {
2604 Ok(())
2605 }
2606
2607 gen_batch_tests!(check_batch_default_row);
2608 gen_batch_tests!(check_batch_no_poison);
2609
2610 #[test]
2611 fn test_invalid_kernel_error() {
2612 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2613 let sweep = AlligatorBatchRange {
2614 jaw_period: (5, 5, 0),
2615 jaw_offset: (1, 1, 0),
2616 teeth_period: (3, 3, 0),
2617 teeth_offset: (1, 1, 0),
2618 lips_period: (2, 2, 0),
2619 lips_offset: (1, 1, 0),
2620 };
2621
2622 let result = alligator_batch_with_kernel(&data, &sweep, Kernel::Scalar);
2623 assert!(matches!(
2624 result,
2625 Err(AlligatorError::InvalidKernelForBatch(Kernel::Scalar))
2626 ));
2627
2628 let result = alligator_batch_with_kernel(&data, &sweep, Kernel::Avx2);
2629 assert!(matches!(
2630 result,
2631 Err(AlligatorError::InvalidKernelForBatch(Kernel::Avx2))
2632 ));
2633
2634 let result = alligator_batch_with_kernel(&data, &sweep, Kernel::ScalarBatch);
2635 assert!(result.is_ok());
2636
2637 let result = alligator_batch_with_kernel(&data, &sweep, Kernel::Auto);
2638 assert!(result.is_ok());
2639 }
2640}
2641
2642#[cfg(feature = "python")]
2643#[pyfunction(name = "alligator")]
2644#[pyo3(signature = (data, jaw_period=13, jaw_offset=8, teeth_period=8, teeth_offset=5, lips_period=5, lips_offset=3, kernel=None))]
2645pub fn alligator_py<'py>(
2646 py: Python<'py>,
2647 data: numpy::PyReadonlyArray1<'py, f64>,
2648 jaw_period: usize,
2649 jaw_offset: usize,
2650 teeth_period: usize,
2651 teeth_offset: usize,
2652 lips_period: usize,
2653 lips_offset: usize,
2654 kernel: Option<&str>,
2655) -> PyResult<Bound<'py, PyDict>> {
2656 use numpy::{IntoPyArray, PyArray1};
2657 use pyo3::types::PyDict;
2658
2659 let slice_in = data.as_slice()?;
2660 let params = AlligatorParams {
2661 jaw_period: Some(jaw_period),
2662 jaw_offset: Some(jaw_offset),
2663 teeth_period: Some(teeth_period),
2664 teeth_offset: Some(teeth_offset),
2665 lips_period: Some(lips_period),
2666 lips_offset: Some(lips_offset),
2667 };
2668 let input = AlligatorInput::from_slice(slice_in, params);
2669 let kern = validate_kernel(kernel, false)?;
2670
2671 let out = py
2672 .allow_threads(|| alligator_with_kernel(&input, kern))
2673 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2674
2675 let dict = PyDict::new(py);
2676 dict.set_item("jaw", out.jaw.into_pyarray(py))?;
2677 dict.set_item("teeth", out.teeth.into_pyarray(py))?;
2678 dict.set_item("lips", out.lips.into_pyarray(py))?;
2679 Ok(dict)
2680}
2681
2682#[cfg(feature = "python")]
2683#[pyclass(name = "AlligatorStream")]
2684pub struct AlligatorStreamPy {
2685 stream: AlligatorStream,
2686}
2687
2688#[cfg(feature = "python")]
2689#[pymethods]
2690impl AlligatorStreamPy {
2691 #[new]
2692 #[pyo3(signature = (jaw_period=13, jaw_offset=8, teeth_period=8, teeth_offset=5, lips_period=5, lips_offset=3))]
2693 fn new(
2694 jaw_period: usize,
2695 jaw_offset: usize,
2696 teeth_period: usize,
2697 teeth_offset: usize,
2698 lips_period: usize,
2699 lips_offset: usize,
2700 ) -> PyResult<Self> {
2701 let params = AlligatorParams {
2702 jaw_period: Some(jaw_period),
2703 jaw_offset: Some(jaw_offset),
2704 teeth_period: Some(teeth_period),
2705 teeth_offset: Some(teeth_offset),
2706 lips_period: Some(lips_period),
2707 lips_offset: Some(lips_offset),
2708 };
2709 let stream =
2710 AlligatorStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
2711 Ok(AlligatorStreamPy { stream })
2712 }
2713
2714 fn update(&mut self, value: f64) -> Option<(f64, f64, f64)> {
2715 self.stream.update(value)
2716 }
2717}
2718
2719#[cfg(feature = "python")]
2720#[pyfunction(name = "alligator_batch")]
2721#[pyo3(signature = (data, jaw_period_range, jaw_offset_range, teeth_period_range, teeth_offset_range, lips_period_range, lips_offset_range, kernel=None))]
2722pub fn alligator_batch_py<'py>(
2723 py: Python<'py>,
2724 data: numpy::PyReadonlyArray1<'py, f64>,
2725 jaw_period_range: (usize, usize, usize),
2726 jaw_offset_range: (usize, usize, usize),
2727 teeth_period_range: (usize, usize, usize),
2728 teeth_offset_range: (usize, usize, usize),
2729 lips_period_range: (usize, usize, usize),
2730 lips_offset_range: (usize, usize, usize),
2731 kernel: Option<&str>,
2732) -> PyResult<Bound<'py, PyDict>> {
2733 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
2734 let slice_in = data.as_slice()?;
2735 let sweep = AlligatorBatchRange {
2736 jaw_period: jaw_period_range,
2737 jaw_offset: jaw_offset_range,
2738 teeth_period: teeth_period_range,
2739 teeth_offset: teeth_offset_range,
2740 lips_period: lips_period_range,
2741 lips_offset: lips_offset_range,
2742 };
2743
2744 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
2745 let rows = combos.len();
2746 let cols = slice_in.len();
2747 let total = rows
2748 .checked_mul(cols)
2749 .ok_or_else(|| PyValueError::new_err("alligator_batch_py: rows*cols overflow"))?;
2750
2751 let jaw_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2752 let teeth_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2753 let lips_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
2754 let jaw_out = unsafe { jaw_arr.as_slice_mut()? };
2755 let teeth_out = unsafe { teeth_arr.as_slice_mut()? };
2756 let lips_out = unsafe { lips_arr.as_slice_mut()? };
2757
2758 let kern = validate_kernel(kernel, true)?;
2759 let combos = py
2760 .allow_threads(|| {
2761 let batch_k = match kern {
2762 Kernel::Auto => detect_best_batch_kernel(),
2763 k => k,
2764 };
2765 alligator_batch_inner_into(
2766 slice_in, &sweep, batch_k, true, jaw_out, teeth_out, lips_out,
2767 )
2768 })
2769 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2770
2771 let dict = PyDict::new(py);
2772 dict.set_item("jaw", jaw_arr.reshape((rows, cols))?)?;
2773 dict.set_item("teeth", teeth_arr.reshape((rows, cols))?)?;
2774 dict.set_item("lips", lips_arr.reshape((rows, cols))?)?;
2775
2776 dict.set_item(
2777 "jaw_periods",
2778 combos
2779 .iter()
2780 .map(|p| p.jaw_period.unwrap() as u64)
2781 .collect::<Vec<_>>()
2782 .into_pyarray(py),
2783 )?;
2784 dict.set_item(
2785 "jaw_offsets",
2786 combos
2787 .iter()
2788 .map(|p| p.jaw_offset.unwrap() as u64)
2789 .collect::<Vec<_>>()
2790 .into_pyarray(py),
2791 )?;
2792 dict.set_item(
2793 "teeth_periods",
2794 combos
2795 .iter()
2796 .map(|p| p.teeth_period.unwrap() as u64)
2797 .collect::<Vec<_>>()
2798 .into_pyarray(py),
2799 )?;
2800 dict.set_item(
2801 "teeth_offsets",
2802 combos
2803 .iter()
2804 .map(|p| p.teeth_offset.unwrap() as u64)
2805 .collect::<Vec<_>>()
2806 .into_pyarray(py),
2807 )?;
2808 dict.set_item(
2809 "lips_periods",
2810 combos
2811 .iter()
2812 .map(|p| p.lips_period.unwrap() as u64)
2813 .collect::<Vec<_>>()
2814 .into_pyarray(py),
2815 )?;
2816 dict.set_item(
2817 "lips_offsets",
2818 combos
2819 .iter()
2820 .map(|p| p.lips_offset.unwrap() as u64)
2821 .collect::<Vec<_>>()
2822 .into_pyarray(py),
2823 )?;
2824
2825 Ok(dict)
2826}
2827
2828pub fn alligator_into_slice(
2829 jaw_dst: &mut [f64],
2830 teeth_dst: &mut [f64],
2831 lips_dst: &mut [f64],
2832 input: &AlligatorInput,
2833 kern: Kernel,
2834) -> Result<(), AlligatorError> {
2835 let data: &[f64] = match &input.data {
2836 AlligatorData::Candles { candles, source } => source_type(candles, source),
2837 AlligatorData::Slice(sl) => sl,
2838 };
2839
2840 let first = data
2841 .iter()
2842 .position(|x| !x.is_nan())
2843 .ok_or(AlligatorError::AllValuesNaN)?;
2844
2845 let len = data.len();
2846
2847 if jaw_dst.len() != len {
2848 return Err(AlligatorError::OutputLengthMismatch {
2849 expected: len,
2850 got: jaw_dst.len(),
2851 });
2852 }
2853 if teeth_dst.len() != len {
2854 return Err(AlligatorError::OutputLengthMismatch {
2855 expected: len,
2856 got: teeth_dst.len(),
2857 });
2858 }
2859 if lips_dst.len() != len {
2860 return Err(AlligatorError::OutputLengthMismatch {
2861 expected: len,
2862 got: lips_dst.len(),
2863 });
2864 }
2865
2866 let jaw_period = input.get_jaw_period();
2867 let jaw_offset = input.get_jaw_offset();
2868 let teeth_period = input.get_teeth_period();
2869 let teeth_offset = input.get_teeth_offset();
2870 let lips_period = input.get_lips_period();
2871 let lips_offset = input.get_lips_offset();
2872
2873 if jaw_period == 0 || jaw_period > len {
2874 return Err(AlligatorError::InvalidJawPeriod {
2875 period: jaw_period,
2876 data_len: len,
2877 });
2878 }
2879 if jaw_offset > len {
2880 return Err(AlligatorError::InvalidJawOffset {
2881 offset: jaw_offset,
2882 data_len: len,
2883 });
2884 }
2885 if teeth_period == 0 || teeth_period > len {
2886 return Err(AlligatorError::InvalidTeethPeriod {
2887 period: teeth_period,
2888 data_len: len,
2889 });
2890 }
2891 if teeth_offset > len {
2892 return Err(AlligatorError::InvalidTeethOffset {
2893 offset: teeth_offset,
2894 data_len: len,
2895 });
2896 }
2897 if lips_period == 0 || lips_period > len {
2898 return Err(AlligatorError::InvalidLipsPeriod {
2899 period: lips_period,
2900 data_len: len,
2901 });
2902 }
2903 if lips_offset > len {
2904 return Err(AlligatorError::InvalidLipsOffset {
2905 offset: lips_offset,
2906 data_len: len,
2907 });
2908 }
2909
2910 let jaw_warmup = first + jaw_period - 1 + jaw_offset;
2911 let teeth_warmup = first + teeth_period - 1 + teeth_offset;
2912 let lips_warmup = first + lips_period - 1 + lips_offset;
2913
2914 for v in &mut jaw_dst[..jaw_warmup] {
2915 *v = f64::NAN;
2916 }
2917 for v in &mut teeth_dst[..teeth_warmup] {
2918 *v = f64::NAN;
2919 }
2920 for v in &mut lips_dst[..lips_warmup] {
2921 *v = f64::NAN;
2922 }
2923
2924 unsafe {
2925 alligator_smma_scalar(
2926 data,
2927 jaw_period,
2928 jaw_offset,
2929 teeth_period,
2930 teeth_offset,
2931 lips_period,
2932 lips_offset,
2933 first,
2934 len,
2935 jaw_dst,
2936 teeth_dst,
2937 lips_dst,
2938 );
2939 }
2940
2941 Ok(())
2942}
2943
2944#[cfg(all(feature = "python", feature = "cuda"))]
2945use crate::cuda::{cuda_available, CudaAlligator};
2946#[cfg(all(feature = "python", feature = "cuda"))]
2947use crate::utilities::dlpack_cuda::{make_device_array_py, DeviceArrayF32Py};
2948#[cfg(all(feature = "python", feature = "cuda"))]
2949use cust::context::Context;
2950#[cfg(all(feature = "python", feature = "cuda"))]
2951use std::sync::Arc;
2952
2953#[cfg(all(feature = "python", feature = "cuda"))]
2954#[pyclass(module = "ta_indicators.cuda", name = "CudaContextGuard", unsendable)]
2955struct CudaContextGuardPy {
2956 #[pyo3(get)]
2957 device_id: u32,
2958 _ctx: Arc<Context>,
2959}
2960
2961#[cfg(all(feature = "python", feature = "cuda"))]
2962#[pyfunction(name = "alligator_cuda_batch_dev")]
2963#[pyo3(signature = (prices_f32, jaw_period, jaw_offset, teeth_period, teeth_offset, lips_period, lips_offset, device_id=0))]
2964pub fn alligator_cuda_batch_dev_py<'py>(
2965 py: Python<'py>,
2966 prices_f32: numpy::PyReadonlyArray1<'py, f32>,
2967 jaw_period: (usize, usize, usize),
2968 jaw_offset: (usize, usize, usize),
2969 teeth_period: (usize, usize, usize),
2970 teeth_offset: (usize, usize, usize),
2971 lips_period: (usize, usize, usize),
2972 lips_offset: (usize, usize, usize),
2973 device_id: usize,
2974) -> PyResult<Bound<'py, PyDict>> {
2975 if !cuda_available() {
2976 return Err(PyValueError::new_err("CUDA not available"));
2977 }
2978 let slice = prices_f32.as_slice()?;
2979 let sweep = AlligatorBatchRange {
2980 jaw_period,
2981 jaw_offset,
2982 teeth_period,
2983 teeth_offset,
2984 lips_period,
2985 lips_offset,
2986 };
2987 let (jaw, teeth, lips, rows, cols, jp, jo, tp, to, lp, lo, guard_dev, guard_ctx) = py
2988 .allow_threads(|| {
2989 let cuda =
2990 CudaAlligator::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
2991 let res = cuda
2992 .alligator_batch_dev(slice, &sweep)
2993 .map_err(|e| PyValueError::new_err(e.to_string()))?;
2994 let rows = res.outputs.rows();
2995 let cols = res.outputs.cols();
2996 let jp: Vec<usize> = res.combos.iter().map(|c| c.jaw_period.unwrap()).collect();
2997 let jo: Vec<usize> = res.combos.iter().map(|c| c.jaw_offset.unwrap()).collect();
2998 let tp: Vec<usize> = res.combos.iter().map(|c| c.teeth_period.unwrap()).collect();
2999 let to: Vec<usize> = res.combos.iter().map(|c| c.teeth_offset.unwrap()).collect();
3000 let lp: Vec<usize> = res.combos.iter().map(|c| c.lips_period.unwrap()).collect();
3001 let lo: Vec<usize> = res.combos.iter().map(|c| c.lips_offset.unwrap()).collect();
3002 Ok::<_, PyErr>((
3003 res.outputs.jaw,
3004 res.outputs.teeth,
3005 res.outputs.lips,
3006 rows,
3007 cols,
3008 jp,
3009 jo,
3010 tp,
3011 to,
3012 lp,
3013 lo,
3014 res.outputs.device_id,
3015 res.outputs._ctx.clone(),
3016 ))
3017 })?;
3018 use numpy::IntoPyArray;
3019 let d = PyDict::new(py);
3020 let jaw_py = make_device_array_py(guard_dev as usize, jaw)?;
3021 let teeth_py = make_device_array_py(guard_dev as usize, teeth)?;
3022 let lips_py = make_device_array_py(guard_dev as usize, lips)?;
3023 d.set_item("jaw", Py::new(py, jaw_py)?)?;
3024 d.set_item("teeth", Py::new(py, teeth_py)?)?;
3025 d.set_item("lips", Py::new(py, lips_py)?)?;
3026 d.set_item("rows", rows)?;
3027 d.set_item("cols", cols)?;
3028 d.set_item("jaw_periods", jp.into_pyarray(py))?;
3029 d.set_item("jaw_offsets", jo.into_pyarray(py))?;
3030 d.set_item("teeth_periods", tp.into_pyarray(py))?;
3031 d.set_item("teeth_offsets", to.into_pyarray(py))?;
3032 d.set_item("lips_periods", lp.into_pyarray(py))?;
3033 d.set_item("lips_offsets", lo.into_pyarray(py))?;
3034 d.set_item(
3035 "context_guard",
3036 Py::new(
3037 py,
3038 CudaContextGuardPy {
3039 device_id: guard_dev,
3040 _ctx: guard_ctx,
3041 },
3042 )?,
3043 )?;
3044 Ok(d)
3045}
3046
3047#[cfg(all(feature = "python", feature = "cuda"))]
3048#[pyfunction(name = "alligator_cuda_many_series_one_param_dev")]
3049#[pyo3(signature = (data_tm_f32, jaw_period, jaw_offset, teeth_period, teeth_offset, lips_period, lips_offset, device_id=0))]
3050pub fn alligator_cuda_many_series_one_param_dev_py<'py>(
3051 py: Python<'py>,
3052 data_tm_f32: numpy::PyReadonlyArray2<'py, f32>,
3053 jaw_period: usize,
3054 jaw_offset: usize,
3055 teeth_period: usize,
3056 teeth_offset: usize,
3057 lips_period: usize,
3058 lips_offset: usize,
3059 device_id: usize,
3060) -> PyResult<Bound<'py, PyDict>> {
3061 use numpy::PyUntypedArrayMethods;
3062 if !cuda_available() {
3063 return Err(PyValueError::new_err("CUDA not available"));
3064 }
3065 let shape = data_tm_f32.shape();
3066 if shape.len() != 2 {
3067 return Err(PyValueError::new_err("expected 2D array"));
3068 }
3069 let rows = shape[0];
3070 let cols = shape[1];
3071 let flat = data_tm_f32.as_slice()?;
3072 let params = AlligatorParams {
3073 jaw_period: Some(jaw_period),
3074 jaw_offset: Some(jaw_offset),
3075 teeth_period: Some(teeth_period),
3076 teeth_offset: Some(teeth_offset),
3077 lips_period: Some(lips_period),
3078 lips_offset: Some(lips_offset),
3079 };
3080 let (jaw, teeth, lips, guard_dev, guard_ctx) = py.allow_threads(|| {
3081 let cuda =
3082 CudaAlligator::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
3083 let out = cuda
3084 .alligator_many_series_one_param_time_major_dev(flat, cols, rows, ¶ms)
3085 .map_err(|e| PyValueError::new_err(e.to_string()))?;
3086 Ok::<_, PyErr>((
3087 out.jaw,
3088 out.teeth,
3089 out.lips,
3090 cuda.device_id(),
3091 cuda.context_arc(),
3092 ))
3093 })?;
3094 let d = PyDict::new(py);
3095 let jaw_py = make_device_array_py(guard_dev as usize, jaw)?;
3096 let teeth_py = make_device_array_py(guard_dev as usize, teeth)?;
3097 let lips_py = make_device_array_py(guard_dev as usize, lips)?;
3098 d.set_item("jaw", Py::new(py, jaw_py)?)?;
3099 d.set_item("teeth", Py::new(py, teeth_py)?)?;
3100 d.set_item("lips", Py::new(py, lips_py)?)?;
3101 d.set_item("rows", rows)?;
3102 d.set_item("cols", cols)?;
3103 d.set_item(
3104 "context_guard",
3105 Py::new(
3106 py,
3107 CudaContextGuardPy {
3108 device_id: guard_dev,
3109 _ctx: guard_ctx,
3110 },
3111 )?,
3112 )?;
3113 Ok(d)
3114}
3115
3116#[inline]
3117pub fn alligator_into_slices(
3118 jaw_out: &mut [f64],
3119 teeth_out: &mut [f64],
3120 lips_out: &mut [f64],
3121 input: &AlligatorInput,
3122 kern: Kernel,
3123) -> Result<(), AlligatorError> {
3124 let data: &[f64] = input.as_ref();
3125 let len = data.len();
3126 if len == 0 {
3127 return Err(AlligatorError::EmptyInputData);
3128 }
3129 if jaw_out.len() != len {
3130 return Err(AlligatorError::OutputLengthMismatch {
3131 expected: len,
3132 got: jaw_out.len(),
3133 });
3134 }
3135 if teeth_out.len() != len {
3136 return Err(AlligatorError::OutputLengthMismatch {
3137 expected: len,
3138 got: teeth_out.len(),
3139 });
3140 }
3141 if lips_out.len() != len {
3142 return Err(AlligatorError::OutputLengthMismatch {
3143 expected: len,
3144 got: lips_out.len(),
3145 });
3146 }
3147
3148 let first = data
3149 .iter()
3150 .position(|x| !x.is_nan())
3151 .ok_or(AlligatorError::AllValuesNaN)?;
3152 let jp = input.get_jaw_period();
3153 let jo = input.get_jaw_offset();
3154 let tp = input.get_teeth_period();
3155 let to = input.get_teeth_offset();
3156 let lp = input.get_lips_period();
3157 let lo = input.get_lips_offset();
3158
3159 if jp == 0 || jp > len {
3160 return Err(AlligatorError::InvalidJawPeriod {
3161 period: jp,
3162 data_len: len,
3163 });
3164 }
3165 if tp == 0 || tp > len {
3166 return Err(AlligatorError::InvalidTeethPeriod {
3167 period: tp,
3168 data_len: len,
3169 });
3170 }
3171 if lp == 0 || lp > len {
3172 return Err(AlligatorError::InvalidLipsPeriod {
3173 period: lp,
3174 data_len: len,
3175 });
3176 }
3177 if jo > len {
3178 return Err(AlligatorError::InvalidJawOffset {
3179 offset: jo,
3180 data_len: len,
3181 });
3182 }
3183 if to > len {
3184 return Err(AlligatorError::InvalidTeethOffset {
3185 offset: to,
3186 data_len: len,
3187 });
3188 }
3189 if lo > len {
3190 return Err(AlligatorError::InvalidLipsOffset {
3191 offset: lo,
3192 data_len: len,
3193 });
3194 }
3195
3196 let jw = first + jp - 1 + jo;
3197 let tw = first + tp - 1 + to;
3198 let lw = first + lp - 1 + lo;
3199 for v in &mut jaw_out[..jw.min(len)] {
3200 *v = f64::NAN;
3201 }
3202 for v in &mut teeth_out[..tw.min(len)] {
3203 *v = f64::NAN;
3204 }
3205 for v in &mut lips_out[..lw.min(len)] {
3206 *v = f64::NAN;
3207 }
3208
3209 let chosen = match kern {
3210 Kernel::Auto => Kernel::Scalar,
3211 k => k,
3212 };
3213
3214 unsafe {
3215 match chosen {
3216 Kernel::Scalar | Kernel::ScalarBatch => {
3217 let _ = alligator_smma_scalar(
3218 data, jp, jo, tp, to, lp, lo, first, len, jaw_out, teeth_out, lips_out,
3219 );
3220 }
3221 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3222 Kernel::Avx2 | Kernel::Avx2Batch => {
3223 let _ = alligator_row_avx2(
3224 data, first, jp, jo, tp, to, lp, lo, len, jaw_out, teeth_out, lips_out,
3225 );
3226 }
3227 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
3228 Kernel::Avx512 | Kernel::Avx512Batch => {
3229 let _ = alligator_row_avx512(
3230 data, first, jp, jo, tp, to, lp, lo, len, jaw_out, teeth_out, lips_out,
3231 );
3232 }
3233 _ => unreachable!(),
3234 }
3235 }
3236 Ok(())
3237}
3238
3239#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
3240#[inline]
3241pub fn alligator_into(
3242 input: &AlligatorInput,
3243 jaw_out: &mut [f64],
3244 teeth_out: &mut [f64],
3245 lips_out: &mut [f64],
3246) -> Result<(), AlligatorError> {
3247 alligator_into_slices(jaw_out, teeth_out, lips_out, input, Kernel::Auto)
3248}
3249
3250#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3251#[wasm_bindgen]
3252pub fn alligator_js(
3253 data: &[f64],
3254 jaw_period: usize,
3255 jaw_offset: usize,
3256 teeth_period: usize,
3257 teeth_offset: usize,
3258 lips_period: usize,
3259 lips_offset: usize,
3260) -> Result<Vec<f64>, JsValue> {
3261 let params = AlligatorParams {
3262 jaw_period: Some(jaw_period),
3263 jaw_offset: Some(jaw_offset),
3264 teeth_period: Some(teeth_period),
3265 teeth_offset: Some(teeth_offset),
3266 lips_period: Some(lips_period),
3267 lips_offset: Some(lips_offset),
3268 };
3269 let input = AlligatorInput::from_slice(data, params);
3270 let out = alligator_with_kernel(&input, Kernel::Auto)
3271 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3272
3273 let total = data
3274 .len()
3275 .checked_mul(3)
3276 .ok_or_else(|| JsValue::from_str("alligator_js: data length overflow"))?;
3277 let mut result = Vec::with_capacity(total);
3278 result.extend_from_slice(&out.jaw);
3279 result.extend_from_slice(&out.teeth);
3280 result.extend_from_slice(&out.lips);
3281 Ok(result)
3282}
3283
3284#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3285#[wasm_bindgen]
3286pub fn alligator_alloc(len: usize) -> *mut f64 {
3287 let mut vec = Vec::<f64>::with_capacity(len);
3288 let ptr = vec.as_mut_ptr();
3289 std::mem::forget(vec);
3290 ptr
3291}
3292
3293#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3294#[wasm_bindgen]
3295pub fn alligator_free(ptr: *mut f64, len: usize) {
3296 if !ptr.is_null() {
3297 unsafe {
3298 let _ = Vec::from_raw_parts(ptr, len, len);
3299 }
3300 }
3301}
3302
3303#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3304#[wasm_bindgen]
3305pub fn alligator_into(
3306 in_ptr: *const f64,
3307 jaw_ptr: *mut f64,
3308 teeth_ptr: *mut f64,
3309 lips_ptr: *mut f64,
3310 len: usize,
3311 jaw_period: usize,
3312 jaw_offset: usize,
3313 teeth_period: usize,
3314 teeth_offset: usize,
3315 lips_period: usize,
3316 lips_offset: usize,
3317) -> Result<(), JsValue> {
3318 if in_ptr.is_null() || jaw_ptr.is_null() || teeth_ptr.is_null() || lips_ptr.is_null() {
3319 return Err(JsValue::from_str("Null pointer provided"));
3320 }
3321
3322 unsafe {
3323 let data = std::slice::from_raw_parts(in_ptr, len);
3324 let params = AlligatorParams {
3325 jaw_period: Some(jaw_period),
3326 jaw_offset: Some(jaw_offset),
3327 teeth_period: Some(teeth_period),
3328 teeth_offset: Some(teeth_offset),
3329 lips_period: Some(lips_period),
3330 lips_offset: Some(lips_offset),
3331 };
3332 let input = AlligatorInput::from_slice(data, params);
3333
3334 let aliased = in_ptr == jaw_ptr as *const f64
3335 || in_ptr == teeth_ptr as *const f64
3336 || in_ptr == lips_ptr as *const f64
3337 || jaw_ptr == teeth_ptr
3338 || jaw_ptr == lips_ptr
3339 || teeth_ptr == lips_ptr;
3340
3341 if aliased {
3342 let mut temp_jaw = vec![0.0; len];
3343 let mut temp_teeth = vec![0.0; len];
3344 let mut temp_lips = vec![0.0; len];
3345
3346 alligator_into_slices(
3347 &mut temp_jaw,
3348 &mut temp_teeth,
3349 &mut temp_lips,
3350 &input,
3351 Kernel::Auto,
3352 )
3353 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3354
3355 let jaw_out = std::slice::from_raw_parts_mut(jaw_ptr, len);
3356 let teeth_out = std::slice::from_raw_parts_mut(teeth_ptr, len);
3357 let lips_out = std::slice::from_raw_parts_mut(lips_ptr, len);
3358
3359 jaw_out.copy_from_slice(&temp_jaw);
3360 teeth_out.copy_from_slice(&temp_teeth);
3361 lips_out.copy_from_slice(&temp_lips);
3362 } else {
3363 let jaw_out = std::slice::from_raw_parts_mut(jaw_ptr, len);
3364 let teeth_out = std::slice::from_raw_parts_mut(teeth_ptr, len);
3365 let lips_out = std::slice::from_raw_parts_mut(lips_ptr, len);
3366
3367 alligator_into_slices(jaw_out, teeth_out, lips_out, &input, Kernel::Auto)
3368 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3369 }
3370
3371 Ok(())
3372 }
3373}
3374
3375#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3376#[wasm_bindgen]
3377pub fn alligator_batch_js(
3378 data: &[f64],
3379 jaw_period_start: usize,
3380 jaw_period_end: usize,
3381 jaw_period_step: usize,
3382 jaw_offset_start: usize,
3383 jaw_offset_end: usize,
3384 jaw_offset_step: usize,
3385 teeth_period_start: usize,
3386 teeth_period_end: usize,
3387 teeth_period_step: usize,
3388 teeth_offset_start: usize,
3389 teeth_offset_end: usize,
3390 teeth_offset_step: usize,
3391 lips_period_start: usize,
3392 lips_period_end: usize,
3393 lips_period_step: usize,
3394 lips_offset_start: usize,
3395 lips_offset_end: usize,
3396 lips_offset_step: usize,
3397) -> Result<Vec<f64>, JsValue> {
3398 let sweep = AlligatorBatchRange {
3399 jaw_period: (jaw_period_start, jaw_period_end, jaw_period_step),
3400 jaw_offset: (jaw_offset_start, jaw_offset_end, jaw_offset_step),
3401 teeth_period: (teeth_period_start, teeth_period_end, teeth_period_step),
3402 teeth_offset: (teeth_offset_start, teeth_offset_end, teeth_offset_step),
3403 lips_period: (lips_period_start, lips_period_end, lips_period_step),
3404 lips_offset: (lips_offset_start, lips_offset_end, lips_offset_step),
3405 };
3406
3407 alligator_batch_inner(data, &sweep, Kernel::ScalarBatch, false)
3408 .map(|output| {
3409 let mut result =
3410 Vec::with_capacity((output.jaw.len() + output.teeth.len() + output.lips.len()));
3411 result.extend_from_slice(&output.jaw);
3412 result.extend_from_slice(&output.teeth);
3413 result.extend_from_slice(&output.lips);
3414 result
3415 })
3416 .map_err(|e| JsValue::from_str(&e.to_string()))
3417}
3418
3419#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3420#[wasm_bindgen]
3421pub fn alligator_batch_metadata_js(
3422 jaw_period_start: usize,
3423 jaw_period_end: usize,
3424 jaw_period_step: usize,
3425 jaw_offset_start: usize,
3426 jaw_offset_end: usize,
3427 jaw_offset_step: usize,
3428 teeth_period_start: usize,
3429 teeth_period_end: usize,
3430 teeth_period_step: usize,
3431 teeth_offset_start: usize,
3432 teeth_offset_end: usize,
3433 teeth_offset_step: usize,
3434 lips_period_start: usize,
3435 lips_period_end: usize,
3436 lips_period_step: usize,
3437 lips_offset_start: usize,
3438 lips_offset_end: usize,
3439 lips_offset_step: usize,
3440) -> Result<Vec<f64>, JsValue> {
3441 let sweep = AlligatorBatchRange {
3442 jaw_period: (jaw_period_start, jaw_period_end, jaw_period_step),
3443 jaw_offset: (jaw_offset_start, jaw_offset_end, jaw_offset_step),
3444 teeth_period: (teeth_period_start, teeth_period_end, teeth_period_step),
3445 teeth_offset: (teeth_offset_start, teeth_offset_end, teeth_offset_step),
3446 lips_period: (lips_period_start, lips_period_end, lips_period_step),
3447 lips_offset: (lips_offset_start, lips_offset_end, lips_offset_step),
3448 };
3449
3450 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3451 let mut metadata = Vec::with_capacity(combos.len() * 6);
3452
3453 for combo in combos {
3454 metadata.push(combo.jaw_period.unwrap() as f64);
3455 metadata.push(combo.jaw_offset.unwrap() as f64);
3456 metadata.push(combo.teeth_period.unwrap() as f64);
3457 metadata.push(combo.teeth_offset.unwrap() as f64);
3458 metadata.push(combo.lips_period.unwrap() as f64);
3459 metadata.push(combo.lips_offset.unwrap() as f64);
3460 }
3461
3462 Ok(metadata)
3463}
3464
3465#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3466#[derive(Serialize, Deserialize)]
3467pub struct AlligatorBatchJsOutput {
3468 pub jaw: Vec<f64>,
3469 pub teeth: Vec<f64>,
3470 pub lips: Vec<f64>,
3471 pub combos: Vec<AlligatorParams>,
3472 pub rows: usize,
3473 pub cols: usize,
3474}
3475
3476#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3477#[derive(Serialize, Deserialize)]
3478pub struct AlligatorBatchConfig {
3479 pub jaw_period_range: (usize, usize, usize),
3480 pub jaw_offset_range: (usize, usize, usize),
3481 pub teeth_period_range: (usize, usize, usize),
3482 pub teeth_offset_range: (usize, usize, usize),
3483 pub lips_period_range: (usize, usize, usize),
3484 pub lips_offset_range: (usize, usize, usize),
3485}
3486
3487#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3488#[wasm_bindgen(js_name = alligator_batch)]
3489pub fn alligator_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
3490 let config: AlligatorBatchConfig = serde_wasm_bindgen::from_value(config)
3491 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
3492
3493 let sweep = AlligatorBatchRange {
3494 jaw_period: config.jaw_period_range,
3495 jaw_offset: config.jaw_offset_range,
3496 teeth_period: config.teeth_period_range,
3497 teeth_offset: config.teeth_offset_range,
3498 lips_period: config.lips_period_range,
3499 lips_offset: config.lips_offset_range,
3500 };
3501 let rows = expand_grid(&sweep)
3502 .map_err(|e| JsValue::from_str(&e.to_string()))?
3503 .len();
3504 let cols = data.len();
3505 let total = rows
3506 .checked_mul(cols)
3507 .ok_or_else(|| JsValue::from_str("alligator_batch_unified_js: rows*cols overflow"))?;
3508 let mut jaw = vec![f64::NAN; total];
3509 let mut teeth = vec![f64::NAN; total];
3510 let mut lips = vec![f64::NAN; total];
3511
3512 let combos = alligator_batch_inner_into(
3513 data,
3514 &sweep,
3515 Kernel::ScalarBatch,
3516 false,
3517 &mut jaw,
3518 &mut teeth,
3519 &mut lips,
3520 )
3521 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3522
3523 let js = AlligatorBatchJsOutput {
3524 jaw,
3525 teeth,
3526 lips,
3527 combos,
3528 rows,
3529 cols,
3530 };
3531 serde_wasm_bindgen::to_value(&js).map_err(|e| JsValue::from_str(&format!("serde: {e}")))
3532}
3533
3534#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
3535#[wasm_bindgen]
3536pub fn alligator_batch_into(
3537 in_ptr: *const f64,
3538 jaw_out_ptr: *mut f64,
3539 teeth_out_ptr: *mut f64,
3540 lips_out_ptr: *mut f64,
3541 len: usize,
3542
3543 jp_s: usize,
3544 jp_e: usize,
3545 jp_step: usize,
3546 jo_s: usize,
3547 jo_e: usize,
3548 jo_step: usize,
3549 tp_s: usize,
3550 tp_e: usize,
3551 tp_step: usize,
3552 to_s: usize,
3553 to_e: usize,
3554 to_step: usize,
3555 lp_s: usize,
3556 lp_e: usize,
3557 lp_step: usize,
3558 lo_s: usize,
3559 lo_e: usize,
3560 lo_step: usize,
3561) -> Result<usize, JsValue> {
3562 if in_ptr.is_null()
3563 || jaw_out_ptr.is_null()
3564 || teeth_out_ptr.is_null()
3565 || lips_out_ptr.is_null()
3566 {
3567 return Err(JsValue::from_str(
3568 "null pointer passed to alligator_batch_into",
3569 ));
3570 }
3571 unsafe {
3572 let data = std::slice::from_raw_parts(in_ptr, len);
3573 let sweep = AlligatorBatchRange {
3574 jaw_period: (jp_s, jp_e, jp_step),
3575 jaw_offset: (jo_s, jo_e, jo_step),
3576 teeth_period: (tp_s, tp_e, tp_step),
3577 teeth_offset: (to_s, to_e, to_step),
3578 lips_period: (lp_s, lp_e, lp_step),
3579 lips_offset: (lo_s, lo_e, lo_step),
3580 };
3581 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
3582 let rows = combos.len();
3583 let cols = len;
3584 let total = rows
3585 .checked_mul(cols)
3586 .ok_or_else(|| JsValue::from_str("alligator_batch_into: rows*cols overflow"))?;
3587
3588 let jaw = std::slice::from_raw_parts_mut(jaw_out_ptr, total);
3589 let teeth = std::slice::from_raw_parts_mut(teeth_out_ptr, total);
3590 let lips = std::slice::from_raw_parts_mut(lips_out_ptr, total);
3591
3592 alligator_batch_inner_into(data, &sweep, Kernel::ScalarBatch, false, jaw, teeth, lips)
3593 .map_err(|e| JsValue::from_str(&e.to_string()))?;
3594 Ok(rows)
3595 }
3596}