Skip to main content

vector_ta/indicators/moving_averages/
sama.rs

1#[cfg(feature = "python")]
2use numpy::PyUntypedArrayMethods;
3#[cfg(feature = "python")]
4use numpy::{IntoPyArray, PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
5#[cfg(feature = "python")]
6use pyo3::exceptions::PyValueError;
7#[cfg(feature = "python")]
8use pyo3::prelude::*;
9#[cfg(feature = "python")]
10use pyo3::types::PyDict;
11
12#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
13use serde::{Deserialize, Serialize};
14#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
15use wasm_bindgen::prelude::*;
16
17#[cfg(all(feature = "python", feature = "cuda"))]
18use crate::cuda::moving_averages::sama_wrapper::DeviceArrayF32Sama;
19#[cfg(all(feature = "python", feature = "cuda"))]
20use crate::cuda::{cuda_available, moving_averages::CudaSama};
21use crate::utilities::data_loader::{source_type, Candles};
22#[cfg(all(feature = "python", feature = "cuda"))]
23use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
24use crate::utilities::enums::Kernel;
25use crate::utilities::helpers::{
26    alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
27    make_uninit_matrix,
28};
29#[cfg(feature = "python")]
30use crate::utilities::kernel_validation::validate_kernel;
31#[cfg(all(feature = "python", feature = "cuda"))]
32use cust::memory::DeviceBuffer;
33
34#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
35use core::arch::x86_64::*;
36
37#[cfg(not(target_arch = "wasm32"))]
38use rayon::prelude::*;
39
40use std::convert::AsRef;
41use std::error::Error;
42use std::mem::MaybeUninit;
43use thiserror::Error;
44
45impl<'a> AsRef<[f64]> for SamaInput<'a> {
46    #[inline(always)]
47    fn as_ref(&self) -> &[f64] {
48        match &self.data {
49            SamaData::Slice(slice) => slice,
50            SamaData::Candles { candles, source } => source_type(candles, source),
51        }
52    }
53}
54
55#[derive(Debug, Clone)]
56pub enum SamaData<'a> {
57    Candles {
58        candles: &'a Candles,
59        source: &'a str,
60    },
61    Slice(&'a [f64]),
62}
63
64#[derive(Debug, Clone)]
65pub struct SamaOutput {
66    pub values: Vec<f64>,
67}
68
69#[derive(Debug, Clone)]
70#[cfg_attr(
71    all(target_arch = "wasm32", feature = "wasm"),
72    derive(Serialize, Deserialize)
73)]
74pub struct SamaParams {
75    pub length: Option<usize>,
76    pub maj_length: Option<usize>,
77    pub min_length: Option<usize>,
78}
79
80impl Default for SamaParams {
81    fn default() -> Self {
82        Self {
83            length: Some(200),
84            maj_length: Some(14),
85            min_length: Some(6),
86        }
87    }
88}
89
90#[derive(Debug, Clone)]
91pub struct SamaInput<'a> {
92    pub data: SamaData<'a>,
93    pub params: SamaParams,
94}
95
96impl<'a> SamaInput<'a> {
97    #[inline]
98    pub fn from_candles(c: &'a Candles, s: &'a str, p: SamaParams) -> Self {
99        Self {
100            data: SamaData::Candles {
101                candles: c,
102                source: s,
103            },
104            params: p,
105        }
106    }
107
108    #[inline]
109    pub fn from_slice(sl: &'a [f64], p: SamaParams) -> Self {
110        Self {
111            data: SamaData::Slice(sl),
112            params: p,
113        }
114    }
115
116    #[inline]
117    pub fn with_default_candles(c: &'a Candles) -> Self {
118        Self::from_candles(c, "close", SamaParams::default())
119    }
120
121    #[inline]
122    pub fn get_length(&self) -> usize {
123        self.params.length.unwrap_or(200)
124    }
125
126    #[inline]
127    pub fn get_maj_length(&self) -> usize {
128        self.params.maj_length.unwrap_or(14)
129    }
130
131    #[inline]
132    pub fn get_min_length(&self) -> usize {
133        self.params.min_length.unwrap_or(6)
134    }
135}
136
137#[derive(Clone, Debug)]
138pub struct SamaBuilder {
139    length: Option<usize>,
140    maj_length: Option<usize>,
141    min_length: Option<usize>,
142    kernel: Kernel,
143}
144
145impl Default for SamaBuilder {
146    fn default() -> Self {
147        Self {
148            length: None,
149            maj_length: None,
150            min_length: None,
151            kernel: Kernel::Auto,
152        }
153    }
154}
155
156impl SamaBuilder {
157    #[inline(always)]
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    #[inline(always)]
163    pub fn length(mut self, val: usize) -> Self {
164        self.length = Some(val);
165        self
166    }
167
168    #[inline(always)]
169    pub fn maj_length(mut self, val: usize) -> Self {
170        self.maj_length = Some(val);
171        self
172    }
173
174    #[inline(always)]
175    pub fn min_length(mut self, val: usize) -> Self {
176        self.min_length = Some(val);
177        self
178    }
179
180    #[inline(always)]
181    pub fn kernel(mut self, k: Kernel) -> Self {
182        self.kernel = k;
183        self
184    }
185
186    #[inline(always)]
187    pub fn apply(self, c: &Candles) -> Result<SamaOutput, SamaError> {
188        let p = SamaParams {
189            length: self.length,
190            maj_length: self.maj_length,
191            min_length: self.min_length,
192        };
193        let i = SamaInput::from_candles(c, "close", p);
194        sama_with_kernel(&i, self.kernel)
195    }
196
197    #[inline(always)]
198    pub fn apply_slice(self, d: &[f64]) -> Result<SamaOutput, SamaError> {
199        let p = SamaParams {
200            length: self.length,
201            maj_length: self.maj_length,
202            min_length: self.min_length,
203        };
204        let i = SamaInput::from_slice(d, p);
205        sama_with_kernel(&i, self.kernel)
206    }
207
208    #[inline(always)]
209    pub fn into_stream(self) -> Result<SamaStream, SamaError> {
210        let p = SamaParams {
211            length: self.length,
212            maj_length: self.maj_length,
213            min_length: self.min_length,
214        };
215        SamaStream::try_new(p)
216    }
217}
218
219#[derive(Debug, Error)]
220pub enum SamaError {
221    #[error("sama: Input data slice is empty.")]
222    EmptyInputData,
223
224    #[error("sama: All values are NaN.")]
225    AllValuesNaN,
226
227    #[error("sama: Invalid period: period = {period}, data length = {data_len}")]
228    InvalidPeriod { period: usize, data_len: usize },
229
230    #[error("sama: Not enough valid data: needed = {needed}, valid = {valid}")]
231    NotEnoughValidData { needed: usize, valid: usize },
232
233    #[error("sama: Output length mismatch: expected = {expected}, got = {got}")]
234    OutputLengthMismatch { expected: usize, got: usize },
235
236    #[error("sama: Invalid range expansion: start = {start}, end = {end}, step = {step}")]
237    InvalidRange {
238        start: usize,
239        end: usize,
240        step: usize,
241    },
242
243    #[error("sama: Invalid kernel for batch path: {0:?}")]
244    InvalidKernelForBatch(Kernel),
245}
246
247#[inline(always)]
248pub fn sama(input: &SamaInput) -> Result<SamaOutput, SamaError> {
249    sama_with_kernel(input, Kernel::Auto)
250}
251
252#[inline(always)]
253pub fn sama_with_kernel(input: &SamaInput, kernel: Kernel) -> Result<SamaOutput, SamaError> {
254    let (data, length, maj_length, min_length, first, chosen) = sama_prepare(input, kernel)?;
255
256    let mut out = alloc_with_nan_prefix(data.len(), first);
257    sama_compute_into(
258        data, length, maj_length, min_length, first, chosen, &mut out,
259    );
260    Ok(SamaOutput { values: out })
261}
262
263#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
264pub fn sama_into(input: &SamaInput, out: &mut [f64]) -> Result<(), SamaError> {
265    let (data, length, maj_length, min_length, first, chosen) = sama_prepare(input, Kernel::Auto)?;
266
267    if out.len() != data.len() {
268        return Err(SamaError::OutputLengthMismatch {
269            expected: data.len(),
270            got: out.len(),
271        });
272    }
273
274    let warm = first.min(out.len());
275    for i in 0..warm {
276        out[i] = f64::from_bits(0x7ff8_0000_0000_0000);
277    }
278
279    sama_compute_into(data, length, maj_length, min_length, first, chosen, out);
280
281    Ok(())
282}
283
284#[inline(always)]
285pub fn sama_into_slice(dst: &mut [f64], input: &SamaInput, kern: Kernel) -> Result<(), SamaError> {
286    let (data, length, maj_length, min_length, first, chosen) = sama_prepare(input, kern)?;
287
288    if dst.len() != data.len() {
289        return Err(SamaError::OutputLengthMismatch {
290            expected: data.len(),
291            got: dst.len(),
292        });
293    }
294
295    sama_compute_into(data, length, maj_length, min_length, first, chosen, dst);
296
297    for v in &mut dst[..first] {
298        *v = f64::NAN;
299    }
300    Ok(())
301}
302
303#[inline(always)]
304fn sama_prepare<'a>(
305    input: &'a SamaInput,
306    kernel: Kernel,
307) -> Result<(&'a [f64], usize, usize, usize, usize, Kernel), SamaError> {
308    let data: &[f64] = input.as_ref();
309    let n = data.len();
310
311    if n == 0 {
312        return Err(SamaError::EmptyInputData);
313    }
314
315    let first = data
316        .iter()
317        .position(|x| !x.is_nan())
318        .ok_or(SamaError::AllValuesNaN)?;
319
320    let length = input.get_length();
321    let maj_length = input.get_maj_length();
322    let min_length = input.get_min_length();
323
324    if length + 1 > n || length == 0 {
325        return Err(SamaError::InvalidPeriod {
326            period: length,
327            data_len: n,
328        });
329    }
330
331    if maj_length == 0 || min_length == 0 {
332        return Err(SamaError::InvalidPeriod {
333            period: 0,
334            data_len: n,
335        });
336    }
337
338    let valid = n - first;
339
340    if valid < 1 {
341        return Err(SamaError::NotEnoughValidData { needed: 1, valid });
342    }
343
344    let chosen = match kernel {
345        Kernel::Auto => detect_best_kernel(),
346        k => k,
347    };
348
349    Ok((data, length, maj_length, min_length, first, chosen))
350}
351
352fn sama_compute_into(
353    data: &[f64],
354    length: usize,
355    maj_length: usize,
356    min_length: usize,
357    first: usize,
358    kernel: Kernel,
359    out: &mut [f64],
360) {
361    unsafe {
362        #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
363        {
364            if matches!(kernel, Kernel::Scalar | Kernel::ScalarBatch) {
365                sama_simd128(data, length, maj_length, min_length, first, out);
366                return;
367            }
368        }
369
370        match kernel {
371            Kernel::Scalar | Kernel::ScalarBatch => {
372                sama_scalar(data, length, maj_length, min_length, first, out)
373            }
374            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
375            Kernel::Avx2 | Kernel::Avx2Batch => {
376                sama_avx2(data, length, maj_length, min_length, first, out)
377            }
378            #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
379            Kernel::Avx512 | Kernel::Avx512Batch => {
380                sama_avx512(data, length, maj_length, min_length, first, out)
381            }
382            #[cfg(not(all(feature = "nightly-avx", target_arch = "x86_64")))]
383            Kernel::Avx2 | Kernel::Avx2Batch | Kernel::Avx512 | Kernel::Avx512Batch => {
384                sama_scalar(data, length, maj_length, min_length, first, out)
385            }
386            _ => unreachable!(),
387        }
388    }
389}
390
391#[inline]
392pub fn sama_scalar(
393    data: &[f64],
394    length: usize,
395    maj_length: usize,
396    min_length: usize,
397    first: usize,
398    out: &mut [f64],
399) {
400    let n = data.len();
401    if n == 0 {
402        return;
403    }
404
405    let maj_alpha = 2.0 / (maj_length as f64 + 1.0);
406    let min_alpha = 2.0 / (min_length as f64 + 1.0);
407    let delta = min_alpha - maj_alpha;
408
409    let cap = length + 1;
410    let mut max_idx = vec![0usize; cap];
411    let mut min_idx = vec![0usize; cap];
412    let mut max_head = 0usize;
413    let mut min_head = 0usize;
414    let mut max_len = 0usize;
415    let mut min_len = 0usize;
416
417    let mut sama_val = f64::NAN;
418
419    for i in first..n {
420        let p = data[i];
421        if p.is_nan() {
422            out[i] = f64::NAN;
423            continue;
424        }
425
426        let wstart = i.saturating_sub(length);
427
428        while max_len > 0 {
429            let idx = max_idx[max_head];
430            if idx >= wstart {
431                break;
432            }
433            max_head += 1;
434            if max_head == cap {
435                max_head = 0;
436            }
437            max_len -= 1;
438        }
439
440        while min_len > 0 {
441            let idx = min_idx[min_head];
442            if idx >= wstart {
443                break;
444            }
445            min_head += 1;
446            if min_head == cap {
447                min_head = 0;
448            }
449            min_len -= 1;
450        }
451
452        while max_len > 0 {
453            let mut last_pos = max_head + max_len - 1;
454            if last_pos >= cap {
455                last_pos -= cap;
456            }
457            let last_idx = max_idx[last_pos];
458            let last_val = data[last_idx];
459            if last_val <= p {
460                max_len -= 1;
461            } else {
462                break;
463            }
464        }
465        let mut ins_pos_max = max_head + max_len;
466        if ins_pos_max >= cap {
467            ins_pos_max -= cap;
468        }
469        max_idx[ins_pos_max] = i;
470        max_len += 1;
471
472        while min_len > 0 {
473            let mut last_pos = min_head + min_len - 1;
474            if last_pos >= cap {
475                last_pos -= cap;
476            }
477            let last_idx = min_idx[last_pos];
478            let last_val = data[last_idx];
479            if last_val >= p {
480                min_len -= 1;
481            } else {
482                break;
483            }
484        }
485        let mut ins_pos_min = min_head + min_len;
486        if ins_pos_min >= cap {
487            ins_pos_min -= cap;
488        }
489        min_idx[ins_pos_min] = i;
490        min_len += 1;
491
492        let hh = data[max_idx[max_head]];
493        let ll = data[min_idx[min_head]];
494
495        let denom = hh - ll;
496        let c = (p + p) - (hh + ll);
497        let mult = if denom > 0.0 { c.abs() / denom } else { 0.0 };
498
499        let a = mult.mul_add(delta, maj_alpha);
500        let alpha = a * a;
501
502        if sama_val.is_nan() {
503            sama_val = p;
504        } else {
505            let diff = p - sama_val;
506            sama_val = diff.mul_add(alpha, sama_val);
507        }
508
509        out[i] = sama_val;
510    }
511}
512
513#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
514#[inline]
515pub fn sama_avx2(
516    data: &[f64],
517    length: usize,
518    maj_length: usize,
519    min_length: usize,
520    first: usize,
521    out: &mut [f64],
522) {
523    sama_scalar(data, length, maj_length, min_length, first, out);
524}
525
526#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
527#[inline]
528pub fn sama_avx512(
529    data: &[f64],
530    length: usize,
531    maj_length: usize,
532    min_length: usize,
533    first: usize,
534    out: &mut [f64],
535) {
536    sama_scalar(data, length, maj_length, min_length, first, out);
537}
538
539#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
540#[inline]
541pub fn sama_simd128(
542    data: &[f64],
543    length: usize,
544    maj_length: usize,
545    min_length: usize,
546    first: usize,
547    out: &mut [f64],
548) {
549    sama_scalar(data, length, maj_length, min_length, first, out);
550}
551
552#[derive(Debug, Clone)]
553pub struct SamaBatchRange {
554    pub length: (usize, usize, usize),
555    pub maj_length: (usize, usize, usize),
556    pub min_length: (usize, usize, usize),
557}
558
559#[derive(Debug, Clone)]
560pub struct SamaBatchOutput {
561    pub values: Vec<f64>,
562    pub combos: Vec<SamaParams>,
563    pub rows: usize,
564    pub cols: usize,
565}
566
567impl SamaBatchOutput {
568    #[inline]
569    pub fn row_for_params(&self, p: &SamaParams) -> Option<usize> {
570        self.combos.iter().position(|c| {
571            c.length.unwrap_or(200) == p.length.unwrap_or(200)
572                && c.maj_length.unwrap_or(14) == p.maj_length.unwrap_or(14)
573                && c.min_length.unwrap_or(6) == p.min_length.unwrap_or(6)
574        })
575    }
576
577    #[inline]
578    pub fn values_for(&self, p: &SamaParams) -> Option<&[f64]> {
579        self.row_for_params(p).map(|row| {
580            let start = row * self.cols;
581            &self.values[start..start + self.cols]
582        })
583    }
584}
585
586#[derive(Clone, Debug, Default)]
587pub struct SamaBatchBuilder {
588    range: SamaBatchRange,
589    kernel: Kernel,
590}
591
592impl SamaBatchBuilder {
593    pub fn new() -> Self {
594        Self::default()
595    }
596
597    pub fn kernel(mut self, k: Kernel) -> Self {
598        self.kernel = k;
599        self
600    }
601
602    #[inline]
603    pub fn length_range(mut self, start: usize, end: usize, step: usize) -> Self {
604        self.range.length = (start, end, step);
605        self
606    }
607
608    #[inline]
609    pub fn maj_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
610        self.range.maj_length = (start, end, step);
611        self
612    }
613
614    #[inline]
615    pub fn min_length_range(mut self, start: usize, end: usize, step: usize) -> Self {
616        self.range.min_length = (start, end, step);
617        self
618    }
619
620    #[inline]
621    pub fn length_static(mut self, v: usize) -> Self {
622        self.range.length = (v, v, 0);
623        self
624    }
625
626    #[inline]
627    pub fn maj_length_static(mut self, v: usize) -> Self {
628        self.range.maj_length = (v, v, 0);
629        self
630    }
631
632    #[inline]
633    pub fn min_length_static(mut self, v: usize) -> Self {
634        self.range.min_length = (v, v, 0);
635        self
636    }
637
638    pub fn apply_slice(self, data: &[f64]) -> Result<SamaBatchOutput, SamaError> {
639        sama_batch_with_kernel(data, &self.range, self.kernel)
640    }
641
642    pub fn apply_candles(self, c: &Candles, src: &str) -> Result<SamaBatchOutput, SamaError> {
643        self.apply_slice(source_type(c, src))
644    }
645
646    pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<SamaBatchOutput, SamaError> {
647        SamaBatchBuilder::new().kernel(k).apply_slice(data)
648    }
649
650    pub fn with_default_candles(c: &Candles) -> Result<SamaBatchOutput, SamaError> {
651        SamaBatchBuilder::new()
652            .kernel(Kernel::Auto)
653            .length_static(200)
654            .maj_length_static(14)
655            .min_length_static(6)
656            .apply_candles(c, "close")
657    }
658}
659
660impl Default for SamaBatchRange {
661    fn default() -> Self {
662        Self {
663            length: (200, 449, 1),
664            maj_length: (14, 14, 0),
665            min_length: (6, 6, 0),
666        }
667    }
668}
669
670#[inline(always)]
671fn axis_usize((start, end, step): (usize, usize, usize)) -> Result<Vec<usize>, SamaError> {
672    if start == end || step == 0 {
673        return Ok(vec![start]);
674    }
675    if start < end {
676        if step == 0 {
677            return Ok(vec![start]);
678        }
679        let mut v = Vec::new();
680        let mut x = start;
681        while x <= end {
682            v.push(x);
683            match x.checked_add(step) {
684                Some(nx) => x = nx,
685                None => break,
686            }
687        }
688        if v.is_empty() {
689            return Err(SamaError::InvalidRange { start, end, step });
690        }
691        Ok(v)
692    } else {
693        if step == 0 {
694            return Ok(vec![start]);
695        }
696        let mut v = Vec::new();
697        let mut x = start;
698        loop {
699            v.push(x);
700            if x <= end {
701                break;
702            }
703            x = x.saturating_sub(step);
704            if x == 0 && end > 0 && x < end {
705                break;
706            }
707            if v.len() > 1 && *v.last().unwrap() == x {
708                break;
709            }
710            if x == 0 && end == 0 {
711                break;
712            }
713            if x < end {
714                break;
715            }
716        }
717        if v.is_empty() {
718            return Err(SamaError::InvalidRange { start, end, step });
719        }
720        Ok(v)
721    }
722}
723
724#[inline(always)]
725fn expand_grid_sama(r: &SamaBatchRange) -> Result<Vec<SamaParams>, SamaError> {
726    let lens = axis_usize(r.length)?;
727    let maj = axis_usize(r.maj_length)?;
728    let min = axis_usize(r.min_length)?;
729    if lens.is_empty() || maj.is_empty() || min.is_empty() {
730        return Err(SamaError::InvalidRange {
731            start: 0,
732            end: 0,
733            step: 0,
734        });
735    }
736    let mut out = Vec::with_capacity(
737        lens.len()
738            .saturating_mul(maj.len())
739            .saturating_mul(min.len()),
740    );
741    for &l in &lens {
742        for &j in &maj {
743            for &m in &min {
744                out.push(SamaParams {
745                    length: Some(l),
746                    maj_length: Some(j),
747                    min_length: Some(m),
748                });
749            }
750        }
751    }
752    Ok(out)
753}
754
755pub fn sama_batch_with_kernel(
756    data: &[f64],
757    sweep: &SamaBatchRange,
758    k: Kernel,
759) -> Result<SamaBatchOutput, SamaError> {
760    let kernel = match k {
761        Kernel::Auto => detect_best_batch_kernel(),
762        other if other.is_batch() => other,
763        other => return Err(SamaError::InvalidKernelForBatch(other)),
764    };
765
766    let simd = match kernel {
767        Kernel::Avx512Batch => Kernel::Avx512,
768        Kernel::Avx2Batch => Kernel::Avx2,
769        Kernel::ScalarBatch => Kernel::Scalar,
770        _ => unreachable!(),
771    };
772
773    sama_batch_par_slice(data, sweep, simd)
774}
775
776#[inline(always)]
777pub fn sama_batch_par_slice(
778    data: &[f64],
779    sweep: &SamaBatchRange,
780    kern: Kernel,
781) -> Result<SamaBatchOutput, SamaError> {
782    sama_batch_inner(data, sweep, kern, true)
783}
784
785#[inline(always)]
786pub fn sama_batch_slice(
787    data: &[f64],
788    sweep: &SamaBatchRange,
789    kern: Kernel,
790) -> Result<SamaBatchOutput, SamaError> {
791    sama_batch_inner(data, sweep, kern, false)
792}
793
794#[inline(always)]
795fn sama_batch_inner(
796    data: &[f64],
797    sweep: &SamaBatchRange,
798    kern: Kernel,
799    parallel: bool,
800) -> Result<SamaBatchOutput, SamaError> {
801    let combos = expand_grid_sama(sweep)?;
802    let cols = data.len();
803    let rows = combos.len();
804    if cols == 0 {
805        return Err(SamaError::EmptyInputData);
806    }
807
808    let total = rows.checked_mul(cols).ok_or(SamaError::InvalidRange {
809        start: rows,
810        end: cols,
811        step: 0,
812    })?;
813
814    let mut buf_mu = make_uninit_matrix(rows, cols);
815
816    let first = data
817        .iter()
818        .position(|x| !x.is_nan())
819        .ok_or(SamaError::AllValuesNaN)?;
820    let warm: Vec<usize> = combos.iter().map(|_| first).collect();
821    init_matrix_prefixes(&mut buf_mu, cols, &warm);
822
823    let mut guard = core::mem::ManuallyDrop::new(buf_mu);
824    let out: &mut [f64] =
825        unsafe { core::slice::from_raw_parts_mut(guard.as_mut_ptr() as *mut f64, guard.len()) };
826
827    sama_batch_inner_into(data, &combos, first, kern, parallel, out)?;
828
829    let values = unsafe {
830        Vec::from_raw_parts(
831            guard.as_mut_ptr() as *mut f64,
832            guard.len(),
833            guard.capacity(),
834        )
835    };
836
837    Ok(SamaBatchOutput {
838        values,
839        combos,
840        rows,
841        cols,
842    })
843}
844
845#[inline(always)]
846fn sama_batch_inner_into(
847    data: &[f64],
848    combos: &[SamaParams],
849    first: usize,
850    kern: Kernel,
851    parallel: bool,
852    out_flat: &mut [f64],
853) -> Result<(), SamaError> {
854    if combos.is_empty() {
855        return Ok(());
856    }
857
858    let rows = combos.len();
859    let cols = data.len();
860
861    if cols - first < 1 {
862        return Err(SamaError::NotEnoughValidData {
863            needed: 1,
864            valid: cols - first,
865        });
866    }
867
868    use std::collections::HashMap;
869    let mut uniq_lengths: Vec<usize> = Vec::new();
870    uniq_lengths.reserve(combos.len());
871    for prm in combos.iter() {
872        let l = prm.length.unwrap_or(200);
873        if !uniq_lengths.contains(&l) {
874            uniq_lengths.push(l);
875        }
876    }
877
878    fn build_rolling_extrema(data: &[f64], length: usize) -> (Vec<f64>, Vec<f64>) {
879        let n = data.len();
880        let cap = length + 1;
881        let mut max_idx = vec![0usize; cap];
882        let mut min_idx = vec![0usize; cap];
883        let mut max_head = 0usize;
884        let mut min_head = 0usize;
885        let mut max_len = 0usize;
886        let mut min_len = 0usize;
887        let mut hh = vec![f64::NAN; n];
888        let mut ll = vec![f64::NAN; n];
889
890        for i in 0..n {
891            let p = data[i];
892            let wstart = i.saturating_sub(length);
893
894            while max_len > 0 {
895                let idx = max_idx[max_head];
896                if idx >= wstart {
897                    break;
898                }
899                max_head += 1;
900                if max_head == cap {
901                    max_head = 0;
902                }
903                max_len -= 1;
904            }
905            while min_len > 0 {
906                let idx = min_idx[min_head];
907                if idx >= wstart {
908                    break;
909                }
910                min_head += 1;
911                if min_head == cap {
912                    min_head = 0;
913                }
914                min_len -= 1;
915            }
916
917            if !p.is_nan() {
918                while max_len > 0 {
919                    let last_pos = (max_head + max_len - 1) % cap;
920                    let last_idx = max_idx[last_pos];
921                    if data[last_idx] <= p {
922                        max_len -= 1;
923                    } else {
924                        break;
925                    }
926                }
927                let ins_pos_max = (max_head + max_len) % cap;
928                max_idx[ins_pos_max] = i;
929                max_len += 1;
930
931                while min_len > 0 {
932                    let last_pos = (min_head + min_len - 1) % cap;
933                    let last_idx = min_idx[last_pos];
934                    if data[last_idx] >= p {
935                        min_len -= 1;
936                    } else {
937                        break;
938                    }
939                }
940                let ins_pos_min = (min_head + min_len) % cap;
941                min_idx[ins_pos_min] = i;
942                min_len += 1;
943            }
944
945            if max_len > 0 && min_len > 0 {
946                hh[i] = data[max_idx[max_head]];
947                ll[i] = data[min_idx[min_head]];
948            }
949        }
950        (hh, ll)
951    }
952
953    let mut hh_map: HashMap<usize, Vec<f64>> = HashMap::with_capacity(uniq_lengths.len());
954    let mut ll_map: HashMap<usize, Vec<f64>> = HashMap::with_capacity(uniq_lengths.len());
955    for &l in &uniq_lengths {
956        let (hh, ll) = build_rolling_extrema(data, l);
957        hh_map.insert(l, hh);
958        ll_map.insert(l, ll);
959    }
960
961    let do_row = |row: usize, row_dst: &mut [f64]| {
962        let prm = &combos[row];
963        let length = prm.length.unwrap_or(200);
964        let maj_length = prm.maj_length.unwrap_or(14);
965        let min_length = prm.min_length.unwrap_or(6);
966
967        let maj_alpha = 2.0 / (maj_length as f64 + 1.0);
968        let min_alpha = 2.0 / (min_length as f64 + 1.0);
969        let delta = min_alpha - maj_alpha;
970
971        let hh = &hh_map[&length];
972        let ll = &ll_map[&length];
973
974        let mut sama_val = f64::NAN;
975        for i in first..cols {
976            let p = data[i];
977            if p.is_nan() {
978                row_dst[i] = f64::NAN;
979                continue;
980            }
981
982            let h = hh[i];
983            let l = ll[i];
984            let denom = h - l;
985            let c = (p + p) - (h + l);
986            let mult = if denom > 0.0 { c.abs() / denom } else { 0.0 };
987            let a = mult.mul_add(delta, maj_alpha);
988            let alpha = a * a;
989
990            if sama_val.is_nan() {
991                sama_val = p;
992            } else {
993                let diff = p - sama_val;
994                sama_val = diff.mul_add(alpha, sama_val);
995            }
996            row_dst[i] = sama_val;
997        }
998    };
999
1000    if parallel {
1001        #[cfg(not(target_arch = "wasm32"))]
1002        {
1003            use rayon::prelude::*;
1004            out_flat
1005                .par_chunks_mut(cols)
1006                .enumerate()
1007                .for_each(|(row, dst)| do_row(row, dst));
1008        }
1009        #[cfg(target_arch = "wasm32")]
1010        {
1011            for (row, dst) in out_flat.chunks_mut(cols).enumerate() {
1012                do_row(row, dst);
1013            }
1014        }
1015    } else {
1016        for (row, dst) in out_flat.chunks_mut(cols).enumerate() {
1017            do_row(row, dst);
1018        }
1019    }
1020
1021    Ok(())
1022}
1023
1024#[derive(Debug, Clone)]
1025pub struct SamaStream {
1026    length: usize,
1027    maj_length: usize,
1028    min_length: usize,
1029
1030    maj_alpha: f64,
1031    min_alpha: f64,
1032    delta: f64,
1033
1034    cap: usize,
1035    buf: Vec<f64>,
1036    head: usize,
1037    tick: usize,
1038
1039    max_idx: Vec<usize>,
1040    min_idx: Vec<usize>,
1041    max_head: usize,
1042    min_head: usize,
1043    max_len: usize,
1044    min_len: usize,
1045
1046    sama_val: f64,
1047}
1048
1049impl SamaStream {
1050    pub fn try_new(params: SamaParams) -> Result<Self, SamaError> {
1051        let length = params.length.unwrap_or(200);
1052        let maj = params.maj_length.unwrap_or(14);
1053        let min = params.min_length.unwrap_or(6);
1054        if length == 0 || maj == 0 || min == 0 {
1055            return Err(SamaError::InvalidPeriod {
1056                period: 0,
1057                data_len: 0,
1058            });
1059        }
1060
1061        let cap = length + 1;
1062        let maj_alpha = 2.0 / (maj as f64 + 1.0);
1063        let min_alpha = 2.0 / (min as f64 + 1.0);
1064
1065        Ok(Self {
1066            length,
1067            maj_length: maj,
1068            min_length: min,
1069
1070            maj_alpha,
1071            min_alpha,
1072            delta: min_alpha - maj_alpha,
1073
1074            cap,
1075            buf: vec![f64::NAN; cap],
1076            head: 0,
1077            tick: 0,
1078
1079            max_idx: vec![0usize; cap],
1080            min_idx: vec![0usize; cap],
1081            max_head: 0,
1082            min_head: 0,
1083            max_len: 0,
1084            min_len: 0,
1085
1086            sama_val: f64::NAN,
1087        })
1088    }
1089
1090    #[inline]
1091    pub fn update(&mut self, value: f64) -> Option<f64> {
1092        let t = self.tick;
1093        let cap = self.cap;
1094
1095        let oldest = t.saturating_sub(self.length);
1096
1097        while self.max_len > 0 {
1098            let idx = self.max_idx[self.max_head];
1099            if idx >= oldest {
1100                break;
1101            }
1102            self.max_head += 1;
1103            if self.max_head == cap {
1104                self.max_head = 0;
1105            }
1106            self.max_len -= 1;
1107        }
1108        while self.min_len > 0 {
1109            let idx = self.min_idx[self.min_head];
1110            if idx >= oldest {
1111                break;
1112            }
1113            self.min_head += 1;
1114            if self.min_head == cap {
1115                self.min_head = 0;
1116            }
1117            self.min_len -= 1;
1118        }
1119
1120        self.buf[self.head] = value;
1121
1122        if value.is_nan() {
1123            self.head += 1;
1124            if self.head == cap {
1125                self.head = 0;
1126            }
1127            self.tick = t.wrapping_add(1);
1128            return Some(f64::NAN);
1129        }
1130
1131        while self.max_len > 0 {
1132            let back_pos = (self.max_head + self.max_len - 1) % cap;
1133            let back_idx = self.max_idx[back_pos];
1134            let back_val = self.buf[back_idx % cap];
1135            if back_val <= value {
1136                self.max_len -= 1;
1137            } else {
1138                break;
1139            }
1140        }
1141        let ins_pos_max = (self.max_head + self.max_len) % cap;
1142        self.max_idx[ins_pos_max] = t;
1143        self.max_len += 1;
1144
1145        while self.min_len > 0 {
1146            let back_pos = (self.min_head + self.min_len - 1) % cap;
1147            let back_idx = self.min_idx[back_pos];
1148            let back_val = self.buf[back_idx % cap];
1149            if back_val >= value {
1150                self.min_len -= 1;
1151            } else {
1152                break;
1153            }
1154        }
1155        let ins_pos_min = (self.min_head + self.min_len) % cap;
1156        self.min_idx[ins_pos_min] = t;
1157        self.min_len += 1;
1158
1159        let hh = self.buf[self.max_idx[self.max_head] % cap];
1160        let ll = self.buf[self.min_idx[self.min_head] % cap];
1161
1162        let denom = hh - ll;
1163        let c = (value + value) - (hh + ll);
1164        let mult = if denom > 0.0 { c.abs() / denom } else { 0.0 };
1165
1166        let a = mult.mul_add(self.delta, self.maj_alpha);
1167        let alpha = a * a;
1168
1169        if self.sama_val.is_nan() {
1170            self.sama_val = value;
1171        } else {
1172            let diff = value - self.sama_val;
1173            self.sama_val = diff.mul_add(alpha, self.sama_val);
1174        }
1175
1176        self.head += 1;
1177        if self.head == cap {
1178            self.head = 0;
1179        }
1180        self.tick = t.wrapping_add(1);
1181
1182        Some(self.sama_val)
1183    }
1184
1185    #[inline]
1186    pub fn next(&mut self, value: f64) -> f64 {
1187        self.update(value).unwrap_or(f64::NAN)
1188    }
1189
1190    #[inline]
1191    pub fn reset(&mut self) {
1192        self.buf.fill(f64::NAN);
1193        self.head = 0;
1194        self.tick = 0;
1195        self.max_len = 0;
1196        self.min_len = 0;
1197        self.max_head = 0;
1198        self.min_head = 0;
1199        self.sama_val = f64::NAN;
1200    }
1201}
1202
1203#[cfg(feature = "python")]
1204#[pyfunction(name = "sama")]
1205#[pyo3(signature = (data, length, maj_length, min_length, kernel=None))]
1206pub fn sama_py<'py>(
1207    py: Python<'py>,
1208    data: PyReadonlyArray1<'py, f64>,
1209    length: usize,
1210    maj_length: usize,
1211    min_length: usize,
1212    kernel: Option<&str>,
1213) -> PyResult<Bound<'py, PyArray1<f64>>> {
1214    let slice_in = data.as_slice()?;
1215    let kern = validate_kernel(kernel, false)?;
1216    let params = SamaParams {
1217        length: Some(length),
1218        maj_length: Some(maj_length),
1219        min_length: Some(min_length),
1220    };
1221    let input = SamaInput::from_slice(slice_in, params);
1222    let result_vec: Vec<f64> = py
1223        .allow_threads(|| sama_with_kernel(&input, kern).map(|o| o.values))
1224        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1225    Ok(result_vec.into_pyarray(py))
1226}
1227
1228#[cfg(feature = "python")]
1229#[pyfunction(name = "sama_batch")]
1230#[pyo3(signature = (data, length_range, maj_length_range, min_length_range, kernel=None))]
1231pub fn sama_batch_py<'py>(
1232    py: Python<'py>,
1233    data: PyReadonlyArray1<'py, f64>,
1234    length_range: (usize, usize, usize),
1235    maj_length_range: (usize, usize, usize),
1236    min_length_range: (usize, usize, usize),
1237    kernel: Option<&str>,
1238) -> PyResult<Bound<'py, PyDict>> {
1239    use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1240    use pyo3::types::PyDict;
1241
1242    let slice_in = data.as_slice()?;
1243    let sweep = SamaBatchRange {
1244        length: length_range,
1245        maj_length: maj_length_range,
1246        min_length: min_length_range,
1247    };
1248
1249    let combos = expand_grid_sama(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1250    let rows = combos.len();
1251    let cols = slice_in.len();
1252
1253    let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1254    let slice_out = unsafe { out_arr.as_slice_mut()? };
1255
1256    let kern = validate_kernel(kernel, true)?;
1257
1258    let combos = py
1259        .allow_threads(|| {
1260            let mapped = match kern {
1261                Kernel::Auto => detect_best_batch_kernel(),
1262                k => k,
1263            };
1264            let simd = match mapped {
1265                Kernel::Avx512Batch => Kernel::Avx512,
1266                Kernel::Avx2Batch => Kernel::Avx2,
1267                Kernel::ScalarBatch => Kernel::Scalar,
1268                _ => Kernel::Scalar,
1269            };
1270
1271            let first = slice_in
1272                .iter()
1273                .position(|x| !x.is_nan())
1274                .ok_or(SamaError::AllValuesNaN)?;
1275            sama_batch_inner_into(slice_in, &combos, first, simd, true, slice_out).map(|_| combos)
1276        })
1277        .map_err(|e| PyValueError::new_err(e.to_string()))?;
1278
1279    let dict = PyDict::new(py);
1280    dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1281    dict.set_item(
1282        "lengths",
1283        combos
1284            .iter()
1285            .map(|p| p.length.unwrap_or(200) as u64)
1286            .collect::<Vec<_>>()
1287            .into_pyarray(py),
1288    )?;
1289    dict.set_item(
1290        "maj_lengths",
1291        combos
1292            .iter()
1293            .map(|p| p.maj_length.unwrap_or(14) as u64)
1294            .collect::<Vec<_>>()
1295            .into_pyarray(py),
1296    )?;
1297    dict.set_item(
1298        "min_lengths",
1299        combos
1300            .iter()
1301            .map(|p| p.min_length.unwrap_or(6) as u64)
1302            .collect::<Vec<_>>()
1303            .into_pyarray(py),
1304    )?;
1305    Ok(dict.into())
1306}
1307
1308#[cfg(all(feature = "python", feature = "cuda"))]
1309#[pyfunction(name = "sama_cuda_batch_dev")]
1310#[pyo3(signature = (data_f32, length_range=(200, 200, 0), maj_length_range=(14, 14, 0), min_length_range=(6, 6, 0), device_id=0))]
1311pub fn sama_cuda_batch_dev_py(
1312    py: Python<'_>,
1313    data_f32: PyReadonlyArray1<'_, f32>,
1314    length_range: (usize, usize, usize),
1315    maj_length_range: (usize, usize, usize),
1316    min_length_range: (usize, usize, usize),
1317    device_id: usize,
1318) -> PyResult<DeviceArrayF32SamaPy> {
1319    if !cuda_available() {
1320        return Err(PyValueError::new_err("CUDA not available"));
1321    }
1322
1323    let slice_in = data_f32.as_slice()?;
1324    let sweep = SamaBatchRange {
1325        length: length_range,
1326        maj_length: maj_length_range,
1327        min_length: min_length_range,
1328    };
1329
1330    let inner = py.allow_threads(|| {
1331        let cuda = CudaSama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1332        cuda.sama_batch_dev(slice_in, &sweep)
1333            .map_err(|e| PyValueError::new_err(e.to_string()))
1334    })?;
1335
1336    Ok(DeviceArrayF32SamaPy { inner })
1337}
1338
1339#[cfg(all(feature = "python", feature = "cuda"))]
1340#[pyfunction(name = "sama_cuda_many_series_one_param_dev")]
1341#[pyo3(signature = (data_tm_f32, length, maj_length, min_length, device_id=0))]
1342pub fn sama_cuda_many_series_one_param_dev_py(
1343    py: Python<'_>,
1344    data_tm_f32: PyReadonlyArray2<'_, f32>,
1345    length: usize,
1346    maj_length: usize,
1347    min_length: usize,
1348    device_id: usize,
1349) -> PyResult<DeviceArrayF32SamaPy> {
1350    if !cuda_available() {
1351        return Err(PyValueError::new_err("CUDA not available"));
1352    }
1353    if length == 0 || maj_length == 0 || min_length == 0 {
1354        return Err(PyValueError::new_err(
1355            "length, maj_length, and min_length must be positive",
1356        ));
1357    }
1358
1359    let flat = data_tm_f32.as_slice()?;
1360    let shape = data_tm_f32.shape();
1361    let series_len = shape[0];
1362    let num_series = shape[1];
1363    let params = SamaParams {
1364        length: Some(length),
1365        maj_length: Some(maj_length),
1366        min_length: Some(min_length),
1367    };
1368
1369    let inner = py.allow_threads(|| {
1370        let cuda = CudaSama::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1371        cuda.sama_many_series_one_param_time_major_dev(flat, num_series, series_len, &params)
1372            .map_err(|e| PyValueError::new_err(e.to_string()))
1373    })?;
1374
1375    Ok(DeviceArrayF32SamaPy { inner })
1376}
1377
1378#[cfg(feature = "python")]
1379#[pyclass(name = "SamaStream")]
1380pub struct SamaStreamPy {
1381    stream: SamaStream,
1382}
1383
1384#[cfg(feature = "python")]
1385#[pymethods]
1386impl SamaStreamPy {
1387    #[new]
1388    fn new(length: usize, maj_length: usize, min_length: usize) -> PyResult<Self> {
1389        let params = SamaParams {
1390            length: Some(length),
1391            maj_length: Some(maj_length),
1392            min_length: Some(min_length),
1393        };
1394        let stream =
1395            SamaStream::try_new(params).map_err(|e| PyValueError::new_err(e.to_string()))?;
1396        Ok(SamaStreamPy { stream })
1397    }
1398
1399    fn update(&mut self, value: f64) -> Option<f64> {
1400        self.stream.update(value)
1401    }
1402}
1403
1404#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1405#[wasm_bindgen]
1406pub fn sama_js(
1407    data: &[f64],
1408    length: usize,
1409    maj_length: usize,
1410    min_length: usize,
1411) -> Result<Vec<f64>, JsValue> {
1412    let params = SamaParams {
1413        length: Some(length),
1414        maj_length: Some(maj_length),
1415        min_length: Some(min_length),
1416    };
1417    let input = SamaInput::from_slice(data, params);
1418    let mut out = vec![0.0; data.len()];
1419    sama_into_slice(&mut out, &input, detect_best_kernel())
1420        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1421    Ok(out)
1422}
1423
1424#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1425#[derive(Serialize, Deserialize)]
1426pub struct SamaBatchConfig {
1427    pub length_range: (usize, usize, usize),
1428    pub maj_length_range: (usize, usize, usize),
1429    pub min_length_range: (usize, usize, usize),
1430}
1431
1432#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1433#[derive(Serialize, Deserialize)]
1434pub struct SamaBatchJsOutput {
1435    pub values: Vec<f64>,
1436    pub combos: Vec<SamaParams>,
1437    pub rows: usize,
1438    pub cols: usize,
1439}
1440
1441#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1442#[wasm_bindgen(js_name = sama_batch)]
1443pub fn sama_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1444    let cfg: SamaBatchConfig = serde_wasm_bindgen::from_value(config)
1445        .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1446    let sweep = SamaBatchRange {
1447        length: cfg.length_range,
1448        maj_length: cfg.maj_length_range,
1449        min_length: cfg.min_length_range,
1450    };
1451    let out = sama_batch_inner(data, &sweep, detect_best_kernel(), false)
1452        .map_err(|e| JsValue::from_str(&e.to_string()))?;
1453    serde_wasm_bindgen::to_value(&SamaBatchJsOutput {
1454        values: out.values,
1455        combos: out.combos,
1456        rows: out.rows,
1457        cols: out.cols,
1458    })
1459    .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1460}
1461
1462#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1463#[wasm_bindgen]
1464pub fn sama_alloc(len: usize) -> *mut f64 {
1465    let mut v = Vec::<f64>::with_capacity(len);
1466    let p = v.as_mut_ptr();
1467    std::mem::forget(v);
1468    p
1469}
1470
1471#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1472#[wasm_bindgen]
1473pub fn sama_free(ptr: *mut f64, len: usize) {
1474    unsafe {
1475        let _ = Vec::from_raw_parts(ptr, len, len);
1476    }
1477}
1478
1479#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1480#[wasm_bindgen]
1481pub fn sama_into(
1482    in_ptr: *const f64,
1483    out_ptr: *mut f64,
1484    len: usize,
1485    length: usize,
1486    maj_length: usize,
1487    min_length: usize,
1488) -> Result<(), JsValue> {
1489    if in_ptr.is_null() || out_ptr.is_null() {
1490        return Err(JsValue::from_str("null pointer"));
1491    }
1492
1493    unsafe {
1494        let data = std::slice::from_raw_parts(in_ptr, len);
1495        let params = SamaParams {
1496            length: Some(length),
1497            maj_length: Some(maj_length),
1498            min_length: Some(min_length),
1499        };
1500        let input = SamaInput::from_slice(data, params);
1501
1502        if in_ptr == out_ptr {
1503            let mut tmp = vec![0.0; len];
1504            sama_into_slice(&mut tmp, &input, detect_best_kernel())
1505                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1506            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1507            out.copy_from_slice(&tmp);
1508        } else {
1509            let out = std::slice::from_raw_parts_mut(out_ptr, len);
1510
1511            sama_into_slice(out, &input, detect_best_kernel())
1512                .map_err(|e| JsValue::from_str(&e.to_string()))?;
1513        }
1514    }
1515    Ok(())
1516}
1517
1518#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1519#[wasm_bindgen]
1520pub fn sama_batch_into(
1521    in_ptr: *const f64,
1522    out_ptr: *mut f64,
1523    len: usize,
1524    length_start: usize,
1525    length_end: usize,
1526    length_step: usize,
1527    maj_start: usize,
1528    maj_end: usize,
1529    maj_step: usize,
1530    min_start: usize,
1531    min_end: usize,
1532    min_step: usize,
1533) -> Result<usize, JsValue> {
1534    if in_ptr.is_null() || out_ptr.is_null() {
1535        return Err(JsValue::from_str("null pointer"));
1536    }
1537    unsafe {
1538        let data = std::slice::from_raw_parts(in_ptr, len);
1539        let sweep = SamaBatchRange {
1540            length: (length_start, length_end, length_step),
1541            maj_length: (maj_start, maj_end, maj_step),
1542            min_length: (min_start, min_end, min_step),
1543        };
1544        let combos = expand_grid_sama(&sweep).map_err(|_| JsValue::from_str("Invalid range"))?;
1545        let rows = combos.len();
1546        let cols = len;
1547        let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
1548        let first = data
1549            .iter()
1550            .position(|x| !x.is_nan())
1551            .ok_or_else(|| JsValue::from_str("All NaN"))?;
1552        sama_batch_inner_into(data, &combos, first, detect_best_kernel(), false, out)
1553            .map_err(|e| JsValue::from_str(&e.to_string()))?;
1554        Ok(rows)
1555    }
1556}
1557
1558#[cfg(test)]
1559mod tests {
1560    use super::*;
1561    use crate::skip_if_unsupported;
1562    use crate::utilities::data_loader::read_candles_from_csv;
1563    use std::error::Error;
1564
1565    fn check_sama_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1566        skip_if_unsupported!(kernel, test_name);
1567        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1568        let candles = read_candles_from_csv(file_path)?;
1569
1570        let input = SamaInput::from_candles(&candles, "close", SamaParams::default());
1571        let result = sama_with_kernel(&input, kernel)?;
1572
1573        assert_eq!(result.values.len(), candles.close.len());
1574
1575        let valid_count = result.values.iter().filter(|&&v| !v.is_nan()).count();
1576        assert!(
1577            valid_count > 0,
1578            "[{}] SAMA should produce valid values",
1579            test_name
1580        );
1581
1582        Ok(())
1583    }
1584
1585    #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1586    #[test]
1587    fn test_sama_into_matches_api() -> Result<(), Box<dyn Error>> {
1588        let mut data = Vec::with_capacity(256);
1589        for _ in 0..5 {
1590            data.push(f64::NAN);
1591        }
1592        for i in 0..251 {
1593            let x = (i as f64).sin() * 0.12345 + 50.0 + ((i % 11) as f64) * 0.001;
1594            data.push(x);
1595        }
1596
1597        let input = SamaInput::from_slice(&data, SamaParams::default());
1598        let baseline = sama(&input)?.values;
1599
1600        let mut out = vec![0.0; data.len()];
1601        sama_into(&input, &mut out)?;
1602
1603        assert_eq!(baseline.len(), out.len());
1604        fn eq_or_both_nan(a: f64, b: f64) -> bool {
1605            (a.is_nan() && b.is_nan()) || (a == b) || (a - b).abs() <= 1e-12
1606        }
1607        for (i, (&a, &b)) in baseline.iter().zip(out.iter()).enumerate() {
1608            assert!(
1609                eq_or_both_nan(a, b),
1610                "mismatch at index {}: api={} into={}",
1611                i,
1612                a,
1613                b
1614            );
1615        }
1616
1617        Ok(())
1618    }
1619
1620    fn check_sama_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1621        skip_if_unsupported!(kernel, test_name);
1622        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1623        let candles = read_candles_from_csv(file_path)?;
1624
1625        let default_params = SamaParams {
1626            length: None,
1627            maj_length: None,
1628            min_length: None,
1629        };
1630        let input = SamaInput::from_candles(&candles, "close", default_params);
1631        let output = sama_with_kernel(&input, kernel)?;
1632        assert_eq!(output.values.len(), candles.close.len());
1633
1634        Ok(())
1635    }
1636
1637    fn check_sama_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1638        skip_if_unsupported!(kernel, test_name);
1639        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1640        let candles = read_candles_from_csv(file_path)?;
1641
1642        let input = SamaInput::with_default_candles(&candles);
1643        match input.data {
1644            SamaData::Candles { source, .. } => assert_eq!(source, "close"),
1645            _ => panic!("Expected SamaData::Candles"),
1646        }
1647        let output = sama_with_kernel(&input, kernel)?;
1648        assert_eq!(output.values.len(), candles.close.len());
1649
1650        Ok(())
1651    }
1652
1653    fn check_sama_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1654        skip_if_unsupported!(kernel, test_name);
1655        let input_data = [10.0, 20.0, 30.0];
1656        let params = SamaParams {
1657            length: Some(0),
1658            maj_length: None,
1659            min_length: None,
1660        };
1661        let input = SamaInput::from_slice(&input_data, params);
1662        let res = sama_with_kernel(&input, kernel);
1663        assert!(
1664            res.is_err(),
1665            "[{}] SAMA should fail with zero period",
1666            test_name
1667        );
1668        Ok(())
1669    }
1670
1671    fn check_sama_period_exceeds_length(
1672        test_name: &str,
1673        kernel: Kernel,
1674    ) -> Result<(), Box<dyn Error>> {
1675        skip_if_unsupported!(kernel, test_name);
1676        let data_small = [10.0, 20.0, 30.0];
1677        let params = SamaParams {
1678            length: Some(10),
1679            maj_length: None,
1680            min_length: None,
1681        };
1682        let input = SamaInput::from_slice(&data_small, params);
1683        let res = sama_with_kernel(&input, kernel);
1684        assert!(
1685            res.is_err(),
1686            "[{}] SAMA should fail with period exceeding length",
1687            test_name
1688        );
1689        Ok(())
1690    }
1691
1692    fn check_sama_very_small_dataset(
1693        test_name: &str,
1694        kernel: Kernel,
1695    ) -> Result<(), Box<dyn Error>> {
1696        skip_if_unsupported!(kernel, test_name);
1697        let single_point = [42.0];
1698        let params = SamaParams::default();
1699        let input = SamaInput::from_slice(&single_point, params);
1700        let res = sama_with_kernel(&input, kernel);
1701        assert!(
1702            res.is_err(),
1703            "[{}] SAMA should fail with insufficient data",
1704            test_name
1705        );
1706        Ok(())
1707    }
1708
1709    fn check_sama_empty_input(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1710        skip_if_unsupported!(kernel, test_name);
1711        let empty: [f64; 0] = [];
1712        let params = SamaParams::default();
1713        let input = SamaInput::from_slice(&empty, params);
1714        let res = sama_with_kernel(&input, kernel);
1715        assert!(
1716            matches!(res, Err(SamaError::EmptyInputData)),
1717            "[{}] SAMA should fail with empty input",
1718            test_name
1719        );
1720        Ok(())
1721    }
1722
1723    fn check_sama_all_nan(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1724        skip_if_unsupported!(kernel, test_name);
1725        let nan_data = [f64::NAN, f64::NAN, f64::NAN];
1726        let params = SamaParams::default();
1727        let input = SamaInput::from_slice(&nan_data, params);
1728        let res = sama_with_kernel(&input, kernel);
1729        assert!(
1730            matches!(res, Err(SamaError::AllValuesNaN)),
1731            "[{}] SAMA should fail with all NaN values",
1732            test_name
1733        );
1734        Ok(())
1735    }
1736
1737    fn check_sama_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1738        skip_if_unsupported!(kernel, test_name);
1739        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1740        let candles = read_candles_from_csv(file_path)?;
1741
1742        let first_params = SamaParams {
1743            length: Some(50),
1744            maj_length: Some(14),
1745            min_length: Some(6),
1746        };
1747        let first_input = SamaInput::from_candles(&candles, "close", first_params);
1748        let first_result = sama_with_kernel(&first_input, kernel)?;
1749
1750        let second_params = SamaParams {
1751            length: Some(50),
1752            maj_length: Some(14),
1753            min_length: Some(6),
1754        };
1755        let second_input = SamaInput::from_slice(&first_result.values, second_params);
1756        let second_result = sama_with_kernel(&second_input, kernel)?;
1757
1758        assert_eq!(second_result.values.len(), first_result.values.len());
1759        Ok(())
1760    }
1761
1762    fn check_sama_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1763        skip_if_unsupported!(kernel, test_name);
1764        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1765        let candles = read_candles_from_csv(file_path)?;
1766
1767        let params = SamaParams::default();
1768        let input = SamaInput::from_candles(&candles, "close", params);
1769        let output = sama_with_kernel(&input, kernel)?;
1770
1771        let first_valid = candles.close.iter().position(|x| !x.is_nan()).unwrap_or(0);
1772        let warmup = first_valid + input.get_length();
1773
1774        for (i, &val) in output
1775            .values
1776            .iter()
1777            .enumerate()
1778            .skip(warmup.min(output.values.len()))
1779        {
1780            assert!(
1781                !val.is_nan(),
1782                "[{}] Unexpected NaN at index {}",
1783                test_name,
1784                i
1785            );
1786        }
1787        Ok(())
1788    }
1789
1790    fn check_sama_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1791        skip_if_unsupported!(kernel, test_name);
1792        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1793        let candles = read_candles_from_csv(file_path)?;
1794
1795        let params = SamaParams::default();
1796
1797        let batch_input = SamaInput::from_candles(&candles, "close", params.clone());
1798        let batch_result = sama_with_kernel(&batch_input, kernel)?;
1799
1800        let mut stream = SamaStream::try_new(params)?;
1801        let mut stream_results = Vec::with_capacity(candles.close.len());
1802        for &price in &candles.close {
1803            stream_results.push(stream.next(price));
1804        }
1805
1806        assert_eq!(batch_result.values.len(), stream_results.len());
1807
1808        for (i, (&batch_val, &stream_val)) in batch_result
1809            .values
1810            .iter()
1811            .zip(stream_results.iter())
1812            .enumerate()
1813        {
1814            if batch_val.is_nan() && stream_val.is_nan() {
1815                continue;
1816            }
1817            if !batch_val.is_nan() && !stream_val.is_nan() {
1818                let diff = (batch_val - stream_val).abs();
1819                assert!(
1820                    diff < 1e-9,
1821                    "[{}] Stream mismatch at index {}: batch={}, stream={}, diff={}",
1822                    test_name,
1823                    i,
1824                    batch_val,
1825                    stream_val,
1826                    diff
1827                );
1828            }
1829        }
1830        Ok(())
1831    }
1832
1833    fn check_batch_sweep(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1834        skip_if_unsupported!(kernel, test_name);
1835        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1836        let candles = read_candles_from_csv(file_path)?;
1837
1838        let range = SamaBatchRange {
1839            length: (190, 210, 1),
1840            maj_length: (12, 16, 1),
1841            min_length: (4, 8, 1),
1842        };
1843
1844        let output = sama_batch_with_kernel(&candles.close, &range, kernel)?;
1845
1846        let expected_count = 21 * 5 * 5;
1847        assert_eq!(
1848            output.rows, expected_count,
1849            "[{}] Expected {} batch results",
1850            test_name, expected_count
1851        );
1852        assert_eq!(
1853            output.values.len(),
1854            expected_count * candles.close.len(),
1855            "[{}] Expected flattened array size",
1856            test_name
1857        );
1858
1859        Ok(())
1860    }
1861
1862    fn check_batch_default_row(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1863        skip_if_unsupported!(kernel, test_name);
1864        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1865        let candles = read_candles_from_csv(file_path)?;
1866
1867        let range = SamaBatchRange {
1868            length: (200, 200, 0),
1869            maj_length: (14, 14, 0),
1870            min_length: (6, 6, 0),
1871        };
1872
1873        let output = sama_batch_with_kernel(&candles.close, &range, kernel)?;
1874
1875        assert_eq!(
1876            output.rows, 1,
1877            "[{}] Should have 1 result for default params",
1878            test_name
1879        );
1880        assert_eq!(output.cols, candles.close.len());
1881
1882        Ok(())
1883    }
1884
1885    macro_rules! generate_all_sama_tests {
1886        ($($test_fn:ident),*) => {
1887            paste::paste! {
1888                $(
1889                    #[test]
1890                    fn [<$test_fn _scalar>]() -> Result<(), Box<dyn Error>> {
1891                        $test_fn(stringify!([<$test_fn _scalar>]), Kernel::Scalar)
1892                    }
1893                )*
1894
1895                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1896                $(
1897                    #[test]
1898                    fn [<$test_fn _avx2>]() -> Result<(), Box<dyn Error>> {
1899                        $test_fn(stringify!([<$test_fn _avx2>]), Kernel::Avx2)
1900                    }
1901
1902                    #[test]
1903                    fn [<$test_fn _avx512>]() -> Result<(), Box<dyn Error>> {
1904                        $test_fn(stringify!([<$test_fn _avx512>]), Kernel::Avx512)
1905                    }
1906                )*
1907
1908                #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
1909                $(
1910                    #[test]
1911                    fn [<$test_fn _simd128>]() -> Result<(), Box<dyn Error>> {
1912                        $test_fn(stringify!([<$test_fn _simd128>]), Kernel::Scalar)
1913                    }
1914                )*
1915            }
1916        }
1917    }
1918
1919    macro_rules! gen_batch_tests {
1920        ($fn_name:ident) => {
1921            paste::paste! {
1922                #[test]
1923                fn [<$fn_name _scalar>]() -> Result<(), Box<dyn Error>> {
1924                    $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch)
1925                }
1926
1927                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1928                #[test]
1929                fn [<$fn_name _avx2>]() -> Result<(), Box<dyn Error>> {
1930                    $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch)
1931                }
1932
1933                #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1934                #[test]
1935                fn [<$fn_name _avx512>]() -> Result<(), Box<dyn Error>> {
1936                    $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch)
1937                }
1938
1939                #[test]
1940                fn [<$fn_name _auto_detect>]() -> Result<(), Box<dyn Error>> {
1941                    $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto)
1942                }
1943            }
1944        };
1945    }
1946
1947    generate_all_sama_tests!(
1948        check_sama_accuracy,
1949        check_sama_partial_params,
1950        check_sama_default_candles,
1951        check_sama_zero_period,
1952        check_sama_period_exceeds_length,
1953        check_sama_very_small_dataset,
1954        check_sama_empty_input,
1955        check_sama_all_nan,
1956        check_sama_reinput,
1957        check_sama_nan_handling,
1958        check_sama_streaming
1959    );
1960
1961    gen_batch_tests!(check_batch_sweep);
1962    gen_batch_tests!(check_batch_default_row);
1963
1964    #[cfg(debug_assertions)]
1965    fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1966        skip_if_unsupported!(kernel, test);
1967
1968        let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1969        let c = read_candles_from_csv(file)?;
1970
1971        let cfgs = vec![
1972            (190, 200, 1, 12, 14, 1, 4, 8, 1),
1973            (200, 200, 0, 14, 14, 0, 6, 6, 0),
1974            (195, 205, 5, 10, 16, 2, 3, 9, 2),
1975        ];
1976
1977        for (ls, le, lstep, js, je, jstep, ms, me, mstep) in cfgs {
1978            let out = SamaBatchRange {
1979                length: (ls, le, lstep),
1980                maj_length: (js, je, jstep),
1981                min_length: (ms, me, mstep),
1982            };
1983            let res = sama_batch_with_kernel(&c.close, &out, kernel)?;
1984            for &v in &res.values {
1985                if v.is_nan() {
1986                    continue;
1987                }
1988                let b = v.to_bits();
1989                assert_ne!(b, 0x1111_1111_1111_1111);
1990                assert_ne!(b, 0x2222_2222_2222_2222);
1991                assert_ne!(b, 0x3333_3333_3333_3333);
1992            }
1993        }
1994        Ok(())
1995    }
1996
1997    #[cfg(debug_assertions)]
1998    gen_batch_tests!(check_batch_no_poison);
1999
2000    #[cfg(not(debug_assertions))]
2001    fn check_batch_no_poison_scalar() -> Result<(), Box<dyn Error>> {
2002        Ok(())
2003    }
2004
2005    #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
2006    #[test]
2007    fn test_sama_simd128_correctness() {
2008        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
2009        let params = SamaParams::default();
2010        let input = SamaInput::from_slice(&data, params);
2011        let scalar = sama_with_kernel(&input, Kernel::Scalar).unwrap();
2012        let simd = sama_with_kernel(&input, Kernel::Scalar).unwrap();
2013        assert_eq!(scalar.values.len(), simd.values.len());
2014        for (a, b) in scalar.values.iter().zip(simd.values.iter()) {
2015            assert!((a - b).abs() < 1e-10);
2016        }
2017    }
2018
2019    #[cfg(debug_assertions)]
2020    #[test]
2021    fn test_sama_no_poison_values() -> Result<(), Box<dyn Error>> {
2022        let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
2023        let candles = read_candles_from_csv(file_path)?;
2024
2025        let input = SamaInput::from_candles(&candles, "close", SamaParams::default());
2026        let output = sama(&input)?;
2027
2028        for &v in &output.values {
2029            if v.is_nan() {
2030                continue;
2031            }
2032            let b = v.to_bits();
2033
2034            assert_ne!(
2035                b, 0x11111111_11111111,
2036                "Found poison value 0x11111111_11111111"
2037            );
2038            assert_ne!(
2039                b, 0x22222222_22222222,
2040                "Found poison value 0x22222222_22222222"
2041            );
2042            assert_ne!(
2043                b, 0x33333333_33333333,
2044                "Found poison value 0x33333333_33333333"
2045            );
2046            assert_ne!(
2047                b, 0xDEADBEEF_DEADBEEF,
2048                "Found poison value 0xDEADBEEF_DEADBEEF"
2049            );
2050            assert_ne!(
2051                b, 0xFEEEFEEE_FEEEFEEE,
2052                "Found poison value 0xFEEEFEEE_FEEEFEEE"
2053            );
2054        }
2055        Ok(())
2056    }
2057
2058    #[test]
2059    fn test_sama_stream_incremental() -> Result<(), Box<dyn Error>> {
2060        let params = SamaParams {
2061            length: Some(10),
2062            maj_length: Some(5),
2063            min_length: Some(3),
2064        };
2065
2066        let mut stream = SamaStream::try_new(params)?;
2067        let data = vec![
2068            10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0,
2069        ];
2070
2071        let mut results = Vec::new();
2072        for &val in &data {
2073            let result = stream.next(val);
2074            if !result.is_nan() {
2075                results.push(result);
2076            }
2077        }
2078
2079        assert!(
2080            !results.is_empty(),
2081            "Stream should produce results immediately"
2082        );
2083
2084        Ok(())
2085    }
2086
2087    #[test]
2088    fn sama_pine_parity_head_start() -> Result<(), Box<dyn Error>> {
2089        let mut long = vec![0.0; 5000];
2090        for i in 0..long.len() {
2091            long[i] = (i as f64).sin() + (i as f64 * 0.01).cos();
2092        }
2093
2094        let pine_params = SamaParams::default();
2095        let pine_out = SamaInput::from_slice(&long, pine_params.clone());
2096        let a = sama_with_kernel(&pine_out, Kernel::Scalar)?.values;
2097
2098        let tail = &long[2000..];
2099        let pine_like_tail = SamaInput::from_slice(tail, pine_params);
2100        let b = sama_with_kernel(&pine_like_tail, Kernel::Scalar)?.values;
2101
2102        let tol = 0.1;
2103        for (i, (&x, &y)) in a[2000..].iter().zip(b.iter()).enumerate().skip(100) {
2104            if x.is_finite() && y.is_finite() {
2105                assert!((x - y).abs() < tol, "i={}, |Δ|={}", i, (x - y).abs());
2106            }
2107        }
2108
2109        Ok(())
2110    }
2111}
2112
2113#[cfg(all(feature = "proptest", not(target_arch = "wasm32")))]
2114mod prop_tests {
2115    use super::*;
2116    use proptest::prelude::*;
2117
2118    proptest! {
2119        #[test]
2120        fn sama_properties(data in prop::collection::vec(-1e6f64..1e6, 5..400),
2121                           len in 2usize..64,
2122                           maj in 2usize..64,
2123                           min in 2usize..64) {
2124
2125
2126            if data.len() <= len {
2127                return Ok(());
2128            }
2129
2130            let params = SamaParams {
2131                length: Some(len),
2132                maj_length: Some(maj),
2133                min_length: Some(min),
2134            };
2135            let input = SamaInput::from_slice(&data, params);
2136            let SamaOutput { values: out } = sama_with_kernel(&input, Kernel::Scalar).unwrap();
2137
2138
2139            let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2140            let warm = first + len;
2141            for i in warm..data.len() {
2142                let wstart = i - len;
2143                let window = &data[wstart..=i];
2144                if window.iter().all(|v| v.is_finite()) {
2145                    let y = out[i];
2146
2147
2148
2149                    prop_assert!(
2150                        y.is_finite(),
2151                        "Output {} at index {} is not finite",
2152                        y, i
2153                    );
2154                }
2155            }
2156        }
2157    }
2158}
2159
2160#[cfg(all(feature = "python", feature = "cuda"))]
2161#[pyclass(module = "ta_indicators.cuda", unsendable)]
2162pub struct DeviceArrayF32SamaPy {
2163    pub(crate) inner: DeviceArrayF32Sama,
2164}
2165
2166#[cfg(all(feature = "python", feature = "cuda"))]
2167#[pymethods]
2168impl DeviceArrayF32SamaPy {
2169    #[getter]
2170    fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
2171        let d = PyDict::new(py);
2172        d.set_item("shape", (self.inner.rows, self.inner.cols))?;
2173        d.set_item("typestr", "<f4")?;
2174
2175        d.set_item(
2176            "strides",
2177            (
2178                self.inner.cols * std::mem::size_of::<f32>(),
2179                std::mem::size_of::<f32>(),
2180            ),
2181        )?;
2182
2183        let len = self.inner.rows.checked_mul(self.inner.cols).unwrap_or(0);
2184        let ptr_usize = if len == 0 {
2185            0usize
2186        } else {
2187            self.inner.device_ptr() as usize
2188        };
2189        d.set_item("data", (ptr_usize, false))?;
2190
2191        d.set_item("version", 3)?;
2192        Ok(d)
2193    }
2194
2195    fn __dlpack_device__(&self) -> (i32, i32) {
2196        (2, self.inner.device_id as i32)
2197    }
2198
2199    #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
2200    fn __dlpack__<'py>(
2201        &mut self,
2202        py: Python<'py>,
2203        stream: Option<pyo3::PyObject>,
2204        max_version: Option<pyo3::PyObject>,
2205        dl_device: Option<pyo3::PyObject>,
2206        copy: Option<pyo3::PyObject>,
2207    ) -> PyResult<PyObject> {
2208        let (kdl, alloc_dev) = self.__dlpack_device__();
2209        if let Some(dev_obj) = dl_device.as_ref() {
2210            if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
2211                if dev_ty != kdl || dev_id != alloc_dev {
2212                    let wants_copy = copy
2213                        .as_ref()
2214                        .and_then(|c| c.extract::<bool>(py).ok())
2215                        .unwrap_or(false);
2216                    if wants_copy {
2217                        return Err(PyValueError::new_err(
2218                            "device copy not implemented for __dlpack__",
2219                        ));
2220                    } else {
2221                        return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
2222                    }
2223                }
2224            }
2225        }
2226
2227        let _ = stream;
2228
2229        let dummy =
2230            DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
2231        let device_id = self.inner.device_id;
2232        let ctx = self.inner.ctx.clone();
2233        let inner = std::mem::replace(
2234            &mut self.inner,
2235            DeviceArrayF32Sama {
2236                buf: dummy,
2237                rows: 0,
2238                cols: 0,
2239                ctx,
2240                device_id,
2241            },
2242        );
2243
2244        let rows = inner.rows;
2245        let cols = inner.cols;
2246        let buf = inner.buf;
2247
2248        let max_version_bound = max_version.map(|obj| obj.into_bound(py));
2249
2250        export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
2251    }
2252}