1#[cfg(feature = "python")]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::moving_averages::sama_wrapper::DeviceArrayF32Sama;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::{cuda_available, moving_averages::CudaSama};
21use crate::utilities::data_loader::{source_type, Candles};
22#[cfg(all(feature = "python", feature = "cuda"))]
23use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
24use crate::utilities::enums::Kernel;
25use crate::utilities::helpers::{
26 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
27 make_uninit_matrix,
28};
29#[cfg(feature = "python")]
30use crate::utilities::kernel_validation::validate_kernel;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use cust::memory::DeviceBuffer;
33
34#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
35use core::arch::x86_64::*;
36
37#[cfg(not(target_arch = "wasm32"))]
38use rayon::prelude::*;
39
40use std::convert::AsRef;
41use std::error::Error;
42use std::mem::MaybeUninit;
43use thiserror::Error;
44
45impl<'a> AsRef<[f64]> for SamaInput<'a> {
46 #[inline(always)]
47 fn as_ref(&self) -> &[f64] {
48 match &self.data {
49 SamaData::Slice(slice) => slice,
50 SamaData::Candles { candles, source } => source_type(candles, source),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
56pub enum SamaData<'a> {
57 Candles {
58 candles: &'a Candles,
59 source: &'a str,
60 },
61 Slice(&'a [f64]),
62}
63
64#[derive(Debug, Clone)]
65pub struct SamaOutput {
66 pub values: Vec<f64>,
67}
68
69#[derive(Debug, Clone)]
70#[cfg_attr(
71 all(target_arch = "wasm32", feature = "wasm"),
72 derive(Serialize, Deserialize)
73)]
74pub struct SamaParams {
75 pub length: Option<usize>,
76 pub maj_length: Option<usize>,
77 pub min_length: Option<usize>,
78}
79
80impl Default for SamaParams {
81 fn default() -> Self {
82 Self {
83 length: Some(200),
84 maj_length: Some(14),
85 min_length: Some(6),
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
91pub struct SamaInput<'a> {
92 pub data: SamaData<'a>,
93 pub params: SamaParams,
94}
95
96impl<'a> SamaInput<'a> {
97 #[inline]
98 pub fn from_candles(c: &'a Candles, s: &'a str, p: SamaParams) -> Self {
99 Self {
100 data: SamaData::Candles {
101 candles: c,
102 source: s,
103 },
104 params: p,
105 }
106 }
107
108 #[inline]
109 pub fn from_slice(sl: &'a [f64], p: SamaParams) -> Self {
110 Self {
111 data: SamaData::Slice(sl),
112 params: p,
113 }
114 }
115
116 #[inline]
117 pub fn with_default_candles(c: &'a Candles) -> Self {
118 Self::from_candles(c, "close", SamaParams::default())
119 }
120
121 #[inline]
122 pub fn get_length(&self) -> usize {
123 self.params.length.unwrap_or(200)
124 }
125
126 #[inline]
127 pub fn get_maj_length(&self) -> usize {
128 self.params.maj_length.unwrap_or(14)
129 }
130
131 #[inline]
132 pub fn get_min_length(&self) -> usize {
133 self.params.min_length.unwrap_or(6)
134 }
135}
136
137#[derive(Clone, Debug)]
138pub struct SamaBuilder {
139 length: Option<usize>,
140 maj_length: Option<usize>,
141 min_length: Option<usize>,
142 kernel: Kernel,
143}
144
145impl Default for SamaBuilder {
146 fn default() -> Self {
147 Self {
148 length: None,
149 maj_length: None,
150 min_length: None,
151 kernel: Kernel::Auto,
152 }
153 }
154}
155
156impl SamaBuilder {
157 #[inline(always)]
158 pub fn new() -> Self {
159 Self::default()
160 }
161
162 #[inline(always)]
163 pub fn length(mut self, val: usize) -> Self {
164 self.length = Some(val);
165 self
166 }
167
168 #[inline(always)]
169 pub fn maj_length(mut self, val: usize) -> Self {
170 self.maj_length = Some(val);
171 self
172 }
173
174 #[inline(always)]
175 pub fn min_length(mut self, val: usize) -> Self {
176 self.min_length = Some(val);
177 self
178 }
179
180 #[inline(always)]
181 pub fn kernel(mut self, k: Kernel) -> Self {
182 self.kernel = k;
183 self
184 }
185
186 #[inline(always)]
187 pub fn apply(self, c: &Candles) -> Result<SamaOutput, SamaError> {
188 let p = SamaParams {
189 length: self.length,
190 maj_length: self.maj_length,
191 min_length: self.min_length,
192 };
193 let i = SamaInput::from_candles(c, "close", p);
194 sama_with_kernel(&i, self.kernel)
195 }
196
197 #[inline(always)]
198 pub fn apply_slice(self, d: &[f64]) -> Result<SamaOutput, SamaError> {
199 let p = SamaParams {
200 length: self.length,
201 maj_length: self.maj_length,
202 min_length: self.min_length,
203 };
204 let i = SamaInput::from_slice(d, p);
205 sama_with_kernel(&i, self.kernel)
206 }
207
208 #[inline(always)]
209 pub fn into_stream(self) -> Result<SamaStream, SamaError> {
210 let p = SamaParams {
211 length: self.length,
212 maj_length: self.maj_length,
213 min_length: self.min_length,
214 };
215 SamaStream::try_new(p)
216 }
217}
218
219#[derive(Debug, Error)]
220pub enum SamaError {
221 #[error("sama: Input data slice is empty.")]
222 EmptyInputData,
223
224 #[error("sama: All values are NaN.")]
225 AllValuesNaN,
226
227 #[error("sama: Invalid period: period = {period}, data length = {data_len}")]
228 InvalidPeriod { period: usize, data_len: usize },
229
230 #[error("sama: Not enough valid data: needed = {needed}, valid = {valid}")]
231 NotEnoughValidData { needed: usize, valid: usize },
232
233 #[error("sama: Output length mismatch: expected = {expected}, got = {got}")]
234 OutputLengthMismatch { expected: usize, got: usize },
235
236 #[error("sama: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
237 InvalidRange {
238 start: usize,
239 end: usize,
240 step: usize,
241 },
242
243 #[error("sama: Invalid kernel for batch path: {0:?}")]
244 InvalidKernelForBatch(Kernel),
245}
246
247#[inline(always)]
248pub fn sama(input: &SamaInput) -> Result<SamaOutput, SamaError> {
249 sama_with_kernel(input, Kernel::Auto)
250}
251
252#[inline(always)]
253pub fn sama_with_kernel(input: &SamaInput, kernel: Kernel) -> Result<SamaOutput, SamaError> {
254 let (data, length, maj_length, min_length, first, chosen) = sama_prepare(input, kernel)?;
255
256 let mut out = alloc_with_nan_prefix(data.len(), first);
257 sama_compute_into(
258 data, length, maj_length, min_length, first, chosen, &mut out,
259 );
260 Ok(SamaOutput { values: out })
261}
262
263#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
264pub fn sama_into(input: &SamaInput, out: &mut [f64]) -> Result<(), SamaError> {
265 let (data, length, maj_length, min_length, first, chosen) = sama_prepare(input, Kernel::Auto)?;
266
267 if out.len() != data.len() {
268 return Err(SamaError::OutputLengthMismatch {
269 expected: data.len(),
270 got: out.len(),
271 });
272 }
273
274 let warm = first.min(out.len());
275 for i in 0..warm {
276 out[i] = f64::from_bits(0x7ff8_0000_0000_0000);
277 }
278
279 sama_compute_into(data, length, maj_length, min_length, first, chosen, out);
280
281 Ok(())
282}
283
284#[inline(always)]
285pub fn sama_into_slice(dst: &mut [f64], input: &SamaInput, kern: Kernel) -> Result<(), SamaError> {
286 let (data, length, maj_length, min_length, first, chosen) = sama_prepare(input, kern)?;
287
288 if dst.len() != data.len() {
289 return Err(SamaError::OutputLengthMismatch {
290 expected: data.len(),
291 got: dst.len(),
292 });
293 }
294
295 sama_compute_into(data, length, maj_length, min_length, first, chosen, dst);
296
297 for v in &mut dst[..first] {
298 *v = f64::NAN;
299 }
300 Ok(())
301}
302
303#[inline(always)]
304fn sama_prepare<'a>(
305 input: &'a SamaInput,
306 kernel: Kernel,
307) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), SamaError> {
308 let data: &[f64] = input.as_ref();
309 let n = data.len();
310
311 if n == 0 {
312 return Err(SamaError::EmptyInputData);
313 }
314
315 let first = data
316 .iter()
317 .position(|x| !x.is_nan())
318 .ok_or(SamaError::AllValuesNaN)?;
319
320 let length = input.get_length();
321 let maj_length = input.get_maj_length();
322 let min_length = input.get_min_length();
323
324 if length + 1 > n || length == 0 {
325 return Err(SamaError::InvalidPeriod {
326 period: length,
327 data_len: n,
328 });
329 }
330
331 if maj_length == 0 || min_length == 0 {
332 return Err(SamaError::InvalidPeriod {
333 period: 0,
334 data_len: n,
335 });
336 }
337
338 let valid = n - first;
339
340 if valid < 1 {
341 return Err(SamaError::NotEnoughValidData { needed: 1, valid });
342 }
343
344 let chosen = match kernel {
345 Kernel::Auto => detect_best_kernel(),
346 k => k,
347 };
348
349 Ok((data, length, maj_length, min_length, first, chosen))
350}
351
352fn sama_compute_into(
353 data: &[f64],
354 length: usize,
355 maj_length: usize,
356 min_length: usize,
357 first: usize,
358 kernel: Kernel,
359 out: &mut [f64],
360) {
361 unsafe {
362 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
363 {
364 if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
365 sama_simd128(data, length, maj_length, min_length, first, out);
366 return;
367 }
368 }
369
370 match kernel {
371 Kernel::Scalar | Kernel::ScalarBatch => {
372 sama_scalar(data, length, maj_length, min_length, first, out)
373 }
374 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
375 Kernel::Avx2 | Kernel::Avx2Batch => {
376 sama_avx2(data, length, maj_length, min_length, first, out)
377 }
378 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
379 Kernel::Avx512 | Kernel::Avx512Batch => {
380 sama_avx512(data, length, maj_length, min_length, first, out)
381 }
382 #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
383 Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
384 sama_scalar(data, length, maj_length, min_length, first, out)
385 }
386 _ => unreachable!(),
387 }
388 }
389}
390
391#[inline]
392pub fn sama_scalar(
393 data: &[f64],
394 length: usize,
395 maj_length: usize,
396 min_length: usize,
397 first: usize,
398 out: &mut [f64],
399) {
400 let n = data.len();
401 if n == 0 {
402 return;
403 }
404
405 let maj_alpha = 2.0 / (maj_length as f64 + 1.0);
406 let min_alpha = 2.0 / (min_length as f64 + 1.0);
407 let delta = min_alpha - maj_alpha;
408
409 let cap = length + 1;
410 let mut max_idx = vec![0usize; cap];
411 let mut min_idx = vec![0usize; cap];
412 let mut max_head = 0usize;
413 let mut min_head = 0usize;
414 let mut max_len = 0usize;
415 let mut min_len = 0usize;
416
417 let mut sama_val = f64::NAN;
418
419 for i in first..n {
420 let p = data[i];
421 if p.is_nan() {
422 out[i] = f64::NAN;
423 continue;
424 }
425
426 let wstart = i.saturating_sub(length);
427
428 while max_len > 0 {
429 let idx = max_idx[max_head];
430 if idx >= wstart {
431 break;
432 }
433 max_head += 1;
434 if max_head == cap {
435 max_head = 0;
436 }
437 max_len -= 1;
438 }
439
440 while min_len > 0 {
441 let idx = min_idx[min_head];
442 if idx >= wstart {
443 break;
444 }
445 min_head += 1;
446 if min_head == cap {
447 min_head = 0;
448 }
449 min_len -= 1;
450 }
451
452 while max_len > 0 {
453 let mut last_pos = max_head + max_len - 1;
454 if last_pos >= cap {
455 last_pos -= cap;
456 }
457 let last_idx = max_idx[last_pos];
458 let last_val = data[last_idx];
459 if last_val <= p {
460 max_len -= 1;
461 } else {
462 break;
463 }
464 }
465 let mut ins_pos_max = max_head + max_len;
466 if ins_pos_max >= cap {
467 ins_pos_max -= cap;
468 }
469 max_idx[ins_pos_max] = i;
470 max_len += 1;
471
472 while min_len > 0 {
473 let mut last_pos = min_head + min_len - 1;
474 if last_pos >= cap {
475 last_pos -= cap;
476 }
477 let last_idx = min_idx[last_pos];
478 let last_val = data[last_idx];
479 if last_val >= p {
480 min_len -= 1;
481 } else {
482 break;
483 }
484 }
485 let mut ins_pos_min = min_head + min_len;
486 if ins_pos_min >= cap {
487 ins_pos_min -= cap;
488 }
489 min_idx[ins_pos_min] = i;
490 min_len += 1;
491
492 let hh = data[max_idx[max_head]];
493 let ll = data[min_idx[min_head]];
494
495 let denom = hh - ll;
496 let c = (p + p) - (hh + ll);
497 let mult = if denom > 0.0 { c.abs() / denom } else { 0.0 };
498
499 let a = mult.mul_add(delta, maj_alpha);
500 let alpha = a * a;
501
502 if sama_val.is_nan() {
503 sama_val = p;
504 } else {
505 let diff = p - sama_val;
506 sama_val = diff.mul_add(alpha, sama_val);
507 }
508
509 out[i] = sama_val;
510 }
511}
512
513#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
514#[inline]
515pub fn sama_avx2(
516 data: &[f64],
517 length: usize,
518 maj_length: usize,
519 min_length: usize,
520 first: usize,
521 out: &mut [f64],
522) {
523 sama_scalar(data, length, maj_length, min_length, first, out);
524}
525
526#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
527#[inline]
528pub fn sama_avx512(
529 data: &[f64],
530 length: usize,
531 maj_length: usize,
532 min_length: usize,
533 first: usize,
534 out: &mut [f64],
535) {
536 sama_scalar(data, length, maj_length, min_length, first, out);
537}
538
539#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
540#[inline]
541pub fn sama_simd128(
542 data: &[f64],
543 length: usize,
544 maj_length: usize,
545 min_length: usize,
546 first: usize,
547 out: &mut [f64],
548) {
549 sama_scalar(data, length, maj_length, min_length, first, out);
550}
551
552#[derive(Debug, Clone)]
553pub struct SamaBatchRange {
554 pub length: (usize, usize, usize),
555 pub maj_length: (usize, usize, usize),
556 pub min_length: (usize, usize, usize),
557}
558
559#[derive(Debug, Clone)]
560pub struct SamaBatchOutput {
561 pub values: Vec<f64>,
562 pub combos: Vec<SamaParams>,
563 pub rows: usize,
564 pub cols: usize,
565}
566
567impl SamaBatchOutput {
568 #[inline]
569 pub fn row_for_params(&self, p: &SamaParams) -> Option<usize> {
570 self.combos.iter().position(|c| {
571 c.length.unwrap_or(200) == p.length.unwrap_or(200)
572 && c.maj_length.unwrap_or(14) == p.maj_length.unwrap_or(14)
573 && c.min_length.unwrap_or(6) == p.min_length.unwrap_or(6)
574 })
575 }
576
577 #[inline]
578 pub fn values_for(&self, p: &SamaParams) -> Option<&[f64]> {
579 self.row_for_params(p).map(|row| {
580 let start = row * self.cols;
581 &self.values[start..start + self.cols]
582 })
583 }
584}
585
586#[derive(Clone, Debug, Default)]
587pub struct SamaBatchBuilder {
588 range: SamaBatchRange,
589 kernel: Kernel,
590}
591
592impl SamaBatchBuilder {
593 pub fn new() -> Self {
594 Self::default()
595 }
596
597 pub fn kernel(mut self, k: Kernel) -> Self {
598 self.kernel = k;
599 self
600 }
601
602 #[inline]
603 pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
604 self.range.length = (start, end, step);
605 self
606 }
607
608 #[inline]
609 pub fn maj_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
610 self.range.maj_length = (start, end, step);
611 self
612 }
613
614 #[inline]
615 pub fn min_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
616 self.range.min_length = (start, end, step);
617 self
618 }
619
620 #[inline]
621 pub fn length_static(mut self, v: usize) -> Self {
622 self.range.length = (v, v, 0);
623 self
624 }
625
626 #[inline]
627 pub fn maj_length_static(mut self, v: usize) -> Self {
628 self.range.maj_length = (v, v, 0);
629 self
630 }
631
632 #[inline]
633 pub fn min_length_static(mut self, v: usize) -> Self {
634 self.range.min_length = (v, v, 0);
635 self
636 }
637
638 pub fn apply_slice(self, data: &[f64]) -> Result<SamaBatchOutput, SamaError> {
639 sama_batch_with_kernel(data, &self.range, self.kernel)
640 }
641
642 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SamaBatchOutput, SamaError> {
643 self.apply_slice(source_type(c, src))
644 }
645
646 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SamaBatchOutput, SamaError> {
647 SamaBatchBuilder::new().kernel(k).apply_slice(data)
648 }
649
650 pub fn with_default_candles(c: &Candles) -> Result<SamaBatchOutput, SamaError> {
651 SamaBatchBuilder::new()
652 .kernel(Kernel::Auto)
653 .length_static(200)
654 .maj_length_static(14)
655 .min_length_static(6)
656 .apply_candles(c, "close")
657 }
658}
659
660impl Default for SamaBatchRange {
661 fn default() -> Self {
662 Self {
663 length: (200, 449, 1),
664 maj_length: (14, 14, 0),
665 min_length: (6, 6, 0),
666 }
667 }
668}
669
670#[inline(always)]
671fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SamaError> {
672 if start == end || step == 0 {
673 return Ok(vec![start]);
674 }
675 if start < end {
676 if step == 0 {
677 return Ok(vec![start]);
678 }
679 let mut v = Vec::new();
680 let mut x = start;
681 while x <= end {
682 v.push(x);
683 match x.checked_add(step) {
684 Some(nx) => x = nx,
685 None => break,
686 }
687 }
688 if v.is_empty() {
689 return Err(SamaError::InvalidRange { start, end, step });
690 }
691 Ok(v)
692 } else {
693 if step == 0 {
694 return Ok(vec![start]);
695 }
696 let mut v = Vec::new();
697 let mut x = start;
698 loop {
699 v.push(x);
700 if x <= end {
701 break;
702 }
703 x = x.saturating_sub(step);
704 if x == 0 && end > 0 && x < end {
705 break;
706 }
707 if v.len() > 1 && *v.last().unwrap() == x {
708 break;
709 }
710 if x == 0 && end == 0 {
711 break;
712 }
713 if x < end {
714 break;
715 }
716 }
717 if v.is_empty() {
718 return Err(SamaError::InvalidRange { start, end, step });
719 }
720 Ok(v)
721 }
722}
723
724#[inline(always)]
725fn expand_grid_sama(r: &SamaBatchRange) -> Result<Vec<SamaParams>, SamaError> {
726 let lens = axis_usize(r.length)?;
727 let maj = axis_usize(r.maj_length)?;
728 let min = axis_usize(r.min_length)?;
729 if lens.is_empty() || maj.is_empty() || min.is_empty() {
730 return Err(SamaError::InvalidRange {
731 start: 0,
732 end: 0,
733 step: 0,
734 });
735 }
736 let mut out = Vec::with_capacity(
737 lens.len()
738 .saturating_mul(maj.len())
739 .saturating_mul(min.len()),
740 );
741 for &l in &lens {
742 for &j in &maj {
743 for &m in &min {
744 out.push(SamaParams {
745 length: Some(l),
746 maj_length: Some(j),
747 min_length: Some(m),
748 });
749 }
750 }
751 }
752 Ok(out)
753}
754
755pub fn sama_batch_with_kernel(
756 data: &[f64],
757 sweep: &SamaBatchRange,
758 k: Kernel,
759) -> Result<SamaBatchOutput, SamaError> {
760 let kernel = match k {
761 Kernel::Auto => detect_best_batch_kernel(),
762 other if other.is_batch() => other,
763 other => return Err(SamaError::InvalidKernelForBatch(other)),
764 };
765
766 let simd = match kernel {
767 Kernel::Avx512Batch => Kernel::Avx512,
768 Kernel::Avx2Batch => Kernel::Avx2,
769 Kernel::ScalarBatch => Kernel::Scalar,
770 _ => unreachable!(),
771 };
772
773 sama_batch_par_slice(data, sweep, simd)
774}
775
776#[inline(always)]
777pub fn sama_batch_par_slice(
778 data: &[f64],
779 sweep: &SamaBatchRange,
780 kern: Kernel,
781) -> Result<SamaBatchOutput, SamaError> {
782 sama_batch_inner(data, sweep, kern, true)
783}
784
785#[inline(always)]
786pub fn sama_batch_slice(
787 data: &[f64],
788 sweep: &SamaBatchRange,
789 kern: Kernel,
790) -> Result<SamaBatchOutput, SamaError> {
791 sama_batch_inner(data, sweep, kern, false)
792}
793
794#[inline(always)]
795fn sama_batch_inner(
796 data: &[f64],
797 sweep: &SamaBatchRange,
798 kern: Kernel,
799 parallel: bool,
800) -> Result<SamaBatchOutput, SamaError> {
801 let combos = expand_grid_sama(sweep)?;
802 let cols = data.len();
803 let rows = combos.len();
804 if cols == 0 {
805 return Err(SamaError::EmptyInputData);
806 }
807
808 let total = rows.checked_mul(cols).ok_or(SamaError::InvalidRange {
809 start: rows,
810 end: cols,
811 step: 0,
812 })?;
813
814 let mut buf_mu = make_uninit_matrix(rows, cols);
815
816 let first = data
817 .iter()
818 .position(|x| !x.is_nan())
819 .ok_or(SamaError::AllValuesNaN)?;
820 let warm: Vec<usize> = combos.iter().map(|_| first).collect();
821 init_matrix_prefixes(&mut buf_mu, cols, &warm);
822
823 let mut guard = core::mem::ManuallyDrop::new(buf_mu);
824 let out: &mut [f64] =
825 unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
826
827 sama_batch_inner_into(data, &combos, first, kern, parallel, out)?;
828
829 let values = unsafe {
830 Vec::from_raw_parts(
831 guard.as_mut_ptr() as *mut f64,
832 guard.len(),
833 guard.capacity(),
834 )
835 };
836
837 Ok(SamaBatchOutput {
838 values,
839 combos,
840 rows,
841 cols,
842 })
843}
844
845#[inline(always)]
846fn sama_batch_inner_into(
847 data: &[f64],
848 combos: &[SamaParams],
849 first: usize,
850 kern: Kernel,
851 parallel: bool,
852 out_flat: &mut [f64],
853) -> Result<(), SamaError> {
854 if combos.is_empty() {
855 return Ok(());
856 }
857
858 let rows = combos.len();
859 let cols = data.len();
860
861 if cols - first < 1 {
862 return Err(SamaError::NotEnoughValidData {
863 needed: 1,
864 valid: cols - first,
865 });
866 }
867
868 use std::collections::HashMap;
869 let mut uniq_lengths: Vec<usize> = Vec::new();
870 uniq_lengths.reserve(combos.len());
871 for prm in combos.iter() {
872 let l = prm.length.unwrap_or(200);
873 if !uniq_lengths.contains(&l) {
874 uniq_lengths.push(l);
875 }
876 }
877
878 fn build_rolling_extrema(data: &[f64], length: usize) -> (Vec<f64>, Vec<f64>) {
879 let n = data.len();
880 let cap = length + 1;
881 let mut max_idx = vec![0usize; cap];
882 let mut min_idx = vec![0usize; cap];
883 let mut max_head = 0usize;
884 let mut min_head = 0usize;
885 let mut max_len = 0usize;
886 let mut min_len = 0usize;
887 let mut hh = vec![f64::NAN; n];
888 let mut ll = vec![f64::NAN; n];
889
890 for i in 0..n {
891 let p = data[i];
892 let wstart = i.saturating_sub(length);
893
894 while max_len > 0 {
895 let idx = max_idx[max_head];
896 if idx >= wstart {
897 break;
898 }
899 max_head += 1;
900 if max_head == cap {
901 max_head = 0;
902 }
903 max_len -= 1;
904 }
905 while min_len > 0 {
906 let idx = min_idx[min_head];
907 if idx >= wstart {
908 break;
909 }
910 min_head += 1;
911 if min_head == cap {
912 min_head = 0;
913 }
914 min_len -= 1;
915 }
916
917 if !p.is_nan() {
918 while max_len > 0 {
919 let last_pos = (max_head + max_len - 1) % cap;
920 let last_idx = max_idx[last_pos];
921 if data[last_idx] <= p {
922 max_len -= 1;
923 } else {
924 break;
925 }
926 }
927 let ins_pos_max = (max_head + max_len) % cap;
928 max_idx[ins_pos_max] = i;
929 max_len += 1;
930
931 while min_len > 0 {
932 let last_pos = (min_head + min_len - 1) % cap;
933 let last_idx = min_idx[last_pos];
934 if data[last_idx] >= p {
935 min_len -= 1;
936 } else {
937 break;
938 }
939 }
940 let ins_pos_min = (min_head + min_len) % cap;
941 min_idx[ins_pos_min] = i;
942 min_len += 1;
943 }
944
945 if max_len > 0 && min_len > 0 {
946 hh[i] = data[max_idx[max_head]];
947 ll[i] = data[min_idx[min_head]];
948 }
949 }
950 (hh, ll)
951 }
952
953 let mut hh_map: HashMap<usize, Vec<f64>> = HashMap::with_capacity(uniq_lengths.len());
954 let mut ll_map: HashMap<usize, Vec<f64>> = HashMap::with_capacity(uniq_lengths.len());
955 for &l in &uniq_lengths {
956 let (hh, ll) = build_rolling_extrema(data, l);
957 hh_map.insert(l, hh);
958 ll_map.insert(l, ll);
959 }
960
961 let do_row = |row: usize, row_dst: &mut [f64]| {
962 let prm = &combos[row];
963 let length = prm.length.unwrap_or(200);
964 let maj_length = prm.maj_length.unwrap_or(14);
965 let min_length = prm.min_length.unwrap_or(6);
966
967 let maj_alpha = 2.0 / (maj_length as f64 + 1.0);
968 let min_alpha = 2.0 / (min_length as f64 + 1.0);
969 let delta = min_alpha - maj_alpha;
970
971 let hh = &hh_map[&length];
972 let ll = &ll_map[&length];
973
974 let mut sama_val = f64::NAN;
975 for i in first..cols {
976 let p = data[i];
977 if p.is_nan() {
978 row_dst[i] = f64::NAN;
979 continue;
980 }
981
982 let h = hh[i];
983 let l = ll[i];
984 let denom = h - l;
985 let c = (p + p) - (h + l);
986 let mult = if denom > 0.0 { c.abs() / denom } else { 0.0 };
987 let a = mult.mul_add(delta, maj_alpha);
988 let alpha = a * a;
989
990 if sama_val.is_nan() {
991 sama_val = p;
992 } else {
993 let diff = p - sama_val;
994 sama_val = diff.mul_add(alpha, sama_val);
995 }
996 row_dst[i] = sama_val;
997 }
998 };
999
1000 if parallel {
1001 #[cfg(not(target_arch = "wasm32"))]
1002 {
1003 use rayon::prelude::*;
1004 out_flat
1005 .par_chunks_mut(cols)
1006 .enumerate()
1007 .for_each(|(row, dst)| do_row(row, dst));
1008 }
1009 #[cfg(target_arch = "wasm32")]
1010 {
1011 for (row, dst) in out_flat.chunks_mut(cols).enumerate() {
1012 do_row(row, dst);
1013 }
1014 }
1015 } else {
1016 for (row, dst) in out_flat.chunks_mut(cols).enumerate() {
1017 do_row(row, dst);
1018 }
1019 }
1020
1021 Ok(())
1022}
1023
1024#[derive(Debug, Clone)]
1025pub struct SamaStream {
1026 length: usize,
1027 maj_length: usize,
1028 min_length: usize,
1029
1030 maj_alpha: f64,
1031 min_alpha: f64,
1032 delta: f64,
1033
1034 cap: usize,
1035 buf: Vec<f64>,
1036 head: usize,
1037 tick: usize,
1038
1039 max_idx: Vec<usize>,
1040 min_idx: Vec<usize>,
1041 max_head: usize,
1042 min_head: usize,
1043 max_len: usize,
1044 min_len: usize,
1045
1046 sama_val: f64,
1047}
1048
1049impl SamaStream {
1050 pub fn try_new(params: SamaParams) -> Result<Self, SamaError> {
1051 let length = params.length.unwrap_or(200);
1052 let maj = params.maj_length.unwrap_or(14);
1053 let min = params.min_length.unwrap_or(6);
1054 if length == 0 || maj == 0 || min == 0 {
1055 return Err(SamaError::InvalidPeriod {
1056 period: 0,
1057 data_len: 0,
1058 });
1059 }
1060
1061 let cap = length + 1;
1062 let maj_alpha = 2.0 / (maj as f64 + 1.0);
1063 let min_alpha = 2.0 / (min as f64 + 1.0);
1064
1065 Ok(Self {
1066 length,
1067 maj_length: maj,
1068 min_length: min,
1069
1070 maj_alpha,
1071 min_alpha,
1072 delta: min_alpha - maj_alpha,
1073
1074 cap,
1075 buf: vec![f64::NAN; cap],
1076 head: 0,
1077 tick: 0,
1078
1079 max_idx: vec![0usize; cap],
1080 min_idx: vec![0usize; cap],
1081 max_head: 0,
1082 min_head: 0,
1083 max_len: 0,
1084 min_len: 0,
1085
1086 sama_val: f64::NAN,
1087 })
1088 }
1089
1090 #[inline]
1091 pub fn update(&mut self, value: f64) -> Option<f64> {
1092 let t = self.tick;
1093 let cap = self.cap;
1094
1095 let oldest = t.saturating_sub(self.length);
1096
1097 while self.max_len > 0 {
1098 let idx = self.max_idx[self.max_head];
1099 if idx >= oldest {
1100 break;
1101 }
1102 self.max_head += 1;
1103 if self.max_head == cap {
1104 self.max_head = 0;
1105 }
1106 self.max_len -= 1;
1107 }
1108 while self.min_len > 0 {
1109 let idx = self.min_idx[self.min_head];
1110 if idx >= oldest {
1111 break;
1112 }
1113 self.min_head += 1;
1114 if self.min_head == cap {
1115 self.min_head = 0;
1116 }
1117 self.min_len -= 1;
1118 }
1119
1120 self.buf[self.head] = value;
1121
1122 if value.is_nan() {
1123 self.head += 1;
1124 if self.head == cap {
1125 self.head = 0;
1126 }
1127 self.tick = t.wrapping_add(1);
1128 return Some(f64::NAN);
1129 }
1130
1131 while self.max_len > 0 {
1132 let back_pos = (self.max_head + self.max_len - 1) % cap;
1133 let back_idx = self.max_idx[back_pos];
1134 let back_val = self.buf[back_idx % cap];
1135 if back_val <= value {
1136 self.max_len -= 1;
1137 } else {
1138 break;
1139 }
1140 }
1141 let ins_pos_max = (self.max_head + self.max_len) % cap;
1142 self.max_idx[ins_pos_max] = t;
1143 self.max_len += 1;
1144
1145 while self.min_len > 0 {
1146 let back_pos = (self.min_head + self.min_len - 1) % cap;
1147 let back_idx = self.min_idx[back_pos];
1148 let back_val = self.buf[back_idx % cap];
1149 if back_val >= value {
1150 self.min_len -= 1;
1151 } else {
1152 break;
1153 }
1154 }
1155 let ins_pos_min = (self.min_head + self.min_len) % cap;
1156 self.min_idx[ins_pos_min] = t;
1157 self.min_len += 1;
1158
1159 let hh = self.buf[self.max_idx[self.max_head] % cap];
1160 let ll = self.buf[self.min_idx[self.min_head] % cap];
1161
1162 let denom = hh - ll;
1163 let c = (value + value) - (hh + ll);
1164 let mult = if denom > 0.0 { c.abs() / denom } else { 0.0 };
1165
1166 let a = mult.mul_add(self.delta, self.maj_alpha);
1167 let alpha = a * a;
1168
1169 if self.sama_val.is_nan() {
1170 self.sama_val = value;
1171 } else {
1172 let diff = value - self.sama_val;
1173 self.sama_val = diff.mul_add(alpha, self.sama_val);
1174 }
1175
1176 self.head += 1;
1177 if self.head == cap {
1178 self.head = 0;
1179 }
1180 self.tick = t.wrapping_add(1);
1181
1182 Some(self.sama_val)
1183 }
1184
1185 #[inline]
1186 pub fn next(&mut self, value: f64) -> f64 {
1187 self.update(value).unwrap_or(f64::NAN)
1188 }
1189
1190 #[inline]
1191 pub fn reset(&mut self) {
1192 self.buf.fill(f64::NAN);
1193 self.head = 0;
1194 self.tick = 0;
1195 self.max_len = 0;
1196 self.min_len = 0;
1197 self.max_head = 0;
1198 self.min_head = 0;
1199 self.sama_val = f64::NAN;
1200 }
1201}
1202
1203#[cfg(feature = "python")]
1204#[pyfunction(name = "sama")]
1205#[pyo3(signature = (data, length, maj_length, min_length, kernel=None))]
1206pub fn sama_py<'py>(
1207 py: Python<'py>,
1208 data: PyReadonlyArray1<'py, f64>,
1209 length: usize,
1210 maj_length: usize,
1211 min_length: usize,
1212 kernel: Option<&str>,
1213) -> PyResult<Bound<'py, PyArray1<f64>>> {
1214 let slice_in = data.as_slice()?;
1215 let kern = validate_kernel(kernel, false)?;
1216 let params = SamaParams {
1217 length: Some(length),
1218 maj_length: Some(maj_length),
1219 min_length: Some(min_length),
1220 };
1221 let input = SamaInput::from_slice(slice_in, params);
1222 let result_vec: Vec<f64> = py
1223 .allow_threads(|| sama_with_kernel(&input, kern).map(|o| o.values))
1224 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1225 Ok(result_vec.into_pyarray(py))
1226}
1227
1228#[cfg(feature = "python")]
1229#[pyfunction(name = "sama_batch")]
1230#[pyo3(signature = (data, length_range, maj_length_range, min_length_range, kernel=None))]
1231pub fn sama_batch_py<'py>(
1232 py: Python<'py>,
1233 data: PyReadonlyArray1<'py, f64>,
1234 length_range: (usize, usize, usize),
1235 maj_length_range: (usize, usize, usize),
1236 min_length_range: (usize, usize, usize),
1237 kernel: Option<&str>,
1238) -> PyResult<Bound<'py, PyDict>> {
1239 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1240 use pyo3::types::PyDict;
1241
1242 let slice_in = data.as_slice()?;
1243 let sweep = SamaBatchRange {
1244 length: length_range,
1245 maj_length: maj_length_range,
1246 min_length: min_length_range,
1247 };
1248
1249 let combos = expand_grid_sama(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1250 let rows = combos.len();
1251 let cols = slice_in.len();
1252
1253 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1254 let slice_out = unsafe { out_arr.as_slice_mut()? };
1255
1256 let kern = validate_kernel(kernel, true)?;
1257
1258 let combos = py
1259 .allow_threads(|| {
1260 let mapped = match kern {
1261 Kernel::Auto => detect_best_batch_kernel(),
1262 k => k,
1263 };
1264 let simd = match mapped {
1265 Kernel::Avx512Batch => Kernel::Avx512,
1266 Kernel::Avx2Batch => Kernel::Avx2,
1267 Kernel::ScalarBatch => Kernel::Scalar,
1268 _ => Kernel::Scalar,
1269 };
1270
1271 let first = slice_in
1272 .iter()
1273 .position(|x| !x.is_nan())
1274 .ok_or(SamaError::AllValuesNaN)?;
1275 sama_batch_inner_into(slice_in, &combos, first, simd, true, slice_out).map(|_| combos)
1276 })
1277 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1278
1279 let dict = PyDict::new(py);
1280 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1281 dict.set_item(
1282 "lengths",
1283 combos
1284 .iter()
1285 .map(|p| p.length.unwrap_or(200) as u64)
1286 .collect::<Vec<_>>()
1287 .into_pyarray(py),
1288 )?;
1289 dict.set_item(
1290 "maj_lengths",
1291 combos
1292 .iter()
1293 .map(|p| p.maj_length.unwrap_or(14) as u64)
1294 .collect::<Vec<_>>()
1295 .into_pyarray(py),
1296 )?;
1297 dict.set_item(
1298 "min_lengths",
1299 combos
1300 .iter()
1301 .map(|p| p.min_length.unwrap_or(6) as u64)
1302 .collect::<Vec<_>>()
1303 .into_pyarray(py),
1304 )?;
1305 Ok(dict.into())
1306}
1307
1308#[cfg(all(feature = "python", feature = "cuda"))]
1309#[pyfunction(name = "sama_cuda_batch_dev")]
1310#[pyo3(signature = (data_f32, length_range=(200, 200, 0), maj_length_range=(14, 14, 0), min_length_range=(6, 6, 0), device_id=0))]
1311pub fn sama_cuda_batch_dev_py(
1312 py: Python<'_>,
1313 data_f32: PyReadonlyArray1<'_, f32>,
1314 length_range: (usize, usize, usize),
1315 maj_length_range: (usize, usize, usize),
1316 min_length_range: (usize, usize, usize),
1317 device_id: usize,
1318) -> PyResult<DeviceArrayF32SamaPy> {
1319 if !cuda_available() {
1320 return Err(PyValueError::new_err("CUDA not available"));
1321 }
1322
1323 let slice_in = data_f32.as_slice()?;
1324 let sweep = SamaBatchRange {
1325 length: length_range,
1326 maj_length: maj_length_range,
1327 min_length: min_length_range,
1328 };
1329
1330 let inner = py.allow_threads(|| {
1331 let cuda = CudaSama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1332 cuda.sama_batch_dev(slice_in, &sweep)
1333 .map_err(|e| PyValueError::new_err(e.to_string()))
1334 })?;
1335
1336 Ok(DeviceArrayF32SamaPy { inner })
1337}
1338
1339#[cfg(all(feature = "python", feature = "cuda"))]
1340#[pyfunction(name = "sama_cuda_many_series_one_param_dev")]
1341#[pyo3(signature = (data_tm_f32, length, maj_length, min_length, device_id=0))]
1342pub fn sama_cuda_many_series_one_param_dev_py(
1343 py: Python<'_>,
1344 data_tm_f32: PyReadonlyArray2<'_, f32>,
1345 length: usize,
1346 maj_length: usize,
1347 min_length: usize,
1348 device_id: usize,
1349) -> PyResult<DeviceArrayF32SamaPy> {
1350 if !cuda_available() {
1351 return Err(PyValueError::new_err("CUDA not available"));
1352 }
1353 if length == 0 || maj_length == 0 || min_length == 0 {
1354 return Err(PyValueError::new_err(
1355 "length, maj_length, and min_length must be positive",
1356 ));
1357 }
1358
1359 let flat = data_tm_f32.as_slice()?;
1360 let shape = data_tm_f32.shape();
1361 let series_len = shape[0];
1362 let num_series = shape[1];
1363 let params = SamaParams {
1364 length: Some(length),
1365 maj_length: Some(maj_length),
1366 min_length: Some(min_length),
1367 };
1368
1369 let inner = py.allow_threads(|| {
1370 let cuda = CudaSama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1371 cuda.sama_many_series_one_param_time_major_dev(flat, num_series, series_len, ¶ms)
1372 .map_err(|e| PyValueError::new_err(e.to_string()))
1373 })?;
1374
1375 Ok(DeviceArrayF32SamaPy { inner })
1376}
1377
1378#[cfg(feature = "python")]
1379#[pyclass(name = "SamaStream")]
1380pub struct SamaStreamPy {
1381 stream: SamaStream,
1382}
1383
1384#[cfg(feature = "python")]
1385#[pymethods]
1386impl SamaStreamPy {
1387 #[new]
1388 fn new(length: usize, maj_length: usize, min_length: usize) -> PyResult<Self> {
1389 let params = SamaParams {
1390 length: Some(length),
1391 maj_length: Some(maj_length),
1392 min_length: Some(min_length),
1393 };
1394 let stream =
1395 SamaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1396 Ok(SamaStreamPy { stream })
1397 }
1398
1399 fn update(&mut self, value: f64) -> Option<f64> {
1400 self.stream.update(value)
1401 }
1402}
1403
1404#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1405#[wasm_bindgen]
1406pub fn sama_js(
1407 data: &[f64],
1408 length: usize,
1409 maj_length: usize,
1410 min_length: usize,
1411) -> Result<Vec<f64>, JsValue> {
1412 let params = SamaParams {
1413 length: Some(length),
1414 maj_length: Some(maj_length),
1415 min_length: Some(min_length),
1416 };
1417 let input = SamaInput::from_slice(data, params);
1418 let mut out = vec![0.0; data.len()];
1419 sama_into_slice(&mut out, &input, detect_best_kernel())
1420 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1421 Ok(out)
1422}
1423
1424#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1425#[derive(Serialize, Deserialize)]
1426pub struct SamaBatchConfig {
1427 pub length_range: (usize, usize, usize),
1428 pub maj_length_range: (usize, usize, usize),
1429 pub min_length_range: (usize, usize, usize),
1430}
1431
1432#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1433#[derive(Serialize, Deserialize)]
1434pub struct SamaBatchJsOutput {
1435 pub values: Vec<f64>,
1436 pub combos: Vec<SamaParams>,
1437 pub rows: usize,
1438 pub cols: usize,
1439}
1440
1441#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1442#[wasm_bindgen(js_name = sama_batch)]
1443pub fn sama_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1444 let cfg: SamaBatchConfig = serde_wasm_bindgen::from_value(config)
1445 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1446 let sweep = SamaBatchRange {
1447 length: cfg.length_range,
1448 maj_length: cfg.maj_length_range,
1449 min_length: cfg.min_length_range,
1450 };
1451 let out = sama_batch_inner(data, &sweep, detect_best_kernel(), false)
1452 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1453 serde_wasm_bindgen::to_value(&SamaBatchJsOutput {
1454 values: out.values,
1455 combos: out.combos,
1456 rows: out.rows,
1457 cols: out.cols,
1458 })
1459 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1460}
1461
1462#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1463#[wasm_bindgen]
1464pub fn sama_alloc(len: usize) -> *mut f64 {
1465 let mut v = Vec::<f64>::with_capacity(len);
1466 let p = v.as_mut_ptr();
1467 std::mem::forget(v);
1468 p
1469}
1470
1471#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1472#[wasm_bindgen]
1473pub fn sama_free(ptr: *mut f64, len: usize) {
1474 unsafe {
1475 let _ = Vec::from_raw_parts(ptr, len, len);
1476 }
1477}
1478
1479#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1480#[wasm_bindgen]
1481pub fn sama_into(
1482 in_ptr: *const f64,
1483 out_ptr: *mut f64,
1484 len: usize,
1485 length: usize,
1486 maj_length: usize,
1487 min_length: usize,
1488) -> Result<(), JsValue> {
1489 if in_ptr.is_null() || out_ptr.is_null() {
1490 return Err(JsValue::from_str("null pointer"));
1491 }
1492
1493 unsafe {
1494 let data = std::slice::from_raw_parts(in_ptr, len);
1495 let params = SamaParams {
1496 length: Some(length),
1497 maj_length: Some(maj_length),
1498 min_length: Some(min_length),
1499 };
1500 let input = SamaInput::from_slice(data, params);
1501
1502 if in_ptr == out_ptr {
1503 let mut tmp = vec![0.0; len];
1504 sama_into_slice(&mut tmp, &input, detect_best_kernel())
1505 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1506 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1507 out.copy_from_slice(&tmp);
1508 } else {
1509 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1510
1511 sama_into_slice(out, &input, detect_best_kernel())
1512 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1513 }
1514 }
1515 Ok(())
1516}
1517
1518#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1519#[wasm_bindgen]
1520pub fn sama_batch_into(
1521 in_ptr: *const f64,
1522 out_ptr: *mut f64,
1523 len: usize,
1524 length_start: usize,
1525 length_end: usize,
1526 length_step: usize,
1527 maj_start: usize,
1528 maj_end: usize,
1529 maj_step: usize,
1530 min_start: usize,
1531 min_end: usize,
1532 min_step: usize,
1533) -> Result<usize, JsValue> {
1534 if in_ptr.is_null() || out_ptr.is_null() {
1535 return Err(JsValue::from_str("null pointer"));
1536 }
1537 unsafe {
1538 let data = std::slice::from_raw_parts(in_ptr, len);
1539 let sweep = SamaBatchRange {
1540 length: (length_start, length_end, length_step),
1541 maj_length: (maj_start, maj_end, maj_step),
1542 min_length: (min_start, min_end, min_step),
1543 };
1544 let combos = expand_grid_sama(&sweep).map_err(|_| JsValue::from_str("Invalid range"))?;
1545 let rows = combos.len();
1546 let cols = len;
1547 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1548 let first = data
1549 .iter()
1550 .position(|x| !x.is_nan())
1551 .ok_or_else(|| JsValue::from_str("All NaN"))?;
1552 sama_batch_inner_into(data, &combos, first, detect_best_kernel(), false, out)
1553 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1554 Ok(rows)
1555 }
1556}
1557
1558#[cfg(test)]
1559mod tests {
1560 use super::*;
1561 use crate::skip_if_unsupported;
1562 use crate::utilities::data_loader::read_candles_from_csv;
1563 use std::error::Error;
1564
1565 fn check_sama_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1566 skip_if_unsupported!(kernel, test_name);
1567 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1568 let candles = read_candles_from_csv(file_path)?;
1569
1570 let input = SamaInput::from_candles(&candles, "close", SamaParams::default());
1571 let result = sama_with_kernel(&input, kernel)?;
1572
1573 assert_eq!(result.values.len(), candles.close.len());
1574
1575 let valid_count = result.values.iter().filter(|&&v| !v.is_nan()).count();
1576 assert!(
1577 valid_count > 0,
1578 "[{}] SAMA should produce valid values",
1579 test_name
1580 );
1581
1582 Ok(())
1583 }
1584
1585 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1586 #[test]
1587 fn test_sama_into_matches_api() -> Result<(), Box<dyn Error>> {
1588 let mut data = Vec::with_capacity(256);
1589 for _ in 0..5 {
1590 data.push(f64::NAN);
1591 }
1592 for i in 0..251 {
1593 let x = (i as f64).sin() * 0.12345 + 50.0 + ((i % 11) as f64) * 0.001;
1594 data.push(x);
1595 }
1596
1597 let input = SamaInput::from_slice(&data, SamaParams::default());
1598 let baseline = sama(&input)?.values;
1599
1600 let mut out = vec![0.0; data.len()];
1601 sama_into(&input, &mut out)?;
1602
1603 assert_eq!(baseline.len(), out.len());
1604 fn eq_or_both_nan(a: f64, b: f64) -> bool {
1605 (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
1606 }
1607 for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
1608 assert!(
1609 eq_or_both_nan(a, b),
1610 "mismatch at index {}: api={} into={}",
1611 i,
1612 a,
1613 b
1614 );
1615 }
1616
1617 Ok(())
1618 }
1619
1620 fn check_sama_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1621 skip_if_unsupported!(kernel, test_name);
1622 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1623 let candles = read_candles_from_csv(file_path)?;
1624
1625 let default_params = SamaParams {
1626 length: None,
1627 maj_length: None,
1628 min_length: None,
1629 };
1630 let input = SamaInput::from_candles(&candles, "close", default_params);
1631 let output = sama_with_kernel(&input, kernel)?;
1632 assert_eq!(output.values.len(), candles.close.len());
1633
1634 Ok(())
1635 }
1636
1637 fn check_sama_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1638 skip_if_unsupported!(kernel, test_name);
1639 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1640 let candles = read_candles_from_csv(file_path)?;
1641
1642 let input = SamaInput::with_default_candles(&candles);
1643 match input.data {
1644 SamaData::Candles { source, .. } => assert_eq!(source, "close"),
1645 _ => panic!("Expected SamaData::Candles"),
1646 }
1647 let output = sama_with_kernel(&input, kernel)?;
1648 assert_eq!(output.values.len(), candles.close.len());
1649
1650 Ok(())
1651 }
1652
1653 fn check_sama_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1654 skip_if_unsupported!(kernel, test_name);
1655 let input_data = [10.0, 20.0, 30.0];
1656 let params = SamaParams {
1657 length: Some(0),
1658 maj_length: None,
1659 min_length: None,
1660 };
1661 let input = SamaInput::from_slice(&input_data, params);
1662 let res = sama_with_kernel(&input, kernel);
1663 assert!(
1664 res.is_err(),
1665 "[{}] SAMA should fail with zero period",
1666 test_name
1667 );
1668 Ok(())
1669 }
1670
1671 fn check_sama_period_exceeds_length(
1672 test_name: &str,
1673 kernel: Kernel,
1674 ) -> Result<(), Box<dyn Error>> {
1675 skip_if_unsupported!(kernel, test_name);
1676 let data_small = [10.0, 20.0, 30.0];
1677 let params = SamaParams {
1678 length: Some(10),
1679 maj_length: None,
1680 min_length: None,
1681 };
1682 let input = SamaInput::from_slice(&data_small, params);
1683 let res = sama_with_kernel(&input, kernel);
1684 assert!(
1685 res.is_err(),
1686 "[{}] SAMA should fail with period exceeding length",
1687 test_name
1688 );
1689 Ok(())
1690 }
1691
1692 fn check_sama_very_small_dataset(
1693 test_name: &str,
1694 kernel: Kernel,
1695 ) -> Result<(), Box<dyn Error>> {
1696 skip_if_unsupported!(kernel, test_name);
1697 let single_point = [42.0];
1698 let params = SamaParams::default();
1699 let input = SamaInput::from_slice(&single_point, params);
1700 let res = sama_with_kernel(&input, kernel);
1701 assert!(
1702 res.is_err(),
1703 "[{}] SAMA should fail with insufficient data",
1704 test_name
1705 );
1706 Ok(())
1707 }
1708
1709 fn check_sama_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1710 skip_if_unsupported!(kernel, test_name);
1711 let empty: [f64; 0] = [];
1712 let params = SamaParams::default();
1713 let input = SamaInput::from_slice(&empty, params);
1714 let res = sama_with_kernel(&input, kernel);
1715 assert!(
1716 matches!(res, Err(SamaError::EmptyInputData)),
1717 "[{}] SAMA should fail with empty input",
1718 test_name
1719 );
1720 Ok(())
1721 }
1722
1723 fn check_sama_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1724 skip_if_unsupported!(kernel, test_name);
1725 let nan_data = [f64::NAN, f64::NAN, f64::NAN];
1726 let params = SamaParams::default();
1727 let input = SamaInput::from_slice(&nan_data, params);
1728 let res = sama_with_kernel(&input, kernel);
1729 assert!(
1730 matches!(res, Err(SamaError::AllValuesNaN)),
1731 "[{}] SAMA should fail with all NaN values",
1732 test_name
1733 );
1734 Ok(())
1735 }
1736
1737 fn check_sama_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1738 skip_if_unsupported!(kernel, test_name);
1739 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1740 let candles = read_candles_from_csv(file_path)?;
1741
1742 let first_params = SamaParams {
1743 length: Some(50),
1744 maj_length: Some(14),
1745 min_length: Some(6),
1746 };
1747 let first_input = SamaInput::from_candles(&candles, "close", first_params);
1748 let first_result = sama_with_kernel(&first_input, kernel)?;
1749
1750 let second_params = SamaParams {
1751 length: Some(50),
1752 maj_length: Some(14),
1753 min_length: Some(6),
1754 };
1755 let second_input = SamaInput::from_slice(&first_result.values, second_params);
1756 let second_result = sama_with_kernel(&second_input, kernel)?;
1757
1758 assert_eq!(second_result.values.len(), first_result.values.len());
1759 Ok(())
1760 }
1761
1762 fn check_sama_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1763 skip_if_unsupported!(kernel, test_name);
1764 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1765 let candles = read_candles_from_csv(file_path)?;
1766
1767 let params = SamaParams::default();
1768 let input = SamaInput::from_candles(&candles, "close", params);
1769 let output = sama_with_kernel(&input, kernel)?;
1770
1771 let first_valid = candles.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
1772 let warmup = first_valid + input.get_length();
1773
1774 for (i, &val) in output
1775 .values
1776 .iter()
1777 .enumerate()
1778 .skip(warmup.min(output.values.len()))
1779 {
1780 assert!(
1781 !val.is_nan(),
1782 "[{}] Unexpected NaN at index {}",
1783 test_name,
1784 i
1785 );
1786 }
1787 Ok(())
1788 }
1789
1790 fn check_sama_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1791 skip_if_unsupported!(kernel, test_name);
1792 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1793 let candles = read_candles_from_csv(file_path)?;
1794
1795 let params = SamaParams::default();
1796
1797 let batch_input = SamaInput::from_candles(&candles, "close", params.clone());
1798 let batch_result = sama_with_kernel(&batch_input, kernel)?;
1799
1800 let mut stream = SamaStream::try_new(params)?;
1801 let mut stream_results = Vec::with_capacity(candles.close.len());
1802 for &price in &candles.close {
1803 stream_results.push(stream.next(price));
1804 }
1805
1806 assert_eq!(batch_result.values.len(), stream_results.len());
1807
1808 for (i, (&batch_val, &stream_val)) in batch_result
1809 .values
1810 .iter()
1811 .zip(stream_results.iter())
1812 .enumerate()
1813 {
1814 if batch_val.is_nan() && stream_val.is_nan() {
1815 continue;
1816 }
1817 if !batch_val.is_nan() && !stream_val.is_nan() {
1818 let diff = (batch_val - stream_val).abs();
1819 assert!(
1820 diff < 1e-9,
1821 "[{}] Stream mismatch at index {}: batch={}, stream={}, diff={}",
1822 test_name,
1823 i,
1824 batch_val,
1825 stream_val,
1826 diff
1827 );
1828 }
1829 }
1830 Ok(())
1831 }
1832
1833 fn check_batch_sweep(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1834 skip_if_unsupported!(kernel, test_name);
1835 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1836 let candles = read_candles_from_csv(file_path)?;
1837
1838 let range = SamaBatchRange {
1839 length: (190, 210, 1),
1840 maj_length: (12, 16, 1),
1841 min_length: (4, 8, 1),
1842 };
1843
1844 let output = sama_batch_with_kernel(&candles.close, &range, kernel)?;
1845
1846 let expected_count = 21 * 5 * 5;
1847 assert_eq!(
1848 output.rows, expected_count,
1849 "[{}] Expected {} batch results",
1850 test_name, expected_count
1851 );
1852 assert_eq!(
1853 output.values.len(),
1854 expected_count * candles.close.len(),
1855 "[{}] Expected flattened array size",
1856 test_name
1857 );
1858
1859 Ok(())
1860 }
1861
1862 fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1863 skip_if_unsupported!(kernel, test_name);
1864 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1865 let candles = read_candles_from_csv(file_path)?;
1866
1867 let range = SamaBatchRange {
1868 length: (200, 200, 0),
1869 maj_length: (14, 14, 0),
1870 min_length: (6, 6, 0),
1871 };
1872
1873 let output = sama_batch_with_kernel(&candles.close, &range, kernel)?;
1874
1875 assert_eq!(
1876 output.rows, 1,
1877 "[{}] Should have 1 result for default params",
1878 test_name
1879 );
1880 assert_eq!(output.cols, candles.close.len());
1881
1882 Ok(())
1883 }
1884
1885 macro_rules! generate_all_sama_tests {
1886 ($($test_fn:ident),*) => {
1887 paste::paste! {
1888 $(
1889 #[test]
1890 fn [<$test_fn _scalar>]() -> Result<(), Box<dyn Error>> {
1891 $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar)
1892 }
1893 )*
1894
1895 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1896 $(
1897 #[test]
1898 fn [<$test_fn _avx2>]() -> Result<(), Box<dyn Error>> {
1899 $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2)
1900 }
1901
1902 #[test]
1903 fn [<$test_fn _avx512>]() -> Result<(), Box<dyn Error>> {
1904 $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512)
1905 }
1906 )*
1907
1908 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1909 $(
1910 #[test]
1911 fn [<$test_fn _simd128>]() -> Result<(), Box<dyn Error>> {
1912 $test_fn(stringify!([<$test_fn _simd128>]), Kernel::Scalar)
1913 }
1914 )*
1915 }
1916 }
1917 }
1918
1919 macro_rules! gen_batch_tests {
1920 ($fn_name:ident) => {
1921 paste::paste! {
1922 #[test]
1923 fn [<$fn_name _scalar>]() -> Result<(), Box<dyn Error>> {
1924 $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch)
1925 }
1926
1927 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1928 #[test]
1929 fn [<$fn_name _avx2>]() -> Result<(), Box<dyn Error>> {
1930 $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch)
1931 }
1932
1933 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1934 #[test]
1935 fn [<$fn_name _avx512>]() -> Result<(), Box<dyn Error>> {
1936 $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch)
1937 }
1938
1939 #[test]
1940 fn [<$fn_name _auto_detect>]() -> Result<(), Box<dyn Error>> {
1941 $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto)
1942 }
1943 }
1944 };
1945 }
1946
1947 generate_all_sama_tests!(
1948 check_sama_accuracy,
1949 check_sama_partial_params,
1950 check_sama_default_candles,
1951 check_sama_zero_period,
1952 check_sama_period_exceeds_length,
1953 check_sama_very_small_dataset,
1954 check_sama_empty_input,
1955 check_sama_all_nan,
1956 check_sama_reinput,
1957 check_sama_nan_handling,
1958 check_sama_streaming
1959 );
1960
1961 gen_batch_tests!(check_batch_sweep);
1962 gen_batch_tests!(check_batch_default_row);
1963
1964 #[cfg(debug_assertions)]
1965 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1966 skip_if_unsupported!(kernel, test);
1967
1968 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1969 let c = read_candles_from_csv(file)?;
1970
1971 let cfgs = vec![
1972 (190, 200, 1, 12, 14, 1, 4, 8, 1),
1973 (200, 200, 0, 14, 14, 0, 6, 6, 0),
1974 (195, 205, 5, 10, 16, 2, 3, 9, 2),
1975 ];
1976
1977 for (ls, le, lstep, js, je, jstep, ms, me, mstep) in cfgs {
1978 let out = SamaBatchRange {
1979 length: (ls, le, lstep),
1980 maj_length: (js, je, jstep),
1981 min_length: (ms, me, mstep),
1982 };
1983 let res = sama_batch_with_kernel(&c.close, &out, kernel)?;
1984 for &v in &res.values {
1985 if v.is_nan() {
1986 continue;
1987 }
1988 let b = v.to_bits();
1989 assert_ne!(b, 0x1111_1111_1111_1111);
1990 assert_ne!(b, 0x2222_2222_2222_2222);
1991 assert_ne!(b, 0x3333_3333_3333_3333);
1992 }
1993 }
1994 Ok(())
1995 }
1996
1997 #[cfg(debug_assertions)]
1998 gen_batch_tests!(check_batch_no_poison);
1999
2000 #[cfg(not(debug_assertions))]
2001 fn check_batch_no_poison_scalar() -> Result<(), Box<dyn Error>> {
2002 Ok(())
2003 }
2004
2005 #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2006 #[test]
2007 fn test_sama_simd128_correctness() {
2008 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2009 let params = SamaParams::default();
2010 let input = SamaInput::from_slice(&data, params);
2011 let scalar = sama_with_kernel(&input, Kernel::Scalar).unwrap();
2012 let simd = sama_with_kernel(&input, Kernel::Scalar).unwrap();
2013 assert_eq!(scalar.values.len(), simd.values.len());
2014 for (a, b) in scalar.values.iter().zip(simd.values.iter()) {
2015 assert!((a - b).abs() < 1e-10);
2016 }
2017 }
2018
2019 #[cfg(debug_assertions)]
2020 #[test]
2021 fn test_sama_no_poison_values() -> Result<(), Box<dyn Error>> {
2022 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2023 let candles = read_candles_from_csv(file_path)?;
2024
2025 let input = SamaInput::from_candles(&candles, "close", SamaParams::default());
2026 let output = sama(&input)?;
2027
2028 for &v in &output.values {
2029 if v.is_nan() {
2030 continue;
2031 }
2032 let b = v.to_bits();
2033
2034 assert_ne!(
2035 b, 0x11111111_11111111,
2036 "Found poison value 0x11111111_11111111"
2037 );
2038 assert_ne!(
2039 b, 0x22222222_22222222,
2040 "Found poison value 0x22222222_22222222"
2041 );
2042 assert_ne!(
2043 b, 0x33333333_33333333,
2044 "Found poison value 0x33333333_33333333"
2045 );
2046 assert_ne!(
2047 b, 0xDEADBEEF_DEADBEEF,
2048 "Found poison value 0xDEADBEEF_DEADBEEF"
2049 );
2050 assert_ne!(
2051 b, 0xFEEEFEEE_FEEEFEEE,
2052 "Found poison value 0xFEEEFEEE_FEEEFEEE"
2053 );
2054 }
2055 Ok(())
2056 }
2057
2058 #[test]
2059 fn test_sama_stream_incremental() -> Result<(), Box<dyn Error>> {
2060 let params = SamaParams {
2061 length: Some(10),
2062 maj_length: Some(5),
2063 min_length: Some(3),
2064 };
2065
2066 let mut stream = SamaStream::try_new(params)?;
2067 let data = vec![
2068 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0,
2069 ];
2070
2071 let mut results = Vec::new();
2072 for &val in &data {
2073 let result = stream.next(val);
2074 if !result.is_nan() {
2075 results.push(result);
2076 }
2077 }
2078
2079 assert!(
2080 !results.is_empty(),
2081 "Stream should produce results immediately"
2082 );
2083
2084 Ok(())
2085 }
2086
2087 #[test]
2088 fn sama_pine_parity_head_start() -> Result<(), Box<dyn Error>> {
2089 let mut long = vec![0.0; 5000];
2090 for i in 0..long.len() {
2091 long[i] = (i as f64).sin() + (i as f64 * 0.01).cos();
2092 }
2093
2094 let pine_params = SamaParams::default();
2095 let pine_out = SamaInput::from_slice(&long, pine_params.clone());
2096 let a = sama_with_kernel(&pine_out, Kernel::Scalar)?.values;
2097
2098 let tail = &long[2000..];
2099 let pine_like_tail = SamaInput::from_slice(tail, pine_params);
2100 let b = sama_with_kernel(&pine_like_tail, Kernel::Scalar)?.values;
2101
2102 let tol = 0.1;
2103 for (i, (&x, &y)) in a[2000..].iter().zip(b.iter()).enumerate().skip(100) {
2104 if x.is_finite() && y.is_finite() {
2105 assert!((x - y).abs() < tol, "i={}, |Δ|={}", i, (x - y).abs());
2106 }
2107 }
2108
2109 Ok(())
2110 }
2111}
2112
2113#[cfg(all(feature = "proptest", not(target_arch = "wasm32")))]
2114mod prop_tests {
2115 use super::*;
2116 use proptest::prelude::*;
2117
2118 proptest! {
2119 #[test]
2120 fn sama_properties(data in prop::collection::vec(-1e6f64..1e6, 5..400),
2121 len in 2usize..64,
2122 maj in 2usize..64,
2123 min in 2usize..64) {
2124
2125
2126 if data.len() <= len {
2127 return Ok(());
2128 }
2129
2130 let params = SamaParams {
2131 length: Some(len),
2132 maj_length: Some(maj),
2133 min_length: Some(min),
2134 };
2135 let input = SamaInput::from_slice(&data, params);
2136 let SamaOutput { values: out } = sama_with_kernel(&input, Kernel::Scalar).unwrap();
2137
2138
2139 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2140 let warm = first + len;
2141 for i in warm..data.len() {
2142 let wstart = i - len;
2143 let window = &data[wstart..=i];
2144 if window.iter().all(|v| v.is_finite()) {
2145 let y = out[i];
2146
2147
2148
2149 prop_assert!(
2150 y.is_finite(),
2151 "Output {} at index {} is not finite",
2152 y, i
2153 );
2154 }
2155 }
2156 }
2157 }
2158}
2159
2160#[cfg(all(feature = "python", feature = "cuda"))]
2161#[pyclass(module = "ta_indicators.cuda", unsendable)]
2162pub struct DeviceArrayF32SamaPy {
2163 pub(crate) inner: DeviceArrayF32Sama,
2164}
2165
2166#[cfg(all(feature = "python", feature = "cuda"))]
2167#[pymethods]
2168impl DeviceArrayF32SamaPy {
2169 #[getter]
2170 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2171 let d = PyDict::new(py);
2172 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2173 d.set_item("typestr", "<f4")?;
2174
2175 d.set_item(
2176 "strides",
2177 (
2178 self.inner.cols * std::mem::size_of::<f32>(),
2179 std::mem::size_of::<f32>(),
2180 ),
2181 )?;
2182
2183 let len = self.inner.rows.checked_mul(self.inner.cols).unwrap_or(0);
2184 let ptr_usize = if len == 0 {
2185 0usize
2186 } else {
2187 self.inner.device_ptr() as usize
2188 };
2189 d.set_item("data", (ptr_usize, false))?;
2190
2191 d.set_item("version", 3)?;
2192 Ok(d)
2193 }
2194
2195 fn __dlpack_device__(&self) -> (i32, i32) {
2196 (2, self.inner.device_id as i32)
2197 }
2198
2199 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2200 fn __dlpack__<'py>(
2201 &mut self,
2202 py: Python<'py>,
2203 stream: Option<pyo3::PyObject>,
2204 max_version: Option<pyo3::PyObject>,
2205 dl_device: Option<pyo3::PyObject>,
2206 copy: Option<pyo3::PyObject>,
2207 ) -> PyResult<PyObject> {
2208 let (kdl, alloc_dev) = self.__dlpack_device__();
2209 if let Some(dev_obj) = dl_device.as_ref() {
2210 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2211 if dev_ty != kdl || dev_id != alloc_dev {
2212 let wants_copy = copy
2213 .as_ref()
2214 .and_then(|c| c.extract::<bool>(py).ok())
2215 .unwrap_or(false);
2216 if wants_copy {
2217 return Err(PyValueError::new_err(
2218 "device copy not implemented for __dlpack__",
2219 ));
2220 } else {
2221 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2222 }
2223 }
2224 }
2225 }
2226
2227 let _ = stream;
2228
2229 let dummy =
2230 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2231 let device_id = self.inner.device_id;
2232 let ctx = self.inner.ctx.clone();
2233 let inner = std::mem::replace(
2234 &mut self.inner,
2235 DeviceArrayF32Sama {
2236 buf: dummy,
2237 rows: 0,
2238 cols: 0,
2239 ctx,
2240 device_id,
2241 },
2242 );
2243
2244 let rows = inner.rows;
2245 let cols = inner.cols;
2246 let buf = inner.buf;
2247
2248 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2249
2250 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2251 }
2252}