1use crate::utilities::data_loader::{source_type, Candles};
2use crate::utilities::enums::Kernel;
3use crate::utilities::helpers::{
4 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
5 make_uninit_matrix,
6};
7#[cfg(feature = "python")]
8use crate::utilities::kernel_validation::validate_kernel;
9use aligned_vec::{AVec, CACHELINE_ALIGN};
10#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
11use core::arch::x86_64::*;
12#[cfg(feature = "python")]
13use pyo3::exceptions::PyValueError;
14#[cfg(feature = "python")]
15use pyo3::prelude::*;
16#[cfg(not(target_arch = "wasm32"))]
17use rayon::prelude::*;
18#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
19use serde::{Deserialize, Serialize};
20use std::convert::AsRef;
21use std::error::Error;
22use thiserror::Error;
23#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
24use wasm_bindgen::prelude::*;
25
26impl<'a> AsRef<[f64]> for LinearRegInterceptInput<'a> {
27 #[inline(always)]
28 fn as_ref(&self) -> &[f64] {
29 match &self.data {
30 LinearRegInterceptData::Slice(slice) => slice,
31 LinearRegInterceptData::Candles { candles, source } => source_type(candles, source),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
37pub enum LinearRegInterceptData<'a> {
38 Candles {
39 candles: &'a Candles,
40 source: &'a str,
41 },
42 Slice(&'a [f64]),
43}
44
45#[derive(Debug, Clone)]
46pub struct LinearRegInterceptOutput {
47 pub values: Vec<f64>,
48}
49
50#[derive(Debug, Clone)]
51#[cfg_attr(
52 all(target_arch = "wasm32", feature = "wasm"),
53 derive(serde::Serialize, serde::Deserialize)
54)]
55pub struct LinearRegInterceptParams {
56 pub period: Option<usize>,
57}
58
59impl Default for LinearRegInterceptParams {
60 fn default() -> Self {
61 Self { period: Some(14) }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct LinearRegInterceptInput<'a> {
67 pub data: LinearRegInterceptData<'a>,
68 pub params: LinearRegInterceptParams,
69}
70
71impl<'a> LinearRegInterceptInput<'a> {
72 #[inline]
73 pub fn from_candles(c: &'a Candles, s: &'a str, p: LinearRegInterceptParams) -> Self {
74 Self {
75 data: LinearRegInterceptData::Candles {
76 candles: c,
77 source: s,
78 },
79 params: p,
80 }
81 }
82 #[inline]
83 pub fn from_slice(sl: &'a [f64], p: LinearRegInterceptParams) -> Self {
84 Self {
85 data: LinearRegInterceptData::Slice(sl),
86 params: p,
87 }
88 }
89 #[inline]
90 pub fn with_default_candles(c: &'a Candles) -> Self {
91 Self::from_candles(c, "close", LinearRegInterceptParams::default())
92 }
93 #[inline]
94 pub fn get_period(&self) -> usize {
95 self.params.period.unwrap_or(14)
96 }
97}
98
99#[derive(Copy, Clone, Debug)]
100pub struct LinearRegInterceptBuilder {
101 period: Option<usize>,
102 kernel: Kernel,
103}
104
105impl Default for LinearRegInterceptBuilder {
106 fn default() -> Self {
107 Self {
108 period: None,
109 kernel: Kernel::Auto,
110 }
111 }
112}
113
114impl LinearRegInterceptBuilder {
115 #[inline(always)]
116 pub fn new() -> Self {
117 Self::default()
118 }
119 #[inline(always)]
120 pub fn period(mut self, n: usize) -> Self {
121 self.period = Some(n);
122 self
123 }
124 #[inline(always)]
125 pub fn kernel(mut self, k: Kernel) -> Self {
126 self.kernel = k;
127 self
128 }
129 #[inline(always)]
130 pub fn apply(self, c: &Candles) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
131 let p = LinearRegInterceptParams {
132 period: self.period,
133 };
134 let i = LinearRegInterceptInput::from_candles(c, "close", p);
135 linearreg_intercept_with_kernel(&i, self.kernel)
136 }
137 #[inline(always)]
138 pub fn apply_slice(
139 self,
140 d: &[f64],
141 ) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
142 let p = LinearRegInterceptParams {
143 period: self.period,
144 };
145 let i = LinearRegInterceptInput::from_slice(d, p);
146 linearreg_intercept_with_kernel(&i, self.kernel)
147 }
148 #[inline(always)]
149 pub fn into_stream(self) -> Result<LinearRegInterceptStream, LinearRegInterceptError> {
150 let p = LinearRegInterceptParams {
151 period: self.period,
152 };
153 LinearRegInterceptStream::try_new(p)
154 }
155}
156
157#[derive(Debug, Error)]
158pub enum LinearRegInterceptError {
159 #[error("linearreg_intercept: Input data slice is empty.")]
160 EmptyInputData,
161 #[error("linearreg_intercept: All values are NaN.")]
162 AllValuesNaN,
163 #[error("linearreg_intercept: Invalid period: period = {period}, data length = {data_len}")]
164 InvalidPeriod { period: usize, data_len: usize },
165 #[error("linearreg_intercept: Not enough valid data: needed = {needed}, valid = {valid}")]
166 NotEnoughValidData { needed: usize, valid: usize },
167 #[error("linearreg_intercept: Output length mismatch: expected {expected}, got {got}")]
168 OutputLengthMismatch { expected: usize, got: usize },
169 #[error("linearreg_intercept: Invalid range: start={start}, end={end}, step={step}")]
170 InvalidRange {
171 start: String,
172 end: String,
173 step: String,
174 },
175 #[error("linearreg_intercept: Invalid kernel for batch: {0:?}")]
176 InvalidKernelForBatch(Kernel),
177}
178
179#[inline]
180pub fn linearreg_intercept(
181 input: &LinearRegInterceptInput,
182) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
183 linearreg_intercept_with_kernel(input, Kernel::Auto)
184}
185
186pub fn linearreg_intercept_with_kernel(
187 input: &LinearRegInterceptInput,
188 kernel: Kernel,
189) -> Result<LinearRegInterceptOutput, LinearRegInterceptError> {
190 let data: &[f64] = input.as_ref();
191
192 if data.is_empty() {
193 return Err(LinearRegInterceptError::EmptyInputData);
194 }
195
196 let first = data
197 .iter()
198 .position(|x| !x.is_nan())
199 .ok_or(LinearRegInterceptError::AllValuesNaN)?;
200 let len = data.len();
201 let period = input.get_period();
202
203 if period == 0 || period > len {
204 return Err(LinearRegInterceptError::InvalidPeriod {
205 period,
206 data_len: len,
207 });
208 }
209 if (len - first) < period {
210 return Err(LinearRegInterceptError::NotEnoughValidData {
211 needed: period,
212 valid: len - first,
213 });
214 }
215
216 let mut out = alloc_with_nan_prefix(len, first + period - 1);
217
218 let chosen = match kernel {
219 Kernel::Auto => Kernel::Scalar,
220 other => other,
221 };
222
223 unsafe {
224 match chosen {
225 Kernel::Scalar | Kernel::ScalarBatch => {
226 linearreg_intercept_scalar(data, period, first, &mut out)
227 }
228 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
229 Kernel::Avx2 | Kernel::Avx2Batch => {
230 linearreg_intercept_avx2(data, period, first, &mut out)
231 }
232 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
233 Kernel::Avx512 | Kernel::Avx512Batch => {
234 linearreg_intercept_avx512(data, period, first, &mut out)
235 }
236 _ => unreachable!(),
237 }
238 }
239
240 Ok(LinearRegInterceptOutput { values: out })
241}
242
243#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
244#[inline]
245pub fn linearreg_intercept_into(
246 input: &LinearRegInterceptInput,
247 dst: &mut [f64],
248) -> Result<(), LinearRegInterceptError> {
249 let data: &[f64] = input.as_ref();
250
251 if data.is_empty() {
252 return Err(LinearRegInterceptError::EmptyInputData);
253 }
254
255 let first = data
256 .iter()
257 .position(|x| !x.is_nan())
258 .ok_or(LinearRegInterceptError::AllValuesNaN)?;
259 let len = data.len();
260 let period = input.get_period();
261
262 if period == 0 || period > len {
263 return Err(LinearRegInterceptError::InvalidPeriod {
264 period,
265 data_len: len,
266 });
267 }
268 if (len - first) < period {
269 return Err(LinearRegInterceptError::NotEnoughValidData {
270 needed: period,
271 valid: len - first,
272 });
273 }
274
275 if dst.len() != data.len() {
276 return Err(LinearRegInterceptError::OutputLengthMismatch {
277 expected: data.len(),
278 got: dst.len(),
279 });
280 }
281
282 let warmup_end = first + period - 1;
283 for v in &mut dst[..warmup_end] {
284 *v = f64::from_bits(0x7ff8_0000_0000_0000);
285 }
286
287 let chosen = detect_best_kernel();
288
289 unsafe {
290 match chosen {
291 Kernel::Scalar | Kernel::ScalarBatch => {
292 linearreg_intercept_scalar(data, period, first, dst)
293 }
294 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
295 Kernel::Avx2 | Kernel::Avx2Batch => linearreg_intercept_avx2(data, period, first, dst),
296 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
297 Kernel::Avx512 | Kernel::Avx512Batch => {
298 linearreg_intercept_avx512(data, period, first, dst)
299 }
300 _ => unreachable!(),
301 }
302 }
303
304 Ok(())
305}
306
307#[inline]
308pub fn linearreg_intercept_into_slice(
309 dst: &mut [f64],
310 input: &LinearRegInterceptInput,
311 kern: Kernel,
312) -> Result<(), LinearRegInterceptError> {
313 let data: &[f64] = input.as_ref();
314
315 if data.is_empty() {
316 return Err(LinearRegInterceptError::EmptyInputData);
317 }
318
319 let first = data
320 .iter()
321 .position(|x| !x.is_nan())
322 .ok_or(LinearRegInterceptError::AllValuesNaN)?;
323 let len = data.len();
324 let period = input.get_period();
325
326 if period == 0 || period > len {
327 return Err(LinearRegInterceptError::InvalidPeriod {
328 period,
329 data_len: len,
330 });
331 }
332 if (len - first) < period {
333 return Err(LinearRegInterceptError::NotEnoughValidData {
334 needed: period,
335 valid: len - first,
336 });
337 }
338
339 if dst.len() != data.len() {
340 return Err(LinearRegInterceptError::OutputLengthMismatch {
341 expected: data.len(),
342 got: dst.len(),
343 });
344 }
345
346 let chosen = match kern {
347 Kernel::Auto => Kernel::Scalar,
348 other => other,
349 };
350
351 unsafe {
352 match chosen {
353 Kernel::Scalar | Kernel::ScalarBatch => {
354 linearreg_intercept_scalar(data, period, first, dst)
355 }
356 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
357 Kernel::Avx2 | Kernel::Avx2Batch => linearreg_intercept_avx2(data, period, first, dst),
358 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
359 Kernel::Avx512 | Kernel::Avx512Batch => {
360 linearreg_intercept_avx512(data, period, first, dst)
361 }
362 _ => unreachable!(),
363 }
364 }
365
366 let warmup_end = first + period - 1;
367 for v in &mut dst[..warmup_end] {
368 *v = f64::NAN;
369 }
370
371 Ok(())
372}
373
374#[inline]
375pub fn linearreg_intercept_scalar(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
376 if period == 1 {
377 for i in first_val..data.len() {
378 out[i] = data[i];
379 }
380 return;
381 }
382
383 let n = period as f64;
384 let inv_n = 1.0 / n;
385
386 let sum_x = 0.5_f64 * n * (n + 1.0);
387 let sum_x2 = (n * (n + 1.0) * (2.0 * n + 1.0)) / 6.0;
388 let denom = n * sum_x2 - sum_x * sum_x;
389 let bd = 1.0 / denom;
390 let k = 1.0 - sum_x * inv_n;
391
392 let start = first_val;
393 let end = data.len();
394 if end == 0 || end < start + period {
395 return;
396 }
397
398 let mut sum_y = 0.0f64;
399 let mut sum_xy = 0.0f64;
400 for j in 0..period {
401 let y = data[start + j];
402 let x = (j as f64) + 1.0;
403 sum_y += y;
404 sum_xy += y * x;
405 }
406
407 let mut i = start + period - 1;
408 out[i] = ((n * sum_xy - sum_x * sum_y) * bd) * k + sum_y * inv_n;
409
410 while i + 1 < end {
411 let y_in = data[i + 1];
412 let y_out = data[i + 1 - period];
413
414 let prev_sum_y = sum_y;
415 sum_y = prev_sum_y + y_in - y_out;
416 sum_xy = (sum_xy - prev_sum_y) + n * y_in;
417
418 i += 1;
419 out[i] = ((n * sum_xy - sum_x * sum_y) * bd) * k + sum_y * inv_n;
420 }
421}
422
423#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
424#[inline]
425pub fn linearreg_intercept_avx512(data: &[f64], period: usize, first_val: usize, out: &mut [f64]) {
426 if period <= 32 {
427 unsafe { linearreg_intercept_avx512_short(data, period, first_val, out) }
428 } else {
429 unsafe { linearreg_intercept_avx512_long(data, period, first_val, out) }
430 }
431}
432
433#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
434#[inline]
435pub unsafe fn linearreg_intercept_avx2(
436 data: &[f64],
437 period: usize,
438 first_val: usize,
439 out: &mut [f64],
440) {
441 if period == 1 {
442 let mut i = first_val;
443 let end = data.len();
444 let src = data.as_ptr();
445 let dst = out.as_mut_ptr();
446 while i < end {
447 *dst.add(i) = *src.add(i);
448 i += 1;
449 }
450 return;
451 }
452
453 let n = period as f64;
454 let inv_n = 1.0 / n;
455 let sum_x = 0.5_f64 * n * (n + 1.0);
456 let sum_x2 = (n * (n + 1.0) * (2.0 * n + 1.0)) / 6.0;
457 let denom = n.mul_add(sum_x2, -sum_x * sum_x);
458 let bd = 1.0 / denom;
459 let k = 1.0 - sum_x * inv_n;
460
461 let start = first_val;
462 let end = data.len();
463 if end == 0 || end < start + period {
464 return;
465 }
466
467 let mut sum_y = 0.0f64;
468 let mut sum_xy = 0.0f64;
469 let base = data.as_ptr().add(start);
470 let mut j = 0usize;
471 let mut x = 1.0f64;
472 while j < period {
473 let y = *base.add(j);
474 sum_y += y;
475 sum_xy = y.mul_add(x, sum_xy);
476 x += 1.0;
477 j += 1;
478 }
479
480 let mut i = start + period - 1;
481 let outp = out.as_mut_ptr();
482 let mut b = n.mul_add(sum_xy, -sum_x * sum_y) * bd;
483 *outp.add(i) = b.mul_add(k, sum_y * inv_n);
484
485 let dptr = data.as_ptr();
486 while i + 1 < end {
487 let y_in = *dptr.add(i + 1);
488 let y_out = *dptr.add(i + 1 - period);
489
490 let prev_sum_y = sum_y;
491 sum_y = prev_sum_y + y_in - y_out;
492 sum_xy = (sum_xy - prev_sum_y) + n * y_in;
493
494 i += 1;
495 b = n.mul_add(sum_xy, -sum_x * sum_y) * bd;
496 *outp.add(i) = b.mul_add(k, sum_y * inv_n);
497 }
498}
499
500#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
501#[inline]
502pub unsafe fn linearreg_intercept_avx512_short(
503 data: &[f64],
504 period: usize,
505 first_val: usize,
506 out: &mut [f64],
507) {
508 linearreg_intercept_avx2(data, period, first_val, out)
509}
510
511#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
512#[inline]
513pub unsafe fn linearreg_intercept_avx512_long(
514 data: &[f64],
515 period: usize,
516 first_val: usize,
517 out: &mut [f64],
518) {
519 linearreg_intercept_avx2(data, period, first_val, out)
520}
521
522#[derive(Debug, Clone)]
523pub struct LinearRegInterceptStream {
524 period: usize,
525 buffer: Vec<f64>,
526 head: usize,
527 filled: bool,
528 sum_x: f64,
529 sum_x2: f64,
530 n: f64,
531 bd: f64,
532 sum_y: f64,
533 sum_xy: f64,
534}
535
536impl LinearRegInterceptStream {
537 #[inline]
538 pub fn try_new(params: LinearRegInterceptParams) -> Result<Self, LinearRegInterceptError> {
539 let period = params.period.unwrap_or(14);
540 if period == 0 {
541 return Err(LinearRegInterceptError::InvalidPeriod {
542 period,
543 data_len: 0,
544 });
545 }
546
547 let n = period as f64;
548 let sum_x = 0.5_f64 * n * (n + 1.0);
549 let sum_x2 = (n * (n + 1.0) * (2.0 * n + 1.0)) / 6.0;
550 let denom = n * sum_x2 - sum_x * sum_x;
551 let bd = if period == 1 { 0.0 } else { 1.0 / denom };
552
553 Ok(Self {
554 period,
555 buffer: vec![f64::NAN; period],
556 head: 0,
557 filled: false,
558 sum_x,
559 sum_x2,
560 n,
561 bd,
562 sum_y: 0.0,
563 sum_xy: 0.0,
564 })
565 }
566
567 #[inline(always)]
568 pub fn update(&mut self, value: f64) -> Option<f64> {
569 if self.period == 1 {
570 return Some(value);
571 }
572
573 let tail = self.head;
574 let y_out = self.buffer[tail];
575
576 self.buffer[tail] = value;
577 self.head = if self.head + 1 == self.period {
578 0
579 } else {
580 self.head + 1
581 };
582
583 if !self.filled {
584 let x = (tail as f64) + 1.0;
585 self.sum_y += value;
586 self.sum_xy = value.mul_add(x, self.sum_xy);
587
588 if self.head == 0 {
589 self.filled = true;
590 } else {
591 return None;
592 }
593 } else {
594 let sum_y_old = self.sum_y;
595 self.sum_y = sum_y_old + value - y_out;
596
597 self.sum_xy = (self.sum_xy - sum_y_old) + self.n * value;
598 }
599
600 let inv_n = 1.0 / self.n;
601 let k = 1.0 - self.sum_x * inv_n;
602
603 let t = self.n.mul_add(self.sum_xy, -(self.sum_x * self.sum_y));
604 let b = t * self.bd;
605 let y = self.sum_y.mul_add(inv_n, b * k);
606 Some(y)
607 }
608}
609
610#[derive(Clone, Debug)]
611pub struct LinearRegInterceptBatchRange {
612 pub period: (usize, usize, usize),
613}
614
615impl Default for LinearRegInterceptBatchRange {
616 fn default() -> Self {
617 Self {
618 period: (14, 263, 1),
619 }
620 }
621}
622
623#[derive(Clone, Debug, Default)]
624pub struct LinearRegInterceptBatchBuilder {
625 range: LinearRegInterceptBatchRange,
626 kernel: Kernel,
627}
628
629impl LinearRegInterceptBatchBuilder {
630 pub fn new() -> Self {
631 Self::default()
632 }
633 pub fn kernel(mut self, k: Kernel) -> Self {
634 self.kernel = k;
635 self
636 }
637 #[inline]
638 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
639 self.range.period = (start, end, step);
640 self
641 }
642 #[inline]
643 pub fn period_static(mut self, p: usize) -> Self {
644 self.range.period = (p, p, 0);
645 self
646 }
647 pub fn apply_slice(
648 self,
649 data: &[f64],
650 ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
651 linearreg_intercept_batch_with_kernel(data, &self.range, self.kernel)
652 }
653 pub fn with_default_slice(
654 data: &[f64],
655 k: Kernel,
656 ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
657 LinearRegInterceptBatchBuilder::new()
658 .kernel(k)
659 .apply_slice(data)
660 }
661 pub fn apply_candles(
662 self,
663 c: &Candles,
664 src: &str,
665 ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
666 let slice = source_type(c, src);
667 self.apply_slice(slice)
668 }
669 pub fn with_default_candles(
670 c: &Candles,
671 ) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
672 LinearRegInterceptBatchBuilder::new()
673 .kernel(Kernel::Auto)
674 .apply_candles(c, "close")
675 }
676}
677
678pub fn linearreg_intercept_batch_with_kernel(
679 data: &[f64],
680 sweep: &LinearRegInterceptBatchRange,
681 k: Kernel,
682) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
683 let kernel = match k {
684 Kernel::Auto => detect_best_batch_kernel(),
685 other if other.is_batch() => other,
686 other => return Err(LinearRegInterceptError::InvalidKernelForBatch(other)),
687 };
688
689 let simd = match kernel {
690 Kernel::Avx512Batch => Kernel::Avx512,
691 Kernel::Avx2Batch => Kernel::Avx2,
692 Kernel::ScalarBatch => Kernel::Scalar,
693 _ => unreachable!(),
694 };
695 linearreg_intercept_batch_par_slice(data, sweep, simd)
696}
697
698#[derive(Clone, Debug)]
699pub struct LinearRegInterceptBatchOutput {
700 pub values: Vec<f64>,
701 pub combos: Vec<LinearRegInterceptParams>,
702 pub rows: usize,
703 pub cols: usize,
704}
705impl LinearRegInterceptBatchOutput {
706 pub fn row_for_params(&self, p: &LinearRegInterceptParams) -> Option<usize> {
707 self.combos
708 .iter()
709 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
710 }
711 pub fn values_for(&self, p: &LinearRegInterceptParams) -> Option<&[f64]> {
712 self.row_for_params(p).map(|row| {
713 let start = row * self.cols;
714 &self.values[start..start + self.cols]
715 })
716 }
717}
718
719#[inline(always)]
720fn expand_grid(
721 r: &LinearRegInterceptBatchRange,
722) -> Result<Vec<LinearRegInterceptParams>, LinearRegInterceptError> {
723 fn axis_usize(
724 (start, end, step): (usize, usize, usize),
725 ) -> Result<Vec<usize>, LinearRegInterceptError> {
726 if step == 0 || start == end {
727 return Ok(vec![start]);
728 }
729
730 let mut values = Vec::new();
731 let step_u = step;
732
733 if start <= end {
734 let mut v = start;
735 loop {
736 if v > end {
737 break;
738 }
739 values.push(v);
740 match v.checked_add(step_u) {
741 Some(next) => v = next,
742 None => break,
743 }
744 }
745 } else {
746 let mut v = start;
747 loop {
748 if v < end {
749 break;
750 }
751 values.push(v);
752 match v.checked_sub(step_u) {
753 Some(next) => v = next,
754 None => break,
755 }
756 }
757 }
758
759 if values.is_empty() {
760 return Err(LinearRegInterceptError::InvalidRange {
761 start: start.to_string(),
762 end: end.to_string(),
763 step: step.to_string(),
764 });
765 }
766
767 Ok(values)
768 }
769
770 let periods = axis_usize(r.period)?;
771
772 let mut out = Vec::with_capacity(periods.len());
773 for p in periods {
774 out.push(LinearRegInterceptParams { period: Some(p) });
775 }
776 Ok(out)
777}
778
779#[inline(always)]
780pub fn linearreg_intercept_batch_slice(
781 data: &[f64],
782 sweep: &LinearRegInterceptBatchRange,
783 kern: Kernel,
784) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
785 linearreg_intercept_batch_inner(data, sweep, kern, false)
786}
787
788#[inline(always)]
789pub fn linearreg_intercept_batch_par_slice(
790 data: &[f64],
791 sweep: &LinearRegInterceptBatchRange,
792 kern: Kernel,
793) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
794 linearreg_intercept_batch_inner(data, sweep, kern, true)
795}
796
797#[inline(always)]
798fn linearreg_intercept_batch_inner_into(
799 data: &[f64],
800 sweep: &LinearRegInterceptBatchRange,
801 kern: Kernel,
802 parallel: bool,
803 out: &mut [f64],
804) -> Result<Vec<LinearRegInterceptParams>, LinearRegInterceptError> {
805 if data.is_empty() {
806 return Err(LinearRegInterceptError::EmptyInputData);
807 }
808
809 let combos = expand_grid(sweep)?;
810
811 let first = data
812 .iter()
813 .position(|x| !x.is_nan())
814 .ok_or(LinearRegInterceptError::AllValuesNaN)?;
815 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
816 if data.len() - first < max_p {
817 return Err(LinearRegInterceptError::NotEnoughValidData {
818 needed: max_p,
819 valid: data.len() - first,
820 });
821 }
822
823 let cols = data.len();
824
825 let chosen = match kern {
826 Kernel::Auto => Kernel::Scalar,
827 other => other,
828 };
829
830 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
831 let period = combos[row].period.unwrap();
832 match chosen {
833 Kernel::Scalar | Kernel::ScalarBatch => {
834 linearreg_intercept_row_scalar(data, first, period, out_row)
835 }
836 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
837 Kernel::Avx2 | Kernel::Avx2Batch => {
838 linearreg_intercept_row_avx2(data, first, period, out_row)
839 }
840 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
841 Kernel::Avx512 | Kernel::Avx512Batch => {
842 linearreg_intercept_row_avx512(data, first, period, out_row)
843 }
844 _ => unreachable!(),
845 }
846 };
847
848 if parallel {
849 #[cfg(not(target_arch = "wasm32"))]
850 {
851 out.par_chunks_mut(cols)
852 .enumerate()
853 .for_each(|(row, slice)| do_row(row, slice));
854 }
855
856 #[cfg(target_arch = "wasm32")]
857 {
858 for (row, slice) in out.chunks_mut(cols).enumerate() {
859 do_row(row, slice);
860 }
861 }
862 } else {
863 for (row, slice) in out.chunks_mut(cols).enumerate() {
864 do_row(row, slice);
865 }
866 }
867
868 Ok(combos)
869}
870
871#[inline(always)]
872fn linearreg_intercept_batch_inner(
873 data: &[f64],
874 sweep: &LinearRegInterceptBatchRange,
875 kern: Kernel,
876 parallel: bool,
877) -> Result<LinearRegInterceptBatchOutput, LinearRegInterceptError> {
878 if data.is_empty() {
879 return Err(LinearRegInterceptError::EmptyInputData);
880 }
881
882 let combos = expand_grid(sweep)?;
883
884 let first = data
885 .iter()
886 .position(|x| !x.is_nan())
887 .ok_or(LinearRegInterceptError::AllValuesNaN)?;
888 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
889 if data.len() - first < max_p {
890 return Err(LinearRegInterceptError::NotEnoughValidData {
891 needed: max_p,
892 valid: data.len() - first,
893 });
894 }
895
896 let rows = combos.len();
897 let cols = data.len();
898
899 let total = rows
900 .checked_mul(cols)
901 .ok_or_else(|| LinearRegInterceptError::InvalidRange {
902 start: sweep.period.0.to_string(),
903 end: sweep.period.1.to_string(),
904 step: sweep.period.2.to_string(),
905 })?;
906
907 let mut buf_mu = make_uninit_matrix(rows, cols);
908
909 let warm: Vec<usize> = combos
910 .iter()
911 .map(|c| first + c.period.unwrap() - 1)
912 .collect();
913 init_matrix_prefixes(&mut buf_mu, cols, &warm);
914
915 let mut values = unsafe {
916 let ptr = buf_mu.as_mut_ptr() as *mut f64;
917 std::mem::forget(buf_mu);
918 Vec::from_raw_parts(ptr, total, total)
919 };
920
921 let chosen = match kern {
922 Kernel::Auto => Kernel::Scalar,
923 other => other,
924 };
925
926 let do_row = |row: usize, out_row: &mut [f64]| unsafe {
927 let period = combos[row].period.unwrap();
928 match chosen {
929 Kernel::Scalar | Kernel::ScalarBatch => {
930 linearreg_intercept_row_scalar(data, first, period, out_row)
931 }
932 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
933 Kernel::Avx2 | Kernel::Avx2Batch => {
934 linearreg_intercept_row_avx2(data, first, period, out_row)
935 }
936 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
937 Kernel::Avx512 | Kernel::Avx512Batch => {
938 linearreg_intercept_row_avx512(data, first, period, out_row)
939 }
940 _ => unreachable!(),
941 }
942 };
943
944 if parallel {
945 #[cfg(not(target_arch = "wasm32"))]
946 {
947 values
948 .par_chunks_mut(cols)
949 .enumerate()
950 .for_each(|(row, slice)| do_row(row, slice));
951 }
952
953 #[cfg(target_arch = "wasm32")]
954 {
955 for (row, slice) in values.chunks_mut(cols).enumerate() {
956 do_row(row, slice);
957 }
958 }
959 } else {
960 for (row, slice) in values.chunks_mut(cols).enumerate() {
961 do_row(row, slice);
962 }
963 }
964
965 Ok(LinearRegInterceptBatchOutput {
966 values,
967 combos,
968 rows,
969 cols,
970 })
971}
972
973#[inline(always)]
974unsafe fn linearreg_intercept_row_scalar(
975 data: &[f64],
976 first: usize,
977 period: usize,
978 out: &mut [f64],
979) {
980 linearreg_intercept_scalar(data, period, first, out)
981}
982
983#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
984#[inline(always)]
985unsafe fn linearreg_intercept_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
986 linearreg_intercept_avx2(data, period, first, out)
987}
988
989#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
990#[inline(always)]
991pub unsafe fn linearreg_intercept_row_avx512(
992 data: &[f64],
993 first: usize,
994 period: usize,
995 out: &mut [f64],
996) {
997 linearreg_intercept_avx512(data, period, first, out)
998}
999
1000#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1001#[inline(always)]
1002pub unsafe fn linearreg_intercept_row_avx512_short(
1003 data: &[f64],
1004 first: usize,
1005 period: usize,
1006 out: &mut [f64],
1007) {
1008 linearreg_intercept_avx512_short(data, period, first, out)
1009}
1010
1011#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1012#[inline(always)]
1013pub unsafe fn linearreg_intercept_row_avx512_long(
1014 data: &[f64],
1015 first: usize,
1016 period: usize,
1017 out: &mut [f64],
1018) {
1019 linearreg_intercept_avx512_long(data, period, first, out)
1020}
1021
1022#[inline(always)]
1023fn expand_grid_reg(r: &LinearRegInterceptBatchRange) -> Vec<LinearRegInterceptParams> {
1024 expand_grid(r).unwrap_or_else(|_| Vec::new())
1025}
1026
1027#[cfg(all(feature = "python", feature = "cuda"))]
1028use crate::cuda::moving_averages::DeviceArrayF32;
1029#[cfg(all(feature = "python", feature = "cuda"))]
1030use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1031#[cfg(all(feature = "python", feature = "cuda"))]
1032use cust::context::Context;
1033#[cfg(all(feature = "python", feature = "cuda"))]
1034use cust::memory::DeviceBuffer;
1035#[cfg(all(feature = "python", feature = "cuda"))]
1036use std::sync::Arc;
1037
1038#[cfg(all(feature = "python", feature = "cuda"))]
1039#[pyclass(module = "ta_indicators.cuda", unsendable)]
1040pub struct LinearRegInterceptDeviceArrayF32Py {
1041 pub(crate) buf: Option<DeviceBuffer<f32>>,
1042 pub(crate) rows: usize,
1043 pub(crate) cols: usize,
1044 pub(crate) ctx: Arc<Context>,
1045 pub(crate) device_id: u32,
1046}
1047
1048#[cfg(all(feature = "python", feature = "cuda"))]
1049#[pymethods]
1050impl LinearRegInterceptDeviceArrayF32Py {
1051 #[getter]
1052 fn __cuda_array_interface__<'py>(
1053 &self,
1054 py: Python<'py>,
1055 ) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1056 let d = pyo3::types::PyDict::new(py);
1057 d.set_item("shape", (self.rows, self.cols))?;
1058 d.set_item("typestr", "<f4")?;
1059 d.set_item(
1060 "strides",
1061 (
1062 self.cols * std::mem::size_of::<f32>(),
1063 std::mem::size_of::<f32>(),
1064 ),
1065 )?;
1066 let ptr = self
1067 .buf
1068 .as_ref()
1069 .ok_or_else(|| PyValueError::new_err("buffer already exported via __dlpack__"))?
1070 .as_device_ptr()
1071 .as_raw() as usize;
1072 d.set_item("data", (ptr, false))?;
1073
1074 d.set_item("version", 3)?;
1075 Ok(d)
1076 }
1077
1078 fn __dlpack_device__(&self) -> (i32, i32) {
1079 (2, self.device_id as i32)
1080 }
1081
1082 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1083 fn __dlpack__<'py>(
1084 &mut self,
1085 py: Python<'py>,
1086 stream: Option<pyo3::PyObject>,
1087 max_version: Option<pyo3::PyObject>,
1088 dl_device: Option<pyo3::PyObject>,
1089 copy: Option<pyo3::PyObject>,
1090 ) -> PyResult<PyObject> {
1091 let (kdl, alloc_dev) = self.__dlpack_device__();
1092 if let Some(dev_obj) = dl_device.as_ref() {
1093 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1094 if dev_ty != kdl || dev_id != alloc_dev {
1095 let wants_copy = copy
1096 .as_ref()
1097 .and_then(|c| c.extract::<bool>(py).ok())
1098 .unwrap_or(false);
1099 if wants_copy {
1100 return Err(PyValueError::new_err(
1101 "device copy not implemented for __dlpack__",
1102 ));
1103 } else {
1104 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1105 }
1106 }
1107 }
1108 }
1109 let _ = stream;
1110
1111 let buf = self
1112 .buf
1113 .take()
1114 .ok_or_else(|| PyValueError::new_err("__dlpack__ may only be called once"))?;
1115
1116 let rows = self.rows;
1117 let cols = self.cols;
1118
1119 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1120
1121 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1122 }
1123}
1124
1125#[cfg(all(feature = "python", feature = "cuda"))]
1126#[pyfunction(name = "linearreg_intercept_cuda_batch_dev")]
1127#[pyo3(signature = (data, period_range, device_id=0))]
1128pub fn linearreg_intercept_cuda_batch_dev_py(
1129 py: Python<'_>,
1130 data: numpy::PyReadonlyArray1<'_, f32>,
1131 period_range: (usize, usize, usize),
1132 device_id: usize,
1133) -> PyResult<LinearRegInterceptDeviceArrayF32Py> {
1134 use crate::cuda::cuda_available;
1135 use crate::cuda::CudaLinregIntercept;
1136 if !cuda_available() {
1137 return Err(PyValueError::new_err("CUDA not available"));
1138 }
1139 let slice = data.as_slice()?;
1140 let sweep = LinearRegInterceptBatchRange {
1141 period: period_range,
1142 };
1143 let (dev, ctx, dev_id) = py.allow_threads(|| {
1144 let cuda = CudaLinregIntercept::new(device_id)
1145 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1146 let (dev, _combos) = cuda
1147 .linearreg_intercept_batch_dev(slice, &sweep)
1148 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1149 let ctx = cuda.context_arc();
1150 Ok::<_, pyo3::PyErr>((dev, ctx, cuda.device_id()))
1151 })?;
1152 let DeviceArrayF32 { buf, rows, cols } = dev;
1153 Ok(LinearRegInterceptDeviceArrayF32Py {
1154 buf: Some(buf),
1155 rows,
1156 cols,
1157 ctx,
1158 device_id: dev_id,
1159 })
1160}
1161
1162#[cfg(all(feature = "python", feature = "cuda"))]
1163#[pyfunction(name = "linearreg_intercept_cuda_many_series_one_param_dev")]
1164#[pyo3(signature = (data_tm, cols, rows, period, device_id=0))]
1165pub fn linearreg_intercept_cuda_many_series_one_param_dev_py(
1166 py: Python<'_>,
1167 data_tm: numpy::PyReadonlyArray1<'_, f32>,
1168 cols: usize,
1169 rows: usize,
1170 period: usize,
1171 device_id: usize,
1172) -> PyResult<LinearRegInterceptDeviceArrayF32Py> {
1173 use crate::cuda::cuda_available;
1174 use crate::cuda::CudaLinregIntercept;
1175 if !cuda_available() {
1176 return Err(PyValueError::new_err("CUDA not available"));
1177 }
1178 let slice = data_tm.as_slice()?;
1179 let expected = cols
1180 .checked_mul(rows)
1181 .ok_or_else(|| PyValueError::new_err("rows*cols overflow"))?;
1182 if slice.len() != expected {
1183 return Err(PyValueError::new_err("time-major input length mismatch"));
1184 }
1185 let params = LinearRegInterceptParams {
1186 period: Some(period),
1187 };
1188 let (dev, ctx, dev_id) = py.allow_threads(|| {
1189 let cuda = CudaLinregIntercept::new(device_id)
1190 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1191 let dev = cuda
1192 .linearreg_intercept_many_series_one_param_time_major_dev(slice, cols, rows, ¶ms)
1193 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1194 let ctx = cuda.context_arc();
1195 Ok::<_, pyo3::PyErr>((dev, ctx, cuda.device_id()))
1196 })?;
1197 Ok(LinearRegInterceptDeviceArrayF32Py {
1198 buf: Some(dev.buf),
1199 rows: dev.rows,
1200 cols: dev.cols,
1201 ctx,
1202 device_id: dev_id,
1203 })
1204}
1205
1206#[cfg(test)]
1207mod tests {
1208 use super::*;
1209 use crate::skip_if_unsupported;
1210 use crate::utilities::data_loader::read_candles_from_csv;
1211 #[cfg(feature = "proptest")]
1212 use proptest::prelude::*;
1213
1214 fn check_linreg_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1215 skip_if_unsupported!(kernel, test_name);
1216 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1217 let candles = read_candles_from_csv(file_path)?;
1218 let default_params = LinearRegInterceptParams { period: None };
1219 let input = LinearRegInterceptInput::from_candles(&candles, "close", default_params);
1220 let output = linearreg_intercept_with_kernel(&input, kernel)?;
1221 assert_eq!(output.values.len(), candles.close.len());
1222 Ok(())
1223 }
1224
1225 fn check_linreg_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1226 skip_if_unsupported!(kernel, test_name);
1227 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1228 let candles = read_candles_from_csv(file_path)?;
1229 let input = LinearRegInterceptInput::from_candles(
1230 &candles,
1231 "close",
1232 LinearRegInterceptParams::default(),
1233 );
1234 let result = linearreg_intercept_with_kernel(&input, kernel)?;
1235 let expected_last_five = [
1236 60000.91428571429,
1237 59947.142857142855,
1238 59754.57142857143,
1239 59318.4,
1240 59321.91428571429,
1241 ];
1242 let start = result.values.len().saturating_sub(5);
1243 for (i, &val) in result.values[start..].iter().enumerate() {
1244 let diff = (val - expected_last_five[i]).abs();
1245 assert!(
1246 diff < 1e-1,
1247 "[{}] LinReg {:?} mismatch at idx {}: got {}, expected {}",
1248 test_name,
1249 kernel,
1250 i,
1251 val,
1252 expected_last_five[i]
1253 );
1254 }
1255 Ok(())
1256 }
1257
1258 fn check_linreg_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1259 skip_if_unsupported!(kernel, test_name);
1260 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1261 let candles = read_candles_from_csv(file_path)?;
1262 let input = LinearRegInterceptInput::with_default_candles(&candles);
1263 match input.data {
1264 LinearRegInterceptData::Candles { source, .. } => assert_eq!(source, "close"),
1265 _ => panic!("Expected LinearRegInterceptData::Candles"),
1266 }
1267 let output = linearreg_intercept_with_kernel(&input, kernel)?;
1268 assert_eq!(output.values.len(), candles.close.len());
1269 Ok(())
1270 }
1271
1272 fn check_linreg_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1273 skip_if_unsupported!(kernel, test_name);
1274 let input_data = [10.0, 20.0, 30.0];
1275 let params = LinearRegInterceptParams { period: Some(0) };
1276 let input = LinearRegInterceptInput::from_slice(&input_data, params);
1277 let res = linearreg_intercept_with_kernel(&input, kernel);
1278 assert!(
1279 res.is_err(),
1280 "[{}] LinReg should fail with zero period",
1281 test_name
1282 );
1283 Ok(())
1284 }
1285
1286 fn check_linreg_period_exceeds_length(
1287 test_name: &str,
1288 kernel: Kernel,
1289 ) -> Result<(), Box<dyn Error>> {
1290 skip_if_unsupported!(kernel, test_name);
1291 let data_small = [10.0, 20.0, 30.0];
1292 let params = LinearRegInterceptParams { period: Some(10) };
1293 let input = LinearRegInterceptInput::from_slice(&data_small, params);
1294 let res = linearreg_intercept_with_kernel(&input, kernel);
1295 assert!(
1296 res.is_err(),
1297 "[{}] LinReg should fail with period exceeding length",
1298 test_name
1299 );
1300 Ok(())
1301 }
1302
1303 fn check_linreg_very_small_dataset(
1304 test_name: &str,
1305 kernel: Kernel,
1306 ) -> Result<(), Box<dyn Error>> {
1307 skip_if_unsupported!(kernel, test_name);
1308 let single_point = [42.0];
1309 let params = LinearRegInterceptParams { period: Some(14) };
1310 let input = LinearRegInterceptInput::from_slice(&single_point, params);
1311 let res = linearreg_intercept_with_kernel(&input, kernel);
1312 assert!(
1313 res.is_err(),
1314 "[{}] LinReg should fail with insufficient data",
1315 test_name
1316 );
1317 Ok(())
1318 }
1319
1320 fn check_linreg_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1321 skip_if_unsupported!(kernel, test_name);
1322 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1323 let candles = read_candles_from_csv(file_path)?;
1324 let first_params = LinearRegInterceptParams { period: Some(14) };
1325 let first_input = LinearRegInterceptInput::from_candles(&candles, "close", first_params);
1326 let first_result = linearreg_intercept_with_kernel(&first_input, kernel)?;
1327 let second_params = LinearRegInterceptParams { period: Some(14) };
1328 let second_input = LinearRegInterceptInput::from_slice(&first_result.values, second_params);
1329 let second_result = linearreg_intercept_with_kernel(&second_input, kernel)?;
1330 assert_eq!(second_result.values.len(), first_result.values.len());
1331
1332 let start = second_result
1333 .values
1334 .iter()
1335 .position(|v| !v.is_nan())
1336 .unwrap_or(second_result.values.len());
1337
1338 for (i, v) in second_result.values[start..].iter().enumerate() {
1339 assert!(
1340 !v.is_nan(),
1341 "[{}] Unexpected NaN at index {} after reinput",
1342 test_name,
1343 start + i
1344 );
1345 }
1346 Ok(())
1347 }
1348
1349 fn check_linreg_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1350 skip_if_unsupported!(kernel, test_name);
1351 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1352 let candles = read_candles_from_csv(file_path)?;
1353 let input = LinearRegInterceptInput::from_candles(
1354 &candles,
1355 "close",
1356 LinearRegInterceptParams { period: Some(14) },
1357 );
1358 let res = linearreg_intercept_with_kernel(&input, kernel)?;
1359 assert_eq!(res.values.len(), candles.close.len());
1360 if res.values.len() > 40 {
1361 for (i, &val) in res.values[40..].iter().enumerate() {
1362 assert!(
1363 !val.is_nan(),
1364 "[{}] Found unexpected NaN at out-index {}",
1365 test_name,
1366 40 + i
1367 );
1368 }
1369 }
1370 Ok(())
1371 }
1372
1373 #[cfg(debug_assertions)]
1374 fn check_linreg_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1375 skip_if_unsupported!(kernel, test_name);
1376
1377 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1378 let candles = read_candles_from_csv(file_path)?;
1379
1380 let test_params = vec![
1381 LinearRegInterceptParams::default(),
1382 LinearRegInterceptParams { period: Some(2) },
1383 LinearRegInterceptParams { period: Some(5) },
1384 LinearRegInterceptParams { period: Some(7) },
1385 LinearRegInterceptParams { period: Some(10) },
1386 LinearRegInterceptParams { period: Some(20) },
1387 LinearRegInterceptParams { period: Some(30) },
1388 LinearRegInterceptParams { period: Some(50) },
1389 LinearRegInterceptParams { period: Some(100) },
1390 LinearRegInterceptParams { period: Some(150) },
1391 LinearRegInterceptParams { period: Some(200) },
1392 ];
1393
1394 for (param_idx, params) in test_params.iter().enumerate() {
1395 let input = LinearRegInterceptInput::from_candles(&candles, "close", params.clone());
1396 let output = linearreg_intercept_with_kernel(&input, kernel)?;
1397
1398 for (i, &val) in output.values.iter().enumerate() {
1399 if val.is_nan() {
1400 continue;
1401 }
1402
1403 let bits = val.to_bits();
1404
1405 if bits == 0x11111111_11111111 {
1406 panic!(
1407 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} \
1408 with params: period={} (param set {})",
1409 test_name,
1410 val,
1411 bits,
1412 i,
1413 params.period.unwrap_or(14),
1414 param_idx
1415 );
1416 }
1417
1418 if bits == 0x22222222_22222222 {
1419 panic!(
1420 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} \
1421 with params: period={} (param set {})",
1422 test_name,
1423 val,
1424 bits,
1425 i,
1426 params.period.unwrap_or(14),
1427 param_idx
1428 );
1429 }
1430
1431 if bits == 0x33333333_33333333 {
1432 panic!(
1433 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} \
1434 with params: period={} (param set {})",
1435 test_name,
1436 val,
1437 bits,
1438 i,
1439 params.period.unwrap_or(14),
1440 param_idx
1441 );
1442 }
1443 }
1444 }
1445
1446 Ok(())
1447 }
1448
1449 #[cfg(not(debug_assertions))]
1450 fn check_linreg_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1451 Ok(())
1452 }
1453
1454 #[cfg(feature = "proptest")]
1455 #[allow(clippy::float_cmp)]
1456 fn check_linearreg_intercept_property(
1457 test_name: &str,
1458 kernel: Kernel,
1459 ) -> Result<(), Box<dyn std::error::Error>> {
1460 skip_if_unsupported!(kernel, test_name);
1461
1462 fn calculate_expected_linreg_intercept(
1463 window_start_idx: usize,
1464 period: usize,
1465 data_slope: f64,
1466 data_intercept: f64,
1467 ) -> f64 {
1468 data_slope * window_start_idx as f64 + data_intercept
1469 }
1470
1471 let strat = (1usize..=100, 50usize..500, 0usize..5, any::<u64>()).prop_map(
1472 |(period, len, scenario, seed)| {
1473 let mut rng_state = seed.wrapping_mul(1664525).wrapping_add(1013904223);
1474 let mut data = Vec::with_capacity(len);
1475
1476 match scenario {
1477 0 => {
1478 for _ in 0..len {
1479 rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
1480 let val = (rng_state as f64 / u64::MAX as f64) * 200.0 - 100.0;
1481 data.push(val);
1482 }
1483 }
1484 1 => {
1485 let constant = 42.0;
1486 data.resize(len, constant);
1487 }
1488 2 => {
1489 for i in 0..len {
1490 data.push(2.0 * i as f64 + 10.0);
1491 }
1492 }
1493 3 => {
1494 for i in 0..len {
1495 data.push(-1.5 * i as f64 + 100.0);
1496 }
1497 }
1498 _ => {
1499 for i in 0..len {
1500 rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
1501 let noise = ((rng_state as f64 / u64::MAX as f64) - 0.5) * 10.0;
1502 data.push(0.5 * i as f64 + 50.0 + noise);
1503 }
1504 }
1505 }
1506
1507 (data, period, scenario)
1508 },
1509 );
1510
1511 proptest::test_runner::TestRunner::default().run(&strat, |(data, period, scenario)| {
1512 let params = LinearRegInterceptParams {
1513 period: Some(period),
1514 };
1515 let input = LinearRegInterceptInput::from_slice(&data, params);
1516
1517 let output = linearreg_intercept_with_kernel(&input, kernel)?;
1518
1519 let ref_output = linearreg_intercept_with_kernel(&input, Kernel::Scalar)?;
1520
1521 prop_assert_eq!(
1522 output.values.len(),
1523 data.len(),
1524 "[{}] Output length mismatch",
1525 test_name
1526 );
1527
1528 if period == 1 {
1529 for i in 0..data.len() {
1530 let expected = data[i];
1531 let actual = output.values[i];
1532 prop_assert!(
1533 (actual - expected).abs() < 1e-9,
1534 "[{}] Period=1: expected {}, got {} at index {}",
1535 test_name,
1536 expected,
1537 actual,
1538 i
1539 );
1540 }
1541 } else {
1542 for i in 0..(period - 1) {
1543 prop_assert!(
1544 output.values[i].is_nan(),
1545 "[{}] Expected NaN during warmup at index {}",
1546 test_name,
1547 i
1548 );
1549 }
1550
1551 if period <= data.len() {
1552 prop_assert!(
1553 !output.values[period - 1].is_nan(),
1554 "[{}] Expected valid value at index {}",
1555 test_name,
1556 period - 1
1557 );
1558 }
1559 }
1560
1561 if scenario == 1 && period < data.len() {
1562 for i in (period - 1)..data.len() {
1563 let intercept = output.values[i];
1564 if !intercept.is_nan() {
1565 prop_assert!(
1566 (intercept - 42.0).abs() < 1e-9,
1567 "[{}] Constant data: expected 42.0, got {} at index {}",
1568 test_name,
1569 intercept,
1570 i
1571 );
1572 }
1573 }
1574 }
1575
1576 if (scenario == 2 || scenario == 3) && period > 1 && period < data.len() {
1577 let (data_slope, data_intercept) = match scenario {
1578 2 => (2.0, 10.0),
1579 3 => (-1.5, 100.0),
1580 _ => unreachable!(),
1581 };
1582
1583 for i in (period - 1)..data.len().min(period * 5) {
1584 let actual = output.values[i];
1585 if !actual.is_nan() {
1586 let window_start = i + 1 - period;
1587 let expected = calculate_expected_linreg_intercept(
1588 window_start,
1589 period,
1590 data_slope,
1591 data_intercept,
1592 );
1593
1594 prop_assert!((actual - expected).abs() < 1e-9,
1595 "[{}] Linear trend (scenario {}): expected {:.6}, got {:.6} at index {} (window start {})",
1596 test_name, scenario, expected, actual, i, window_start);
1597 }
1598 }
1599 }
1600
1601 for i in 0..output.values.len() {
1602 let y = output.values[i];
1603 let r = ref_output.values[i];
1604
1605 let bits = y.to_bits();
1606 prop_assert!(
1607 bits != 0x11111111_11111111
1608 && bits != 0x22222222_22222222
1609 && bits != 0x33333333_33333333,
1610 "[{}] Found poison value at index {}: 0x{:016X}",
1611 test_name,
1612 i,
1613 bits
1614 );
1615
1616 if y.is_nan() && r.is_nan() {
1617 continue;
1618 }
1619
1620 if y.is_finite() && r.is_finite() {
1621 prop_assert!(
1622 (y - r).abs() <= 1e-9,
1623 "[{}] Kernel mismatch at index {}: {} vs {} (diff: {})",
1624 test_name,
1625 i,
1626 y,
1627 r,
1628 (y - r).abs()
1629 );
1630 } else {
1631 prop_assert_eq!(
1632 y.is_nan(),
1633 r.is_nan(),
1634 "[{}] NaN mismatch at index {}: {} vs {}",
1635 test_name,
1636 i,
1637 y,
1638 r
1639 );
1640 }
1641 }
1642
1643 if period > 1 {
1644 for i in (period - 1)..output.values.len() {
1645 let val = output.values[i];
1646 prop_assert!(
1647 val.is_finite(),
1648 "[{}] Non-finite value {} at index {}",
1649 test_name,
1650 val,
1651 i
1652 );
1653 }
1654 }
1655
1656 Ok(())
1657 })?;
1658
1659 Ok(())
1660 }
1661
1662 macro_rules! generate_all_linreg_tests {
1663 ($($test_fn:ident),*) => {
1664 paste::paste! {
1665 $(
1666 #[test]
1667 fn [<$test_fn _scalar_f64>]() {
1668 let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar);
1669 }
1670 )*
1671 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1672 $(
1673 #[test]
1674 fn [<$test_fn _avx2_f64>]() {
1675 let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2);
1676 }
1677 #[test]
1678 fn [<$test_fn _avx512_f64>]() {
1679 let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512);
1680 }
1681 )*
1682 }
1683 }
1684 }
1685
1686 generate_all_linreg_tests!(
1687 check_linreg_partial_params,
1688 check_linreg_accuracy,
1689 check_linreg_default_candles,
1690 check_linreg_zero_period,
1691 check_linreg_period_exceeds_length,
1692 check_linreg_very_small_dataset,
1693 check_linreg_reinput,
1694 check_linreg_nan_handling,
1695 check_linreg_no_poison
1696 );
1697
1698 #[cfg(feature = "proptest")]
1699 generate_all_linreg_tests!(check_linearreg_intercept_property);
1700
1701 #[test]
1702 fn test_linearreg_intercept_into_matches_api() -> Result<(), Box<dyn Error>> {
1703 let len = 256usize;
1704 let mut data = Vec::with_capacity(len);
1705 for i in 0..len {
1706 if i < 10 {
1707 data.push(f64::NAN);
1708 } else {
1709 let x = i as f64;
1710 data.push((0.1 * x).sin() * 3.0 + 0.05 * x + 2.0);
1711 }
1712 }
1713
1714 let input = LinearRegInterceptInput::from_slice(&data, LinearRegInterceptParams::default());
1715
1716 let baseline = linearreg_intercept(&input)?.values;
1717
1718 let mut out = vec![0.0; data.len()];
1719 #[allow(unused_variables)]
1720 {
1721 #[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
1722 {
1723 linearreg_intercept_into(&input, &mut out)?;
1724 }
1725 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1726 {
1727 linearreg_intercept_into_slice(&mut out, &input, Kernel::Auto)?;
1728 }
1729 }
1730
1731 assert_eq!(baseline.len(), out.len());
1732
1733 fn eq_or_both_nan_eps(a: f64, b: f64) -> bool {
1734 (a.is_nan() && b.is_nan()) || (a - b).abs() <= 1e-12
1735 }
1736
1737 for i in 0..out.len() {
1738 assert!(
1739 eq_or_both_nan_eps(baseline[i], out[i]),
1740 "mismatch at index {}: baseline={} out={}",
1741 i,
1742 baseline[i],
1743 out[i]
1744 );
1745 }
1746
1747 Ok(())
1748 }
1749
1750 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1751 skip_if_unsupported!(kernel, test);
1752
1753 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1754 let c = read_candles_from_csv(file)?;
1755
1756 let output = LinearRegInterceptBatchBuilder::new()
1757 .kernel(kernel)
1758 .apply_candles(&c, "close")?;
1759
1760 let def = LinearRegInterceptParams::default();
1761 let row = output.values_for(&def).expect("default row missing");
1762
1763 assert_eq!(row.len(), c.close.len());
1764
1765 let expected = [
1766 60000.91428571429,
1767 59947.142857142855,
1768 59754.57142857143,
1769 59318.4,
1770 59321.91428571429,
1771 ];
1772 let start = row.len() - 5;
1773 for (i, &v) in row[start..].iter().enumerate() {
1774 assert!(
1775 (v - expected[i]).abs() < 1e-1,
1776 "[{test}] default-row mismatch at idx {i}: {v} vs {expected:?}"
1777 );
1778 }
1779 Ok(())
1780 }
1781
1782 #[cfg(debug_assertions)]
1783 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1784 skip_if_unsupported!(kernel, test);
1785
1786 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1787 let c = read_candles_from_csv(file)?;
1788
1789 let test_configs = vec![
1790 (2, 10, 2),
1791 (5, 25, 5),
1792 (10, 10, 0),
1793 (2, 5, 1),
1794 (30, 60, 15),
1795 (2, 14, 3),
1796 (50, 100, 25),
1797 (100, 200, 50),
1798 ];
1799
1800 for (cfg_idx, &(p_start, p_end, p_step)) in test_configs.iter().enumerate() {
1801 let output = LinearRegInterceptBatchBuilder::new()
1802 .kernel(kernel)
1803 .period_range(p_start, p_end, p_step)
1804 .apply_candles(&c, "close")?;
1805
1806 for (idx, &val) in output.values.iter().enumerate() {
1807 if val.is_nan() {
1808 continue;
1809 }
1810
1811 let bits = val.to_bits();
1812 let row = idx / output.cols;
1813 let col = idx % output.cols;
1814 let combo = &output.combos[row];
1815
1816 if bits == 0x11111111_11111111 {
1817 panic!(
1818 "[{}] Config {}: Found alloc_with_nan_prefix poison value {} (0x{:016X}) \
1819 at row {} col {} (flat index {}) with params: period={}",
1820 test,
1821 cfg_idx,
1822 val,
1823 bits,
1824 row,
1825 col,
1826 idx,
1827 combo.period.unwrap_or(14)
1828 );
1829 }
1830
1831 if bits == 0x22222222_22222222 {
1832 panic!(
1833 "[{}] Config {}: Found init_matrix_prefixes poison value {} (0x{:016X}) \
1834 at row {} col {} (flat index {}) with params: period={}",
1835 test,
1836 cfg_idx,
1837 val,
1838 bits,
1839 row,
1840 col,
1841 idx,
1842 combo.period.unwrap_or(14)
1843 );
1844 }
1845
1846 if bits == 0x33333333_33333333 {
1847 panic!(
1848 "[{}] Config {}: Found make_uninit_matrix poison value {} (0x{:016X}) \
1849 at row {} col {} (flat index {}) with params: period={}",
1850 test,
1851 cfg_idx,
1852 val,
1853 bits,
1854 row,
1855 col,
1856 idx,
1857 combo.period.unwrap_or(14)
1858 );
1859 }
1860 }
1861 }
1862
1863 Ok(())
1864 }
1865
1866 #[cfg(not(debug_assertions))]
1867 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1868 Ok(())
1869 }
1870
1871 macro_rules! gen_batch_tests {
1872 ($fn_name:ident) => {
1873 paste::paste! {
1874 #[test] fn [<$fn_name _scalar>]() {
1875 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1876 }
1877 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1878 #[test] fn [<$fn_name _avx2>]() {
1879 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1880 }
1881 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1882 #[test] fn [<$fn_name _avx512>]() {
1883 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1884 }
1885 #[test] fn [<$fn_name _auto_detect>]() {
1886 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1887 }
1888 }
1889 };
1890 }
1891 gen_batch_tests!(check_batch_default_row);
1892 gen_batch_tests!(check_batch_no_poison);
1893}
1894
1895#[cfg(feature = "python")]
1896#[pyfunction(name = "linearreg_intercept")]
1897#[pyo3(signature = (data, period, kernel=None))]
1898pub fn linearreg_intercept_py<'py>(
1899 py: Python<'py>,
1900 data: numpy::PyReadonlyArray1<'py, f64>,
1901 period: usize,
1902 kernel: Option<&str>,
1903) -> PyResult<Bound<'py, numpy::PyArray1<f64>>> {
1904 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1905
1906 let slice_in = data.as_slice()?;
1907 let kern = validate_kernel(kernel, false)?;
1908 let params = LinearRegInterceptParams {
1909 period: Some(period),
1910 };
1911 let input = LinearRegInterceptInput::from_slice(slice_in, params);
1912
1913 let result_vec: Vec<f64> = py
1914 .allow_threads(|| linearreg_intercept_with_kernel(&input, kern).map(|o| o.values))
1915 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1916
1917 Ok(result_vec.into_pyarray(py))
1918}
1919
1920#[cfg(feature = "python")]
1921#[pyclass(name = "LinearRegInterceptStream")]
1922pub struct LinearRegInterceptStreamPy {
1923 stream: LinearRegInterceptStream,
1924}
1925
1926#[cfg(feature = "python")]
1927#[pymethods]
1928impl LinearRegInterceptStreamPy {
1929 #[new]
1930 fn new(period: usize) -> PyResult<Self> {
1931 let params = LinearRegInterceptParams {
1932 period: Some(period),
1933 };
1934 let stream = LinearRegInterceptStream::try_new(params)
1935 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1936 Ok(LinearRegInterceptStreamPy { stream })
1937 }
1938
1939 fn update(&mut self, value: f64) -> Option<f64> {
1940 self.stream.update(value)
1941 }
1942}
1943
1944#[cfg(feature = "python")]
1945#[pyfunction(name = "linearreg_intercept_batch")]
1946#[pyo3(signature = (data, period_range, kernel=None))]
1947pub fn linearreg_intercept_batch_py<'py>(
1948 py: Python<'py>,
1949 data: numpy::PyReadonlyArray1<'py, f64>,
1950 period_range: (usize, usize, usize),
1951 kernel: Option<&str>,
1952) -> PyResult<Bound<'py, pyo3::types::PyDict>> {
1953 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1954 use pyo3::types::PyDict;
1955
1956 let slice_in = data.as_slice()?;
1957 let sweep = LinearRegInterceptBatchRange {
1958 period: period_range,
1959 };
1960
1961 let combos = expand_grid(&sweep).map_err(|e| PyValueError::new_err(e.to_string()))?;
1962 let rows = combos.len();
1963 let cols = slice_in.len();
1964
1965 let total = rows
1966 .checked_mul(cols)
1967 .ok_or_else(|| PyValueError::new_err("linearreg_intercept_batch: rows*cols overflow"))?;
1968
1969 let out_arr = unsafe { PyArray1::<f64>::new(py, [total], false) };
1970 let slice_out = unsafe { out_arr.as_slice_mut()? };
1971
1972 if !combos.is_empty() && cols > 0 {
1973 let first = slice_in.iter().position(|x| !x.is_nan()).unwrap_or(0);
1974 for (r, prm) in combos.iter().enumerate() {
1975 let warm = (first + prm.period.unwrap() - 1).min(cols);
1976 for v in &mut slice_out[r * cols..r * cols + warm] {
1977 *v = f64::NAN;
1978 }
1979 }
1980 }
1981
1982 let kern = validate_kernel(kernel, true)?;
1983 let combos = py
1984 .allow_threads(|| {
1985 let kernel = match kern {
1986 Kernel::Auto => detect_best_batch_kernel(),
1987 k => k,
1988 };
1989 let simd = match kernel {
1990 Kernel::Avx512Batch => Kernel::Avx512,
1991 Kernel::Avx2Batch => Kernel::Avx2,
1992 Kernel::ScalarBatch => Kernel::Scalar,
1993 _ => unreachable!(),
1994 };
1995 linearreg_intercept_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1996 })
1997 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1998
1999 let dict = PyDict::new(py);
2000 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
2001 dict.set_item(
2002 "periods",
2003 combos
2004 .iter()
2005 .map(|p| p.period.unwrap() as u64)
2006 .collect::<Vec<_>>()
2007 .into_pyarray(py),
2008 )?;
2009 Ok(dict)
2010}
2011
2012#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2013#[wasm_bindgen]
2014pub fn linearreg_intercept_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
2015 let params = LinearRegInterceptParams {
2016 period: Some(period),
2017 };
2018 let input = LinearRegInterceptInput::from_slice(data, params);
2019
2020 let mut output = vec![0.0; data.len()];
2021 linearreg_intercept_into_slice(&mut output, &input, Kernel::Auto)
2022 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2023
2024 Ok(output)
2025}
2026
2027#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2028#[wasm_bindgen]
2029pub fn linearreg_intercept_alloc(len: usize) -> *mut f64 {
2030 let mut vec = Vec::<f64>::with_capacity(len);
2031 let ptr = vec.as_mut_ptr();
2032 std::mem::forget(vec);
2033 ptr
2034}
2035
2036#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2037#[wasm_bindgen]
2038pub fn linearreg_intercept_free(ptr: *mut f64, len: usize) {
2039 if !ptr.is_null() {
2040 unsafe {
2041 let _ = Vec::from_raw_parts(ptr, len, len);
2042 }
2043 }
2044}
2045
2046#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2047#[wasm_bindgen]
2048pub fn linearreg_intercept_into(
2049 in_ptr: *const f64,
2050 out_ptr: *mut f64,
2051 len: usize,
2052 period: usize,
2053) -> Result<(), JsValue> {
2054 if in_ptr.is_null() || out_ptr.is_null() {
2055 return Err(JsValue::from_str("Null pointer provided"));
2056 }
2057
2058 unsafe {
2059 let data = std::slice::from_raw_parts(in_ptr, len);
2060 let params = LinearRegInterceptParams {
2061 period: Some(period),
2062 };
2063 let input = LinearRegInterceptInput::from_slice(data, params);
2064
2065 if in_ptr == out_ptr {
2066 let mut temp = vec![0.0; len];
2067 linearreg_intercept_into_slice(&mut temp, &input, Kernel::Auto)
2068 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2069 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2070 out.copy_from_slice(&temp);
2071 } else {
2072 let out = std::slice::from_raw_parts_mut(out_ptr, len);
2073 linearreg_intercept_into_slice(out, &input, Kernel::Auto)
2074 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2075 }
2076 Ok(())
2077 }
2078}
2079
2080#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2081#[derive(Serialize, Deserialize)]
2082pub struct LinearRegInterceptBatchConfig {
2083 pub period_range: (usize, usize, usize),
2084}
2085
2086#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2087#[derive(Serialize, Deserialize)]
2088pub struct LinearRegInterceptBatchJsOutput {
2089 pub values: Vec<f64>,
2090 pub combos: Vec<LinearRegInterceptParams>,
2091 pub rows: usize,
2092 pub cols: usize,
2093}
2094
2095#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2096#[wasm_bindgen(js_name = linearreg_intercept_batch)]
2097pub fn linearreg_intercept_batch_unified_js(
2098 data: &[f64],
2099 config: JsValue,
2100) -> Result<JsValue, JsValue> {
2101 let config: LinearRegInterceptBatchConfig = serde_wasm_bindgen::from_value(config)
2102 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
2103
2104 let sweep = LinearRegInterceptBatchRange {
2105 period: config.period_range,
2106 };
2107
2108 let batch_output = linearreg_intercept_batch_with_kernel(data, &sweep, Kernel::Auto)
2109 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2110
2111 let rows = batch_output.values.len() / data.len();
2112 let result = LinearRegInterceptBatchJsOutput {
2113 values: batch_output.values,
2114 combos: batch_output.combos,
2115 rows,
2116 cols: data.len(),
2117 };
2118
2119 serde_wasm_bindgen::to_value(&result).map_err(|e| JsValue::from_str(&e.to_string()))
2120}
2121
2122#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
2123#[wasm_bindgen]
2124pub fn linearreg_intercept_batch_into(
2125 in_ptr: *const f64,
2126 out_ptr: *mut f64,
2127 len: usize,
2128 period_start: usize,
2129 period_end: usize,
2130 period_step: usize,
2131) -> Result<usize, JsValue> {
2132 if in_ptr.is_null() || out_ptr.is_null() {
2133 return Err(JsValue::from_str("Null pointer provided"));
2134 }
2135
2136 unsafe {
2137 let data = std::slice::from_raw_parts(in_ptr, len);
2138 let sweep = LinearRegInterceptBatchRange {
2139 period: (period_start, period_end, period_step),
2140 };
2141 let combos = expand_grid(&sweep).map_err(|e| JsValue::from_str(&e.to_string()))?;
2142 let rows = combos.len();
2143 let cols = len;
2144 let total = rows.checked_mul(cols).ok_or_else(|| {
2145 JsValue::from_str("linearreg_intercept_batch_into: rows*cols overflow")
2146 })?;
2147
2148 let first = data.iter().position(|x| !x.is_nan()).unwrap_or(0);
2149
2150 if in_ptr == out_ptr {
2151 let mut temp = vec![0.0; total];
2152
2153 for (r, prm) in combos.iter().enumerate() {
2154 let warm = (first + prm.period.unwrap() - 1).min(cols);
2155 temp[r * cols..r * cols + warm].fill(f64::NAN);
2156 }
2157
2158 linearreg_intercept_batch_inner_into(data, &sweep, Kernel::Auto, true, &mut temp)
2159 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2160 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2161 out.copy_from_slice(&temp);
2162 } else {
2163 let out = std::slice::from_raw_parts_mut(out_ptr, total);
2164
2165 for (r, prm) in combos.iter().enumerate() {
2166 let warm = (first + prm.period.unwrap() - 1).min(cols);
2167 out[r * cols..r * cols + warm].fill(f64::NAN);
2168 }
2169
2170 linearreg_intercept_batch_inner_into(data, &sweep, Kernel::Auto, true, out)
2171 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2172 }
2173 Ok(rows)
2174 }
2175}