1#[cfg(all(feature = "python", feature = "cuda"))]
2use crate::cuda::moving_averages::{CudaLinreg, DeviceArrayF32};
3use crate::utilities::data_loader::{source_type, Candles};
4use crate::utilities::enums::Kernel;
5use crate::utilities::helpers::{
6 alloc_with_nan_prefix, detect_best_batch_kernel, detect_best_kernel, init_matrix_prefixes,
7 make_uninit_matrix,
8};
9#[cfg(feature = "python")]
10use crate::utilities::kernel_validation::validate_kernel;
11use aligned_vec::{AVec, CACHELINE_ALIGN};
12#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
13use core::arch::x86_64::*;
14#[cfg(all(feature = "python", feature = "cuda"))]
15use cust::context::Context;
16#[cfg(all(feature = "python", feature = "cuda"))]
17use cust::memory::DeviceBuffer;
18#[cfg(not(target_arch = "wasm32"))]
19use rayon::prelude::*;
20#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
21use serde::{Deserialize, Serialize};
22use std::convert::AsRef;
23use std::error::Error;
24use std::mem::MaybeUninit;
25#[cfg(all(feature = "python", feature = "cuda"))]
26use std::sync::Arc;
27use thiserror::Error;
28
29#[derive(Debug, Clone)]
30pub enum LinRegData<'a> {
31 Candles {
32 candles: &'a Candles,
33 source: &'a str,
34 },
35 Slice(&'a [f64]),
36}
37
38#[derive(Debug, Clone)]
39pub struct LinRegOutput {
40 pub values: Vec<f64>,
41}
42
43#[derive(Debug, Clone)]
44#[cfg_attr(
45 all(target_arch = "wasm32", feature = "wasm"),
46 derive(Serialize, Deserialize)
47)]
48pub struct LinRegParams {
49 pub period: Option<usize>,
50}
51
52impl Default for LinRegParams {
53 fn default() -> Self {
54 Self { period: Some(14) }
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct LinRegInput<'a> {
60 pub data: LinRegData<'a>,
61 pub params: LinRegParams,
62}
63
64impl<'a> AsRef<[f64]> for LinRegInput<'a> {
65 #[inline(always)]
66 fn as_ref(&self) -> &[f64] {
67 match &self.data {
68 LinRegData::Slice(slice) => slice,
69 LinRegData::Candles { candles, source } => source_type(candles, source),
70 }
71 }
72}
73
74impl<'a> LinRegInput<'a> {
75 #[inline]
76 pub fn from_candles(c: &'a Candles, s: &'a str, p: LinRegParams) -> Self {
77 Self {
78 data: LinRegData::Candles {
79 candles: c,
80 source: s,
81 },
82 params: p,
83 }
84 }
85 #[inline]
86 pub fn from_slice(sl: &'a [f64], p: LinRegParams) -> Self {
87 Self {
88 data: LinRegData::Slice(sl),
89 params: p,
90 }
91 }
92 #[inline]
93 pub fn with_default_candles(c: &'a Candles) -> Self {
94 Self::from_candles(c, "close", LinRegParams::default())
95 }
96 #[inline]
97 pub fn get_period(&self) -> usize {
98 self.params.period.unwrap_or(14)
99 }
100}
101
102#[derive(Copy, Clone, Debug)]
103pub struct LinRegBuilder {
104 period: Option<usize>,
105 kernel: Kernel,
106}
107
108impl Default for LinRegBuilder {
109 fn default() -> Self {
110 Self {
111 period: None,
112 kernel: Kernel::Auto,
113 }
114 }
115}
116
117impl LinRegBuilder {
118 #[inline(always)]
119 pub fn new() -> Self {
120 Self::default()
121 }
122 #[inline(always)]
123 pub fn period(mut self, n: usize) -> Self {
124 self.period = Some(n);
125 self
126 }
127 #[inline(always)]
128 pub fn kernel(mut self, k: Kernel) -> Self {
129 self.kernel = k;
130 self
131 }
132 #[inline(always)]
133 pub fn apply(self, c: &Candles) -> Result<LinRegOutput, LinRegError> {
134 let p = LinRegParams {
135 period: self.period,
136 };
137 let i = LinRegInput::from_candles(c, "close", p);
138 linreg_with_kernel(&i, self.kernel)
139 }
140 #[inline(always)]
141 pub fn apply_slice(self, d: &[f64]) -> Result<LinRegOutput, LinRegError> {
142 let p = LinRegParams {
143 period: self.period,
144 };
145 let i = LinRegInput::from_slice(d, p);
146 linreg_with_kernel(&i, self.kernel)
147 }
148 #[inline(always)]
149 pub fn into_stream(self) -> Result<LinRegStream, LinRegError> {
150 let p = LinRegParams {
151 period: self.period,
152 };
153 LinRegStream::try_new(p)
154 }
155}
156
157#[derive(Debug, Error)]
158pub enum LinRegError {
159 #[error("linreg: No data provided (All values are NaN).")]
160 EmptyInputData,
161 #[error("linreg: All values are NaN.")]
162 AllValuesNaN,
163 #[error("linreg: Invalid period: period = {period}, data length = {data_len}")]
164 InvalidPeriod { period: usize, data_len: usize },
165 #[error("linreg: Not enough valid data: needed = {needed}, valid = {valid}")]
166 NotEnoughValidData { needed: usize, valid: usize },
167 #[error("linreg: Output length mismatch: expected = {expected}, got = {got}")]
168 OutputLengthMismatch { expected: usize, got: usize },
169 #[error("linreg: Invalid range: start = {start}, end = {end}, step = {step}")]
170 InvalidRange {
171 start: usize,
172 end: usize,
173 step: usize,
174 },
175 #[error("linreg: Invalid kernel for batch API: {0:?}")]
176 InvalidKernelForBatch(Kernel),
177 #[error("linreg: arithmetic overflow when computing {what}")]
178 ArithmeticOverflow { what: &'static str },
179}
180
181#[inline]
182pub fn linreg(input: &LinRegInput) -> Result<LinRegOutput, LinRegError> {
183 linreg_with_kernel(input, Kernel::Auto)
184}
185
186#[inline(always)]
187fn linreg_prepare<'a>(
188 input: &'a LinRegInput,
189 kernel: Kernel,
190) -> Result<(&'a [f64], usize, usize, Kernel), LinRegError> {
191 let data: &[f64] = input.as_ref();
192 if data.is_empty() {
193 return Err(LinRegError::EmptyInputData);
194 }
195 let first = data
196 .iter()
197 .position(|x| !x.is_nan())
198 .ok_or(LinRegError::AllValuesNaN)?;
199 let len = data.len();
200 let period = input.get_period();
201
202 if period == 0 || period > len {
203 return Err(LinRegError::InvalidPeriod {
204 period,
205 data_len: len,
206 });
207 }
208 if (len - first) < period {
209 return Err(LinRegError::NotEnoughValidData {
210 needed: period,
211 valid: len - first,
212 });
213 }
214
215 let chosen = match kernel {
216 Kernel::Auto => Kernel::Scalar,
217 other => other,
218 };
219
220 Ok((data, period, first, chosen))
221}
222
223pub fn linreg_with_kernel(
224 input: &LinRegInput,
225 kernel: Kernel,
226) -> Result<LinRegOutput, LinRegError> {
227 let (data, period, first, chosen) = linreg_prepare(input, kernel)?;
228
229 let warm = first + period;
230 let mut out = alloc_with_nan_prefix(data.len(), warm);
231
232 unsafe {
233 match chosen {
234 Kernel::Scalar | Kernel::ScalarBatch => linreg_scalar(data, period, first, &mut out),
235 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
236 Kernel::Avx2 | Kernel::Avx2Batch => linreg_avx2(data, period, first, &mut out),
237 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
238 Kernel::Avx512 | Kernel::Avx512Batch => linreg_avx512(data, period, first, &mut out),
239 _ => unreachable!(),
240 }
241 }
242
243 Ok(LinRegOutput { values: out })
244}
245
246#[cfg(not(all(target_arch = "wasm32", feature = "wasm")))]
247#[inline]
248pub fn linreg_into(input: &LinRegInput, out: &mut [f64]) -> Result<(), LinRegError> {
249 linreg_compute_into(input, Kernel::Scalar, out)
250}
251
252pub fn linreg_compute_into(
253 input: &LinRegInput,
254 kernel: Kernel,
255 out: &mut [f64],
256) -> Result<(), LinRegError> {
257 let data: &[f64] = input.as_ref();
258 if data.is_empty() {
259 return Err(LinRegError::EmptyInputData);
260 }
261 let first = data
262 .iter()
263 .position(|x| !x.is_nan())
264 .ok_or(LinRegError::AllValuesNaN)?;
265 let len = data.len();
266 let period = input.get_period();
267
268 if period == 0 || period > len {
269 return Err(LinRegError::InvalidPeriod {
270 period,
271 data_len: len,
272 });
273 }
274 if (len - first) < period {
275 return Err(LinRegError::NotEnoughValidData {
276 needed: period,
277 valid: len - first,
278 });
279 }
280 if out.len() != len {
281 return Err(LinRegError::OutputLengthMismatch {
282 expected: len,
283 got: out.len(),
284 });
285 }
286
287 let chosen = match kernel {
288 Kernel::Auto => Kernel::Scalar,
289 other => other,
290 };
291
292 let warm = first + period;
293
294 out[..warm].fill(f64::NAN);
295
296 unsafe {
297 match chosen {
298 Kernel::Scalar | Kernel::ScalarBatch => linreg_scalar(data, period, first, out),
299 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
300 Kernel::Avx2 | Kernel::Avx2Batch => linreg_avx2(data, period, first, out),
301 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
302 Kernel::Avx512 | Kernel::Avx512Batch => linreg_avx512(data, period, first, out),
303 _ => unreachable!(),
304 }
305 }
306
307 Ok(())
308}
309
310#[inline(always)]
311fn linreg_scalar(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
312 let period_f = period as f64;
313 let x_sum = ((period * (period + 1)) / 2) as f64;
314 let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
315 let denom_inv = 1.0 / (period_f * x2_sum - x_sum * x_sum);
316 let inv_period = 1.0 / period_f;
317
318 let mut y_sum = 0.0;
319 let mut xy_sum = 0.0;
320 let init_slice = &data[first..first + period - 1];
321 let mut k = 1usize;
322 for &v in init_slice.iter() {
323 y_sum += v;
324 xy_sum += (k as f64) * v;
325 k += 1;
326 }
327
328 let len = data.len();
329 let mut idx = first + period - 1;
330 let mut old_idx = first;
331 unsafe {
332 while idx < len {
333 let new_val = *data.get_unchecked(idx);
334 y_sum += new_val;
335 xy_sum += new_val * period_f;
336
337 let b = (period_f * xy_sum - x_sum * y_sum) * denom_inv;
338 let a = (y_sum - b * x_sum) * inv_period;
339 *out.get_unchecked_mut(idx) = a + b * period_f;
340
341 xy_sum -= y_sum;
342 y_sum -= *data.get_unchecked(old_idx);
343
344 idx += 1;
345 old_idx += 1;
346 }
347 }
348}
349
350#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
351#[target_feature(enable = "avx2,fma")]
352pub unsafe fn linreg_avx2(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
353 use core::arch::x86_64::*;
354
355 let pf = period as f64;
356 let x_sum = ((period * (period + 1)) / 2) as f64;
357 let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
358 let denom_inv = 1.0 / (pf * x2_sum - x_sum * x_sum);
359 let inv_pf = 1.0 / pf;
360
361 let mut y_sum = 0.0f64;
362 let mut xy_sum = 0.0f64;
363
364 let init_len = period.saturating_sub(1);
365 let mut p = data.as_ptr().add(first);
366
367 let vec_blocks = init_len / 4;
368 if vec_blocks > 0 {
369 let base = _mm256_setr_pd(1.0, 2.0, 3.0, 4.0);
370 let mut off = 0.0f64;
371 let mut y_acc = _mm256_set1_pd(0.0);
372 let mut xy_acc = _mm256_set1_pd(0.0);
373
374 for _ in 0..vec_blocks {
375 let y = _mm256_loadu_pd(p);
376 let x = _mm256_add_pd(base, _mm256_set1_pd(off));
377 y_acc = _mm256_add_pd(y_acc, y);
378 xy_acc = _mm256_fmadd_pd(y, x, xy_acc);
379 p = p.add(4);
380 off += 4.0;
381 }
382
383 let mut buf = [0.0f64; 4];
384 _mm256_storeu_pd(buf.as_mut_ptr(), y_acc);
385 y_sum += buf.iter().sum::<f64>();
386 _mm256_storeu_pd(buf.as_mut_ptr(), xy_acc);
387 xy_sum += buf.iter().sum::<f64>();
388 }
389
390 let tail = init_len & 3;
391 let mut k_off = (vec_blocks * 4 + 1) as f64;
392 for _ in 0..tail {
393 let v = *p;
394 y_sum += v;
395 xy_sum += k_off * v;
396 k_off += 1.0;
397 p = p.add(1);
398 }
399
400 let len = data.len();
401 let mut idx = first + period - 1;
402 let mut old_idx = first;
403 while idx < len {
404 let new_v = *data.get_unchecked(idx);
405 y_sum += new_v;
406 xy_sum = f64::mul_add(pf, new_v, xy_sum);
407
408 let b = (pf * xy_sum - x_sum * y_sum) * denom_inv;
409 let a = (y_sum - b * x_sum) * inv_pf;
410 *out.get_unchecked_mut(idx) = a + b * pf;
411
412 xy_sum -= y_sum;
413 y_sum -= *data.get_unchecked(old_idx);
414 idx += 1;
415 old_idx += 1;
416 }
417}
418
419#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
420#[target_feature(enable = "avx512f,avx512dq,fma")]
421pub unsafe fn linreg_avx512(data: &[f64], period: usize, first: usize, out: &mut [f64]) {
422 use core::arch::x86_64::*;
423
424 let pf = period as f64;
425 let x_sum = ((period * (period + 1)) / 2) as f64;
426 let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
427 let denom_inv = 1.0 / (pf * x2_sum - x_sum * x_sum);
428 let inv_pf = 1.0 / pf;
429
430 let mut y_sum = 0.0f64;
431 let mut xy_sum = 0.0f64;
432
433 let init_len = period.saturating_sub(1);
434 let mut p = data.as_ptr().add(first);
435
436 let vec_blocks = init_len / 8;
437 if vec_blocks > 0 {
438 let base = _mm512_setr_pd(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
439 let mut off = 0.0f64;
440 let mut y_acc = _mm512_set1_pd(0.0);
441 let mut xy_acc = _mm512_set1_pd(0.0);
442
443 for _ in 0..vec_blocks {
444 let y = _mm512_loadu_pd(p);
445 let x = _mm512_add_pd(base, _mm512_set1_pd(off));
446 y_acc = _mm512_add_pd(y_acc, y);
447 xy_acc = _mm512_fmadd_pd(y, x, xy_acc);
448 p = p.add(8);
449 off += 8.0;
450 }
451
452 let mut buf = [0.0f64; 8];
453 _mm512_storeu_pd(buf.as_mut_ptr(), y_acc);
454 y_sum += buf.iter().sum::<f64>();
455 _mm512_storeu_pd(buf.as_mut_ptr(), xy_acc);
456 xy_sum += buf.iter().sum::<f64>();
457 }
458
459 let tail = init_len & 7;
460 let mut k_off = (vec_blocks * 8 + 1) as f64;
461 for _ in 0..tail {
462 let v = *p;
463 y_sum += v;
464 xy_sum += k_off * v;
465 k_off += 1.0;
466 p = p.add(1);
467 }
468
469 let len = data.len();
470 let mut idx = first + period - 1;
471 let mut old_idx = first;
472 while idx < len {
473 let new_v = *data.get_unchecked(idx);
474 y_sum += new_v;
475 xy_sum = f64::mul_add(pf, new_v, xy_sum);
476
477 let b = (pf * xy_sum - x_sum * y_sum) * denom_inv;
478 let a = (y_sum - b * x_sum) * inv_pf;
479 *out.get_unchecked_mut(idx) = a + b * pf;
480
481 xy_sum -= y_sum;
482 y_sum -= *data.get_unchecked(old_idx);
483 idx += 1;
484 old_idx += 1;
485 }
486}
487
488#[derive(Clone, Debug)]
489pub struct LinRegBatchRange {
490 pub period: (usize, usize, usize),
491}
492
493impl Default for LinRegBatchRange {
494 fn default() -> Self {
495 Self {
496 period: (14, 263, 1),
497 }
498 }
499}
500
501#[derive(Clone, Debug, Default)]
502pub struct LinRegBatchBuilder {
503 range: LinRegBatchRange,
504 kernel: Kernel,
505}
506
507impl LinRegBatchBuilder {
508 pub fn new() -> Self {
509 Self::default()
510 }
511 pub fn kernel(mut self, k: Kernel) -> Self {
512 self.kernel = k;
513 self
514 }
515 pub fn period_range(mut self, start: usize, end: usize, step: usize) -> Self {
516 self.range.period = (start, end, step);
517 self
518 }
519 pub fn period_static(mut self, p: usize) -> Self {
520 self.range.period = (p, p, 0);
521 self
522 }
523 pub fn apply_slice(self, data: &[f64]) -> Result<LinRegBatchOutput, LinRegError> {
524 linreg_batch_with_kernel(data, &self.range, self.kernel)
525 }
526 pub fn with_default_slice(data: &[f64], k: Kernel) -> Result<LinRegBatchOutput, LinRegError> {
527 LinRegBatchBuilder::new().kernel(k).apply_slice(data)
528 }
529 pub fn apply_candles(self, c: &Candles, src: &str) -> Result<LinRegBatchOutput, LinRegError> {
530 let slice = source_type(c, src);
531 self.apply_slice(slice)
532 }
533 pub fn with_default_candles(c: &Candles) -> Result<LinRegBatchOutput, LinRegError> {
534 LinRegBatchBuilder::new()
535 .kernel(Kernel::Auto)
536 .apply_candles(c, "close")
537 }
538}
539
540#[derive(Clone, Debug)]
541#[cfg_attr(
542 all(target_arch = "wasm32", feature = "wasm"),
543 derive(Serialize, Deserialize)
544)]
545pub struct LinRegBatchOutput {
546 pub values: Vec<f64>,
547 pub combos: Vec<LinRegParams>,
548 pub rows: usize,
549 pub cols: usize,
550}
551
552impl LinRegBatchOutput {
553 pub fn row_for_params(&self, p: &LinRegParams) -> Option<usize> {
554 self.combos
555 .iter()
556 .position(|c| c.period.unwrap_or(14) == p.period.unwrap_or(14))
557 }
558 pub fn values_for(&self, p: &LinRegParams) -> Option<&[f64]> {
559 self.row_for_params(p).map(|row| {
560 let start = row * self.cols;
561 &self.values[start..start + self.cols]
562 })
563 }
564}
565
566pub fn linreg_batch_with_kernel(
567 data: &[f64],
568 sweep: &LinRegBatchRange,
569 k: Kernel,
570) -> Result<LinRegBatchOutput, LinRegError> {
571 let kernel = match k {
572 Kernel::Auto => Kernel::ScalarBatch,
573 other if other.is_batch() => other,
574 _ => return Err(LinRegError::InvalidKernelForBatch(k)),
575 };
576 let simd = match kernel {
577 Kernel::Avx512Batch => Kernel::Avx512,
578 Kernel::Avx2Batch => Kernel::Avx2,
579 Kernel::ScalarBatch => Kernel::Scalar,
580 _ => unreachable!(),
581 };
582 linreg_batch_par_slice(data, sweep, simd)
583}
584
585#[inline(always)]
586pub fn linreg_batch_slice(
587 data: &[f64],
588 sweep: &LinRegBatchRange,
589 kern: Kernel,
590) -> Result<LinRegBatchOutput, LinRegError> {
591 linreg_batch_inner(data, sweep, kern, false)
592}
593
594#[inline(always)]
595pub fn linreg_batch_par_slice(
596 data: &[f64],
597 sweep: &LinRegBatchRange,
598 kern: Kernel,
599) -> Result<LinRegBatchOutput, LinRegError> {
600 linreg_batch_inner(data, sweep, kern, true)
601}
602
603#[inline(always)]
604fn linreg_batch_inner(
605 data: &[f64],
606 sweep: &LinRegBatchRange,
607 kern: Kernel,
608 parallel: bool,
609) -> Result<LinRegBatchOutput, LinRegError> {
610 let combos = expand_grid_linreg(sweep);
611 if combos.is_empty() {
612 let (s, e, t) = sweep.period;
613 return Err(LinRegError::InvalidRange {
614 start: s,
615 end: e,
616 step: t,
617 });
618 }
619 let first = data
620 .iter()
621 .position(|x| !x.is_nan())
622 .ok_or(LinRegError::AllValuesNaN)?;
623 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
624 if data.len() - first < max_p {
625 return Err(LinRegError::NotEnoughValidData {
626 needed: max_p,
627 valid: data.len() - first,
628 });
629 }
630
631 let rows = combos.len();
632 let cols = data.len();
633 let _ = rows
634 .checked_mul(cols)
635 .ok_or(LinRegError::ArithmeticOverflow { what: "rows*cols" })?;
636
637 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
638
639 let mut raw = make_uninit_matrix(rows, cols);
640 unsafe { init_matrix_prefixes(&mut raw, cols, &warm) };
641
642 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
643 let period = combos[row].period.unwrap();
644
645 let out_row =
646 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
647
648 match kern {
649 Kernel::Scalar => linreg_row_scalar(data, first, period, out_row),
650 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
651 Kernel::Avx2 => linreg_row_avx2(data, first, period, out_row),
652 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
653 Kernel::Avx512 => linreg_row_avx512(data, first, period, out_row),
654 _ => unreachable!(),
655 }
656 };
657
658 if parallel {
659 #[cfg(not(target_arch = "wasm32"))]
660 {
661 raw.par_chunks_mut(cols)
662 .enumerate()
663 .for_each(|(row, slice)| do_row(row, slice));
664 }
665
666 #[cfg(target_arch = "wasm32")]
667 {
668 for (row, slice) in raw.chunks_mut(cols).enumerate() {
669 do_row(row, slice);
670 }
671 }
672 } else {
673 for (row, slice) in raw.chunks_mut(cols).enumerate() {
674 do_row(row, slice);
675 }
676 }
677
678 let values: Vec<f64> = unsafe { std::mem::transmute(raw) };
679
680 Ok(LinRegBatchOutput {
681 values,
682 combos,
683 rows,
684 cols,
685 })
686}
687
688pub fn linreg_batch_inner_into(
689 data: &[f64],
690 sweep: &LinRegBatchRange,
691 kern: Kernel,
692 parallel: bool,
693 out: &mut [f64],
694) -> Result<Vec<LinRegParams>, LinRegError> {
695 let combos = expand_grid_linreg(sweep);
696 if combos.is_empty() {
697 let (s, e, t) = sweep.period;
698 return Err(LinRegError::InvalidRange {
699 start: s,
700 end: e,
701 step: t,
702 });
703 }
704 let first = data
705 .iter()
706 .position(|x| !x.is_nan())
707 .ok_or(LinRegError::AllValuesNaN)?;
708 let max_p = combos.iter().map(|c| c.period.unwrap()).max().unwrap();
709 if data.len() - first < max_p {
710 return Err(LinRegError::NotEnoughValidData {
711 needed: max_p,
712 valid: data.len() - first,
713 });
714 }
715
716 let rows = combos.len();
717 let cols = data.len();
718 let expected = rows
719 .checked_mul(cols)
720 .ok_or(LinRegError::ArithmeticOverflow { what: "rows*cols" })?;
721
722 if out.len() != expected {
723 return Err(LinRegError::OutputLengthMismatch {
724 expected,
725 got: out.len(),
726 });
727 }
728
729 let out_uninit = unsafe {
730 std::slice::from_raw_parts_mut(out.as_mut_ptr() as *mut MaybeUninit<f64>, out.len())
731 };
732
733 let warm: Vec<usize> = combos.iter().map(|c| first + c.period.unwrap()).collect();
734
735 unsafe { init_matrix_prefixes(out_uninit, cols, &warm) };
736
737 let do_row = |row: usize, dst_mu: &mut [MaybeUninit<f64>]| unsafe {
738 let period = combos[row].period.unwrap();
739
740 let out_row =
741 core::slice::from_raw_parts_mut(dst_mu.as_mut_ptr() as *mut f64, dst_mu.len());
742
743 match kern {
744 Kernel::Scalar => linreg_row_scalar(data, first, period, out_row),
745 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
746 Kernel::Avx2 => linreg_row_avx2(data, first, period, out_row),
747 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
748 Kernel::Avx512 => linreg_row_avx512(data, first, period, out_row),
749 _ => unreachable!(),
750 }
751 };
752
753 if parallel {
754 #[cfg(not(target_arch = "wasm32"))]
755 {
756 out_uninit
757 .par_chunks_mut(cols)
758 .enumerate()
759 .for_each(|(row, slice)| do_row(row, slice));
760 }
761
762 #[cfg(target_arch = "wasm32")]
763 {
764 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
765 do_row(row, slice);
766 }
767 }
768 } else {
769 for (row, slice) in out_uninit.chunks_mut(cols).enumerate() {
770 do_row(row, slice);
771 }
772 }
773
774 Ok(combos)
775}
776
777#[inline(always)]
778unsafe fn linreg_row_scalar(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
779 linreg_scalar(data, period, first, out)
780}
781
782#[inline(always)]
783unsafe fn linreg_row_prefix_sums_scalar(
784 data: &[f64],
785 first: usize,
786 period: usize,
787 out: &mut [f64],
788 s: &[f64],
789 sp: &[f64],
790) {
791 let len = data.len();
792 let pf = period as f64;
793 let x_sum = ((period * (period + 1)) / 2) as f64;
794 let x2_sum = ((period * (period + 1) * (2 * period + 1)) / 6) as f64;
795 let denom_inv = 1.0 / (pf * x2_sum - x_sum * x_sum);
796 let inv_pf = 1.0 / pf;
797
798 let mut idx = first + period - 1;
799 while idx < len {
800 let pos = idx - first + 1;
801 let y_sum = s.get_unchecked(pos) - s.get_unchecked(pos - period);
802
803 let xy_sum = (sp.get_unchecked(pos) - sp.get_unchecked(pos - period))
804 - ((pos - period) as f64) * y_sum;
805
806 let b = (pf * xy_sum - x_sum * y_sum) * denom_inv;
807 let a = (y_sum - b * x_sum) * inv_pf;
808 *out.get_unchecked_mut(idx) = a + b * pf;
809
810 idx += 1;
811 }
812}
813
814#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
815#[inline(always)]
816unsafe fn linreg_row_avx2(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
817 linreg_avx2(data, period, first, out)
818}
819
820#[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
821#[inline(always)]
822unsafe fn linreg_row_avx512(data: &[f64], first: usize, period: usize, out: &mut [f64]) {
823 linreg_avx512(data, period, first, out)
824}
825
826#[derive(Debug, Clone)]
827pub struct LinRegStream {
828 period: usize,
829 buffer: Vec<f64>,
830 head: usize,
831 filled: bool,
832 x_sum: f64,
833 x2_sum: f64,
834}
835
836impl LinRegStream {
837 pub fn try_new(params: LinRegParams) -> Result<Self, LinRegError> {
838 let period = params.period.unwrap_or(14);
839 if period == 0 {
840 return Err(LinRegError::InvalidPeriod {
841 period,
842 data_len: 0,
843 });
844 }
845 let mut x_sum = 0.0;
846 let mut x2_sum = 0.0;
847 for i in 1..=period {
848 let xi = i as f64;
849 x_sum += xi;
850 x2_sum += xi * xi;
851 }
852 Ok(Self {
853 period,
854 buffer: vec![f64::NAN; period],
855 head: 0,
856 filled: false,
857 x_sum,
858 x2_sum,
859 })
860 }
861
862 #[inline(always)]
863 pub fn update(&mut self, value: f64) -> Option<f64> {
864 self.buffer[self.head] = value;
865 self.head = (self.head + 1) % self.period;
866 if !self.filled && self.head == 0 {
867 self.filled = true;
868 }
869 if !self.filled {
870 return None;
871 }
872 Some(self.dot_ring())
873 }
874
875 #[inline(always)]
876 fn dot_ring(&self) -> f64 {
877 let mut y_sum = 0.0;
878 let mut xy_sum = 0.0;
879 for (i, &y) in
880 (1..=self.period).zip(self.buffer.iter().cycle().skip(self.head).take(self.period))
881 {
882 y_sum += y;
883 xy_sum += y * (i as f64);
884 }
885 let pf = self.period as f64;
886 let bd = 1.0 / (pf * self.x2_sum - self.x_sum * self.x_sum);
887 let b = (pf * xy_sum - self.x_sum * y_sum) * bd;
888 let a = (y_sum - b * self.x_sum) / pf;
889 a + b * pf
890 }
891}
892
893#[inline(always)]
894fn round_up8(x: usize) -> usize {
895 (x + 7) & !7
896}
897
898#[inline(always)]
899pub fn expand_grid_linreg(r: &LinRegBatchRange) -> Vec<LinRegParams> {
900 fn axis_usize((start, end, step): (usize, usize, usize)) -> Vec<usize> {
901 if step == 0 || start == end {
902 return vec![start];
903 }
904 let (lo, hi) = if start <= end {
905 (start, end)
906 } else {
907 (end, start)
908 };
909 let mut v = Vec::new();
910 let mut x = lo;
911 while x <= hi {
912 v.push(x);
913 match x.checked_add(step) {
914 Some(nx) => x = nx,
915 None => break,
916 }
917 }
918 v
919 }
920 let periods = axis_usize(r.period);
921 let mut out = Vec::with_capacity(periods.len());
922 for &p in &periods {
923 out.push(LinRegParams { period: Some(p) });
924 }
925 out
926}
927
928#[cfg(test)]
929mod tests {
930 use super::*;
931 use crate::skip_if_unsupported;
932 use crate::utilities::data_loader::read_candles_from_csv;
933
934 fn check_linreg_accuracy(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
935 skip_if_unsupported!(kernel, test_name);
936 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
937 let candles = read_candles_from_csv(file_path)?;
938 let close_prices = candles.select_candle_field("close")?;
939 let params = LinRegParams { period: Some(14) };
940 let input = LinRegInput::from_candles(&candles, "close", params);
941 let linreg_result = linreg_with_kernel(&input, kernel)?;
942 let expected_last_five = [
943 58929.37142857143,
944 58899.42857142857,
945 58918.857142857145,
946 59100.6,
947 58987.94285714286,
948 ];
949 assert!(linreg_result.values.len() >= 5);
950 assert_eq!(linreg_result.values.len(), close_prices.len());
951 let start_index = linreg_result.values.len() - 5;
952 let result_last_five = &linreg_result.values[start_index..];
953 for (i, &value) in result_last_five.iter().enumerate() {
954 let expected_value = expected_last_five[i];
955 assert!(
956 (value - expected_value).abs() < 1e-1,
957 "Mismatch at index {}: expected {}, got {}",
958 i,
959 expected_value,
960 value
961 );
962 }
963 Ok(())
964 }
965
966 fn check_linreg_partial_params(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
967 skip_if_unsupported!(kernel, test_name);
968 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
969 let candles = read_candles_from_csv(file_path)?;
970 let default_params = LinRegParams { period: None };
971 let input = LinRegInput::from_candles(&candles, "close", default_params);
972 let output = linreg_with_kernel(&input, kernel)?;
973 assert_eq!(output.values.len(), candles.close.len());
974 Ok(())
975 }
976
977 fn check_linreg_default_candles(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
978 skip_if_unsupported!(kernel, test_name);
979 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
980 let candles = read_candles_from_csv(file_path)?;
981 let input = LinRegInput::with_default_candles(&candles);
982 match input.data {
983 LinRegData::Candles { source, .. } => assert_eq!(source, "close"),
984 _ => panic!("Expected LinRegData::Candles"),
985 }
986 let output = linreg_with_kernel(&input, kernel)?;
987 assert_eq!(output.values.len(), candles.close.len());
988 Ok(())
989 }
990
991 #[test]
992 fn test_linreg_into_matches_api() -> Result<(), Box<dyn Error>> {
993 let mut data = Vec::with_capacity(5 + 256);
994 for _ in 0..5 {
995 data.push(f64::NAN);
996 }
997 for i in 0..256u32 {
998 let x = i as f64;
999 let v = (x * 0.137).sin() * 3.0 + x * 0.25;
1000 data.push(v);
1001 }
1002
1003 let params = LinRegParams { period: Some(14) };
1004 let input = LinRegInput::from_slice(&data, params);
1005
1006 let baseline = linreg(&input)?.values;
1007
1008 let mut out = vec![0.0; data.len()];
1009 linreg_into(&input, &mut out)?;
1010
1011 assert_eq!(out.len(), baseline.len());
1012 for (i, (&a, &b)) in out.iter().zip(baseline.iter()).enumerate() {
1013 if a.is_nan() || b.is_nan() {
1014 assert!(
1015 a.is_nan() && b.is_nan(),
1016 "NaN parity mismatch at index {}",
1017 i
1018 );
1019 } else {
1020 let diff = (a - b).abs();
1021 assert!(
1022 diff <= 1e-12,
1023 "Value mismatch at index {}: {} vs {} (diff={})",
1024 i,
1025 a,
1026 b,
1027 diff
1028 );
1029 }
1030 }
1031 Ok(())
1032 }
1033
1034 fn check_linreg_zero_period(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1035 skip_if_unsupported!(kernel, test_name);
1036 let input_data = [10.0, 20.0, 30.0];
1037 let params = LinRegParams { period: Some(0) };
1038 let input = LinRegInput::from_slice(&input_data, params);
1039 let res = linreg_with_kernel(&input, kernel);
1040 assert!(
1041 res.is_err(),
1042 "[{}] LINREG should fail with zero period",
1043 test_name
1044 );
1045 Ok(())
1046 }
1047
1048 fn check_linreg_period_exceeds_length(
1049 test_name: &str,
1050 kernel: Kernel,
1051 ) -> Result<(), Box<dyn Error>> {
1052 skip_if_unsupported!(kernel, test_name);
1053 let data_small = [10.0, 20.0, 30.0];
1054 let params = LinRegParams { period: Some(10) };
1055 let input = LinRegInput::from_slice(&data_small, params);
1056 let res = linreg_with_kernel(&input, kernel);
1057 assert!(
1058 res.is_err(),
1059 "[{}] LINREG should fail with period exceeding length",
1060 test_name
1061 );
1062 Ok(())
1063 }
1064
1065 fn check_linreg_very_small_dataset(
1066 test_name: &str,
1067 kernel: Kernel,
1068 ) -> Result<(), Box<dyn Error>> {
1069 skip_if_unsupported!(kernel, test_name);
1070 let single_point = [42.0];
1071 let params = LinRegParams { period: Some(14) };
1072 let input = LinRegInput::from_slice(&single_point, params);
1073 let res = linreg_with_kernel(&input, kernel);
1074 assert!(
1075 res.is_err(),
1076 "[{}] LINREG should fail with insufficient data",
1077 test_name
1078 );
1079 Ok(())
1080 }
1081
1082 fn check_linreg_reinput(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1083 skip_if_unsupported!(kernel, test_name);
1084 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1085 let candles = read_candles_from_csv(file_path)?;
1086 let first_params = LinRegParams { period: Some(14) };
1087 let first_input = LinRegInput::from_candles(&candles, "close", first_params);
1088 let first_result = linreg_with_kernel(&first_input, kernel)?;
1089 let second_params = LinRegParams { period: Some(10) };
1090 let second_input = LinRegInput::from_slice(&first_result.values, second_params);
1091 let second_result = linreg_with_kernel(&second_input, kernel)?;
1092 assert_eq!(second_result.values.len(), first_result.values.len());
1093 Ok(())
1094 }
1095
1096 fn check_linreg_nan_handling(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1097 skip_if_unsupported!(kernel, test_name);
1098 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1099 let candles = read_candles_from_csv(file_path)?;
1100 let input = LinRegInput::from_candles(&candles, "close", LinRegParams { period: Some(14) });
1101 let res = linreg_with_kernel(&input, kernel)?;
1102 assert_eq!(res.values.len(), candles.close.len());
1103 if res.values.len() > 240 {
1104 for (i, &val) in res.values[240..].iter().enumerate() {
1105 assert!(
1106 !val.is_nan(),
1107 "[{}] Found unexpected NaN at out-index {}",
1108 test_name,
1109 240 + i
1110 );
1111 }
1112 }
1113 Ok(())
1114 }
1115
1116 fn check_linreg_streaming(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1117 skip_if_unsupported!(kernel, test_name);
1118 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1119 let candles = read_candles_from_csv(file_path)?;
1120 let period = 14;
1121 let input = LinRegInput::from_candles(
1122 &candles,
1123 "close",
1124 LinRegParams {
1125 period: Some(period),
1126 },
1127 );
1128 let batch_output = linreg_with_kernel(&input, kernel)?.values;
1129 let mut stream = LinRegStream::try_new(LinRegParams {
1130 period: Some(period),
1131 })?;
1132 let mut stream_values = Vec::with_capacity(candles.close.len());
1133 for &price in &candles.close {
1134 match stream.update(price) {
1135 Some(val) => stream_values.push(val),
1136 None => stream_values.push(f64::NAN),
1137 }
1138 }
1139 assert_eq!(batch_output.len(), stream_values.len());
1140 for (i, (&b, &s)) in batch_output.iter().zip(stream_values.iter()).enumerate() {
1141 if b.is_nan() && s.is_nan() {
1142 continue;
1143 }
1144 let diff = (b - s).abs();
1145 assert!(
1146 diff < 1e-6,
1147 "[{}] LINREG streaming mismatch at idx {}: batch={}, stream={}, diff={}",
1148 test_name,
1149 i,
1150 b,
1151 s,
1152 diff
1153 );
1154 }
1155 Ok(())
1156 }
1157
1158 macro_rules! generate_all_linreg_tests {
1159 ($($test_fn:ident),*) => {
1160 paste::paste! {
1161 $(#[test] fn [<$test_fn _scalar_f64>]() { let _ = $test_fn(stringify!([<$test_fn _scalar_f64>]), Kernel::Scalar); })*
1162 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1163 $(
1164 #[test] fn [<$test_fn _avx2_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx2_f64>]), Kernel::Avx2); }
1165 #[test] fn [<$test_fn _avx512_f64>]() { let _ = $test_fn(stringify!([<$test_fn _avx512_f64>]), Kernel::Avx512); }
1166 )*
1167 }
1168 }
1169 }
1170
1171 #[cfg(debug_assertions)]
1172 fn check_linreg_no_poison(test_name: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1173 skip_if_unsupported!(kernel, test_name);
1174
1175 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1176 let candles = read_candles_from_csv(file_path)?;
1177
1178 let test_periods = vec![2, 5, 10, 14, 20, 30, 50, 100, 200];
1179 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1180
1181 for period in &test_periods {
1182 for source in &test_sources {
1183 let input = LinRegInput::from_candles(
1184 &candles,
1185 source,
1186 LinRegParams {
1187 period: Some(*period),
1188 },
1189 );
1190 let output = linreg_with_kernel(&input, kernel)?;
1191
1192 for (i, &val) in output.values.iter().enumerate() {
1193 if val.is_nan() {
1194 continue;
1195 }
1196
1197 let bits = val.to_bits();
1198
1199 if bits == 0x11111111_11111111 {
1200 panic!(
1201 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1202 test_name, val, bits, i, period, source
1203 );
1204 }
1205
1206 if bits == 0x22222222_22222222 {
1207 panic!(
1208 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at index {} with period={}, source={}",
1209 test_name, val, bits, i, period, source
1210 );
1211 }
1212
1213 if bits == 0x33333333_33333333 {
1214 panic!(
1215 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at index {} with period={}, source={}",
1216 test_name, val, bits, i, period, source
1217 );
1218 }
1219 }
1220 }
1221 }
1222
1223 Ok(())
1224 }
1225
1226 #[cfg(not(debug_assertions))]
1227 fn check_linreg_no_poison(_test_name: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1228 Ok(())
1229 }
1230
1231 #[cfg(feature = "proptest")]
1232 fn check_linreg_property(
1233 test_name: &str,
1234 kernel: Kernel,
1235 ) -> Result<(), Box<dyn std::error::Error>> {
1236 use proptest::prelude::*;
1237 skip_if_unsupported!(kernel, test_name);
1238
1239 let file_path = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1240 let candles = read_candles_from_csv(file_path)?;
1241 let close_data = &candles.close;
1242
1243 let strat = (
1244 2usize..=50,
1245 0usize..close_data.len().saturating_sub(200),
1246 100usize..=200,
1247 );
1248
1249 proptest::test_runner::TestRunner::default()
1250 .run(&strat, |(period, start_idx, slice_len)| {
1251 let end_idx = (start_idx + slice_len).min(close_data.len());
1252 if end_idx <= start_idx || end_idx - start_idx < period + 10 {
1253 return Ok(());
1254 }
1255
1256 let data_slice = &close_data[start_idx..end_idx];
1257 let params = LinRegParams {
1258 period: Some(period),
1259 };
1260 let input = LinRegInput::from_slice(data_slice, params.clone());
1261
1262 let result = linreg_with_kernel(&input, kernel);
1263
1264 let scalar_result = linreg_with_kernel(&input, Kernel::Scalar);
1265
1266 match (result, scalar_result) {
1267 (Ok(LinRegOutput { values: out }), Ok(LinRegOutput { values: ref_out })) => {
1268 prop_assert_eq!(out.len(), data_slice.len());
1269 prop_assert_eq!(ref_out.len(), data_slice.len());
1270
1271 let first = data_slice.iter().position(|x| !x.is_nan()).unwrap_or(0);
1272 let expected_warmup = first + period;
1273
1274 let first_valid = out.iter().position(|x| !x.is_nan());
1275 if let Some(first_idx) = first_valid {
1276 prop_assert_eq!(
1277 first_idx,
1278 expected_warmup,
1279 "First valid at {} but expected warmup is {}",
1280 first_idx,
1281 expected_warmup
1282 );
1283
1284 for i in 0..first_idx {
1285 prop_assert!(
1286 out[i].is_nan(),
1287 "Expected NaN at index {} during warmup, got {}",
1288 i,
1289 out[i]
1290 );
1291 }
1292 }
1293
1294 for i in 0..out.len() {
1295 let y = out[i];
1296 let r = ref_out[i];
1297
1298 if y.is_nan() {
1299 prop_assert!(
1300 r.is_nan(),
1301 "Kernel mismatch at {}: {} vs {}",
1302 i,
1303 y,
1304 r
1305 );
1306 continue;
1307 }
1308
1309 prop_assert!(y.is_finite(), "Non-finite value at index {}: {}", i, y);
1310
1311 let ulps_diff = if y == r {
1312 0
1313 } else {
1314 let y_bits = y.to_bits();
1315 let r_bits = r.to_bits();
1316 ((y_bits as i64) - (r_bits as i64)).unsigned_abs()
1317 };
1318
1319 prop_assert!(
1320 ulps_diff <= 3 || (y - r).abs() < 1e-9,
1321 "Kernel mismatch at {}: {} vs {} (diff: {}, ulps: {})",
1322 i,
1323 y,
1324 r,
1325 (y - r).abs(),
1326 ulps_diff
1327 );
1328 }
1329
1330 if first_valid.is_some() {
1331 let mut linear_data = vec![0.0; period + 5];
1332 for i in 0..linear_data.len() {
1333 linear_data[i] = 100.0 + i as f64 * 2.0;
1334 }
1335 let linear_input =
1336 LinRegInput::from_slice(&linear_data, params.clone());
1337 if let Ok(LinRegOutput { values: linear_out }) =
1338 linreg_with_kernel(&linear_input, kernel)
1339 {
1340 for i in period..linear_data.len() {
1341 if !linear_out[i].is_nan() {
1342 let expected = 100.0 + (i + 1) as f64 * 2.0;
1343 prop_assert!(
1344 (linear_out[i] - expected).abs() < 1e-6,
1345 "Linear prediction error at {}: got {} expected {}",
1346 i,
1347 linear_out[i],
1348 expected
1349 );
1350 }
1351 }
1352 }
1353
1354 let constant_val = 42.0;
1355 let constant_data = vec![constant_val; period + 5];
1356 let const_input = LinRegInput::from_slice(&constant_data, params);
1357 if let Ok(LinRegOutput { values: const_out }) =
1358 linreg_with_kernel(&const_input, kernel)
1359 {
1360 for i in period..constant_data.len() {
1361 if !const_out[i].is_nan() {
1362 prop_assert!(
1363 (const_out[i] - constant_val).abs() < 1e-9,
1364 "Constant prediction error at {}: got {} expected {}",
1365 i,
1366 const_out[i],
1367 constant_val
1368 );
1369 }
1370 }
1371 }
1372
1373 for i in expected_warmup..out.len() {
1374 if !out[i].is_nan() {
1375 let window_start = i + 1 - period;
1376 let window_end = i + 1;
1377 let window = &data_slice[window_start..window_end];
1378
1379 let min_val =
1380 window.iter().copied().fold(f64::INFINITY, f64::min);
1381 let max_val =
1382 window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
1383
1384 let range = max_val - min_val;
1385 let lower_bound = min_val - range;
1386 let upper_bound = max_val + range;
1387
1388 prop_assert!(
1389 out[i] >= lower_bound && out[i] <= upper_bound,
1390 "Output {} at index {} outside reasonable bounds [{}, {}]",
1391 out[i],
1392 i,
1393 lower_bound,
1394 upper_bound
1395 );
1396 }
1397 }
1398 }
1399
1400 Ok(())
1401 }
1402 (Err(e1), Err(e2)) => {
1403 prop_assert_eq!(
1404 std::mem::discriminant(&e1),
1405 std::mem::discriminant(&e2),
1406 "Different error types: {:?} vs {:?}",
1407 e1,
1408 e2
1409 );
1410 Ok(())
1411 }
1412 _ => {
1413 prop_assert!(
1414 false,
1415 "Kernel consistency failed - one succeeded, one failed"
1416 );
1417 Ok(())
1418 }
1419 }
1420 })
1421 .map_err(|e| e.into())
1422 }
1423
1424 generate_all_linreg_tests!(
1425 check_linreg_accuracy,
1426 check_linreg_partial_params,
1427 check_linreg_default_candles,
1428 check_linreg_zero_period,
1429 check_linreg_period_exceeds_length,
1430 check_linreg_very_small_dataset,
1431 check_linreg_reinput,
1432 check_linreg_nan_handling,
1433 check_linreg_streaming,
1434 check_linreg_no_poison
1435 );
1436
1437 #[cfg(feature = "proptest")]
1438 generate_all_linreg_tests!(check_linreg_property);
1439
1440 fn check_batch_default_row(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1441 skip_if_unsupported!(kernel, test);
1442 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1443 let c = read_candles_from_csv(file)?;
1444 let output = LinRegBatchBuilder::new()
1445 .kernel(kernel)
1446 .apply_candles(&c, "close")?;
1447 let def = LinRegParams::default();
1448 let row = output.values_for(&def).expect("default row missing");
1449 assert_eq!(row.len(), c.close.len());
1450 Ok(())
1451 }
1452
1453 macro_rules! gen_batch_tests {
1454 ($fn_name:ident) => {
1455 paste::paste! {
1456 #[test] fn [<$fn_name _scalar>]() {
1457 let _ = $fn_name(stringify!([<$fn_name _scalar>]), Kernel::ScalarBatch);
1458 }
1459 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1460 #[test] fn [<$fn_name _avx2>]() {
1461 let _ = $fn_name(stringify!([<$fn_name _avx2>]), Kernel::Avx2Batch);
1462 }
1463 #[cfg(all(feature = "nightly-avx", target_arch = "x86_64"))]
1464 #[test] fn [<$fn_name _avx512>]() {
1465 let _ = $fn_name(stringify!([<$fn_name _avx512>]), Kernel::Avx512Batch);
1466 }
1467 #[test] fn [<$fn_name _auto_detect>]() {
1468 let _ = $fn_name(stringify!([<$fn_name _auto_detect>]), Kernel::Auto);
1469 }
1470 }
1471 };
1472 }
1473
1474 #[cfg(debug_assertions)]
1475 fn check_batch_no_poison(test: &str, kernel: Kernel) -> Result<(), Box<dyn Error>> {
1476 skip_if_unsupported!(kernel, test);
1477
1478 let file = "src/data/2018-09-01-2024-Bitfinex_Spot-4h.csv";
1479 let c = read_candles_from_csv(file)?;
1480
1481 let test_sources = vec!["open", "high", "low", "close", "hl2", "hlc3", "ohlc4"];
1482
1483 for source in &test_sources {
1484 let output = LinRegBatchBuilder::new()
1485 .kernel(kernel)
1486 .period_range(2, 200, 3)
1487 .apply_candles(&c, source)?;
1488
1489 for (idx, &val) in output.values.iter().enumerate() {
1490 if val.is_nan() {
1491 continue;
1492 }
1493
1494 let bits = val.to_bits();
1495 let row = idx / output.cols;
1496 let col = idx % output.cols;
1497
1498 if bits == 0x11111111_11111111 {
1499 panic!(
1500 "[{}] Found alloc_with_nan_prefix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1501 test, val, bits, row, col, idx, source
1502 );
1503 }
1504
1505 if bits == 0x22222222_22222222 {
1506 panic!(
1507 "[{}] Found init_matrix_prefixes poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1508 test, val, bits, row, col, idx, source
1509 );
1510 }
1511
1512 if bits == 0x33333333_33333333 {
1513 panic!(
1514 "[{}] Found make_uninit_matrix poison value {} (0x{:016X}) at row {} col {} (flat index {}) with source={}",
1515 test, val, bits, row, col, idx, source
1516 );
1517 }
1518 }
1519 }
1520
1521 let edge_case_ranges = vec![(2, 5, 1), (190, 200, 2), (50, 100, 10)];
1522 for (start, end, step) in edge_case_ranges {
1523 let output = LinRegBatchBuilder::new()
1524 .kernel(kernel)
1525 .period_range(start, end, step)
1526 .apply_candles(&c, "close")?;
1527
1528 for (idx, &val) in output.values.iter().enumerate() {
1529 if val.is_nan() {
1530 continue;
1531 }
1532
1533 let bits = val.to_bits();
1534 let row = idx / output.cols;
1535 let col = idx % output.cols;
1536
1537 if bits == 0x11111111_11111111
1538 || bits == 0x22222222_22222222
1539 || bits == 0x33333333_33333333
1540 {
1541 panic!(
1542 "[{}] Found poison value {} (0x{:016X}) at row {} col {} with range ({},{},{})",
1543 test, val, bits, row, col, start, end, step
1544 );
1545 }
1546 }
1547 }
1548
1549 Ok(())
1550 }
1551
1552 #[cfg(not(debug_assertions))]
1553 fn check_batch_no_poison(_test: &str, _kernel: Kernel) -> Result<(), Box<dyn Error>> {
1554 Ok(())
1555 }
1556
1557 gen_batch_tests!(check_batch_default_row);
1558 gen_batch_tests!(check_batch_no_poison);
1559}
1560
1561#[cfg(feature = "python")]
1562use numpy::{PyArray1, PyArrayMethods, PyReadonlyArray1};
1563#[cfg(feature = "python")]
1564use pyo3::exceptions::PyValueError;
1565#[cfg(feature = "python")]
1566use pyo3::prelude::*;
1567
1568#[cfg(feature = "python")]
1569use numpy::IntoPyArray;
1570#[cfg(feature = "python")]
1571use pyo3::types::PyDict;
1572
1573#[cfg(feature = "python")]
1574#[pyfunction]
1575#[pyo3(name = "linreg", signature = (data, period, kernel=None))]
1576pub fn linreg_py<'py>(
1577 py: Python<'py>,
1578 data: PyReadonlyArray1<'py, f64>,
1579 period: usize,
1580 kernel: Option<&str>,
1581) -> PyResult<Bound<'py, PyArray1<f64>>> {
1582 use numpy::{IntoPyArray, PyArrayMethods};
1583
1584 let slice_in = data.as_slice()?;
1585 let kern = validate_kernel(kernel, false)?;
1586 let params = LinRegParams {
1587 period: Some(period),
1588 };
1589 let input = LinRegInput::from_slice(slice_in, params);
1590
1591 let result_vec: Vec<f64> = py
1592 .allow_threads(|| linreg_with_kernel(&input, kern).map(|o| o.values))
1593 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1594
1595 Ok(result_vec.into_pyarray(py))
1596}
1597
1598#[cfg(feature = "python")]
1599#[pyfunction]
1600#[pyo3(name = "linreg_batch", signature = (data, period_range, kernel=None))]
1601pub fn linreg_batch_py<'py>(
1602 py: Python<'py>,
1603 data: PyReadonlyArray1<'py, f64>,
1604 period_range: (usize, usize, usize),
1605 kernel: Option<&str>,
1606) -> PyResult<Bound<'py, PyDict>> {
1607 use numpy::{IntoPyArray, PyArray1, PyArrayMethods};
1608 use pyo3::types::PyDict;
1609
1610 let slice_in = data.as_slice()?;
1611 let kern = validate_kernel(kernel, true)?;
1612 let sweep = LinRegBatchRange {
1613 period: period_range,
1614 };
1615
1616 let combos = expand_grid_linreg(&sweep);
1617 let rows = combos.len();
1618 let cols = slice_in.len();
1619
1620 let out_arr = unsafe { PyArray1::<f64>::new(py, [rows * cols], false) };
1621 let slice_out = unsafe { out_arr.as_slice_mut()? };
1622
1623 let combos = py
1624 .allow_threads(|| {
1625 let kernel = match kern {
1626 Kernel::Auto => detect_best_batch_kernel(),
1627 k => k,
1628 };
1629 let simd = match kernel {
1630 Kernel::Avx512Batch => Kernel::Avx512,
1631 Kernel::Avx2Batch => Kernel::Avx2,
1632 Kernel::ScalarBatch => Kernel::Scalar,
1633 _ => kernel,
1634 };
1635
1636 linreg_batch_inner_into(slice_in, &sweep, simd, true, slice_out)
1637 })
1638 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1639
1640 let dict = PyDict::new(py);
1641 dict.set_item("values", out_arr.reshape((rows, cols))?)?;
1642 dict.set_item(
1643 "periods",
1644 combos
1645 .iter()
1646 .map(|p| p.period.unwrap() as u64)
1647 .collect::<Vec<_>>()
1648 .into_pyarray(py),
1649 )?;
1650
1651 Ok(dict)
1652}
1653
1654#[cfg(all(feature = "python", feature = "cuda"))]
1655#[pyfunction(name = "linreg_cuda_batch_dev")]
1656#[pyo3(signature = (data_f32, period_range, device_id=0))]
1657pub fn linreg_cuda_batch_dev_py<'py>(
1658 py: Python<'py>,
1659 data_f32: PyReadonlyArray1<'py, f32>,
1660 period_range: (usize, usize, usize),
1661 device_id: usize,
1662) -> PyResult<(DeviceArrayF32LinregPy, Bound<'py, PyDict>)> {
1663 use crate::cuda::cuda_available;
1664 use numpy::IntoPyArray;
1665 use pyo3::types::PyDict;
1666
1667 if !cuda_available() {
1668 return Err(PyValueError::new_err("CUDA not available"));
1669 }
1670
1671 let slice_in = data_f32.as_slice()?;
1672 let sweep = LinRegBatchRange {
1673 period: period_range,
1674 };
1675
1676 let (inner, combos, ctx, dev_id) = py.allow_threads(|| {
1677 let cuda = CudaLinreg::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1678 let ctx = cuda.ctx();
1679 let dev_id = cuda.device_id();
1680 let (dev_arr, cmb) = cuda
1681 .linreg_batch_dev(slice_in, &sweep)
1682 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1683 Ok::<_, PyErr>((dev_arr, cmb, ctx, dev_id))
1684 })?;
1685
1686 let dict = PyDict::new(py);
1687 let periods: Vec<u64> = combos.iter().map(|c| c.period.unwrap() as u64).collect();
1688 dict.set_item("periods", periods.into_pyarray(py))?;
1689 Ok((DeviceArrayF32LinregPy::new(inner, ctx, dev_id), dict))
1690}
1691
1692#[cfg(all(feature = "python", feature = "cuda"))]
1693#[pyfunction(name = "linreg_cuda_many_series_one_param_dev")]
1694#[pyo3(signature = (data_tm_f32, period, device_id=0))]
1695pub fn linreg_cuda_many_series_one_param_dev_py(
1696 py: Python<'_>,
1697 data_tm_f32: numpy::PyReadonlyArray2<'_, f32>,
1698 period: usize,
1699 device_id: usize,
1700) -> PyResult<DeviceArrayF32LinregPy> {
1701 use crate::cuda::cuda_available;
1702 use numpy::PyUntypedArrayMethods;
1703
1704 if !cuda_available() {
1705 return Err(PyValueError::new_err("CUDA not available"));
1706 }
1707
1708 let flat_in = data_tm_f32.as_slice()?;
1709 let rows = data_tm_f32.shape()[0];
1710 let cols = data_tm_f32.shape()[1];
1711 let params = LinRegParams {
1712 period: Some(period),
1713 };
1714
1715 let (inner, ctx, dev_id) = py.allow_threads(|| {
1716 let cuda = CudaLinreg::new(device_id).map_err(|e| PyValueError::new_err(e.to_string()))?;
1717 let ctx = cuda.ctx();
1718 let dev_id = cuda.device_id();
1719 let arr = cuda
1720 .linreg_multi_series_one_param_time_major_dev(flat_in, cols, rows, ¶ms)
1721 .map_err(|e| PyValueError::new_err(e.to_string()))?;
1722 Ok::<_, PyErr>((arr, ctx, dev_id))
1723 })?;
1724
1725 Ok(DeviceArrayF32LinregPy::new(inner, ctx, dev_id))
1726}
1727
1728#[cfg(all(feature = "python", feature = "cuda"))]
1729#[pyclass(
1730 module = "ta_indicators.cuda",
1731 name = "DeviceArrayF32Linreg",
1732 unsendable
1733)]
1734pub struct DeviceArrayF32LinregPy {
1735 pub(crate) inner: DeviceArrayF32,
1736 _ctx_guard: Arc<Context>,
1737 _device_id: u32,
1738}
1739
1740#[cfg(all(feature = "python", feature = "cuda"))]
1741#[pymethods]
1742impl DeviceArrayF32LinregPy {
1743 #[new]
1744 fn py_new() -> PyResult<Self> {
1745 Err(pyo3::exceptions::PyTypeError::new_err(
1746 "use factory methods from CUDA functions",
1747 ))
1748 }
1749
1750 #[getter]
1751 fn __cuda_array_interface__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
1752 let d = PyDict::new(py);
1753 let itemsize = std::mem::size_of::<f32>();
1754 d.set_item("shape", (self.inner.rows, self.inner.cols))?;
1755 d.set_item("typestr", "<f4")?;
1756 d.set_item("strides", (self.inner.cols * itemsize, itemsize))?;
1757 let size = self.inner.rows.saturating_mul(self.inner.cols);
1758 let ptr_val: usize = if size == 0 {
1759 0
1760 } else {
1761 self.inner.buf.as_device_ptr().as_raw() as usize
1762 };
1763 d.set_item("data", (ptr_val, false))?;
1764 d.set_item("version", 3)?;
1765 Ok(d)
1766 }
1767
1768 fn __dlpack_device__(&self) -> (i32, i32) {
1769 (2, self._device_id as i32)
1770 }
1771
1772 #[pyo3(signature = (stream=None, max_version=None, dl_device=None, copy=None))]
1773 fn __dlpack__<'py>(
1774 &mut self,
1775 py: Python<'py>,
1776 stream: Option<PyObject>,
1777 max_version: Option<PyObject>,
1778 dl_device: Option<PyObject>,
1779 copy: Option<PyObject>,
1780 ) -> PyResult<PyObject> {
1781 use crate::utilities::dlpack_cuda::export_f32_cuda_dlpack_2d;
1782
1783 let (kdl, alloc_dev) = self.__dlpack_device__();
1784 if let Some(dev_obj) = dl_device.as_ref() {
1785 if let Ok((dev_ty, dev_id)) = dev_obj.extract::<(i32, i32)>(py) {
1786 if dev_ty != kdl || dev_id != alloc_dev {
1787 let wants_copy = copy
1788 .as_ref()
1789 .and_then(|c| c.extract::<bool>(py).ok())
1790 .unwrap_or(false);
1791 if wants_copy {
1792 return Err(PyValueError::new_err(
1793 "device copy not implemented for __dlpack__",
1794 ));
1795 } else {
1796 return Err(PyValueError::new_err("dl_device mismatch for __dlpack__"));
1797 }
1798 }
1799 }
1800 }
1801 let _ = stream;
1802
1803 let dummy =
1804 DeviceBuffer::from_slice(&[]).map_err(|e| PyValueError::new_err(e.to_string()))?;
1805 let inner = std::mem::replace(
1806 &mut self.inner,
1807 DeviceArrayF32 {
1808 buf: dummy,
1809 rows: 0,
1810 cols: 0,
1811 },
1812 );
1813
1814 let rows = inner.rows;
1815 let cols = inner.cols;
1816 let buf = inner.buf;
1817
1818 let max_version_bound = max_version.map(|obj| obj.into_bound(py));
1819
1820 export_f32_cuda_dlpack_2d(py, buf, rows, cols, alloc_dev, max_version_bound)
1821 }
1822}
1823
1824#[cfg(all(feature = "python", feature = "cuda"))]
1825impl DeviceArrayF32LinregPy {
1826 pub fn new(inner: DeviceArrayF32, ctx_guard: Arc<Context>, device_id: u32) -> Self {
1827 Self {
1828 inner,
1829 _ctx_guard: ctx_guard,
1830 _device_id: device_id,
1831 }
1832 }
1833}
1834
1835#[cfg(feature = "python")]
1836#[pyclass(name = "LinRegStream")]
1837pub struct LinRegStreamPy {
1838 inner: LinRegStream,
1839}
1840
1841#[cfg(feature = "python")]
1842#[pymethods]
1843impl LinRegStreamPy {
1844 #[new]
1845 pub fn new(period: usize) -> PyResult<Self> {
1846 let params = LinRegParams {
1847 period: Some(period),
1848 };
1849 match LinRegStream::try_new(params) {
1850 Ok(stream) => Ok(Self { inner: stream }),
1851 Err(e) => Err(PyValueError::new_err(format!("LinRegStream error: {}", e))),
1852 }
1853 }
1854
1855 pub fn update(&mut self, value: f64) -> Option<f64> {
1856 self.inner.update(value)
1857 }
1858}
1859
1860#[inline]
1861pub fn linreg_into_slice(
1862 dst: &mut [f64],
1863 input: &LinRegInput,
1864 kern: Kernel,
1865) -> Result<(), LinRegError> {
1866 let data: &[f64] = input.as_ref();
1867
1868 if dst.len() != data.len() {
1869 return Err(LinRegError::OutputLengthMismatch {
1870 expected: data.len(),
1871 got: dst.len(),
1872 });
1873 }
1874
1875 linreg_compute_into(input, kern, dst)
1876}
1877
1878#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1879use wasm_bindgen::prelude::*;
1880
1881#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1882#[wasm_bindgen]
1883pub fn linreg_js(data: &[f64], period: usize) -> Result<Vec<f64>, JsValue> {
1884 let params = LinRegParams {
1885 period: Some(period),
1886 };
1887 let input = LinRegInput::from_slice(data, params);
1888
1889 let mut output = vec![0.0; data.len()];
1890
1891 linreg_into_slice(&mut output, &input, Kernel::Scalar)
1892 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1893
1894 Ok(output)
1895}
1896
1897#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1898#[derive(Serialize, Deserialize)]
1899pub struct LinRegBatchConfig {
1900 pub period_range: (usize, usize, usize),
1901}
1902
1903#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1904#[wasm_bindgen(js_name = linreg_batch)]
1905pub fn linreg_batch_unified_js(data: &[f64], config: JsValue) -> Result<JsValue, JsValue> {
1906 let config: LinRegBatchConfig = serde_wasm_bindgen::from_value(config)
1907 .map_err(|e| JsValue::from_str(&format!("Invalid config: {}", e)))?;
1908
1909 let sweep = LinRegBatchRange {
1910 period: config.period_range,
1911 };
1912
1913 let output = linreg_batch_slice(data, &sweep, Kernel::Scalar)
1914 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1915
1916 serde_wasm_bindgen::to_value(&output)
1917 .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
1918}
1919
1920#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1921#[wasm_bindgen]
1922pub fn linreg_alloc(len: usize) -> *mut f64 {
1923 let mut vec = Vec::<f64>::with_capacity(len);
1924 let ptr = vec.as_mut_ptr();
1925 std::mem::forget(vec);
1926 ptr
1927}
1928
1929#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1930#[wasm_bindgen]
1931pub fn linreg_free(ptr: *mut f64, len: usize) {
1932 if !ptr.is_null() {
1933 unsafe {
1934 let _ = Vec::from_raw_parts(ptr, len, len);
1935 }
1936 }
1937}
1938
1939#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1940#[wasm_bindgen]
1941pub fn linreg_into(
1942 in_ptr: *const f64,
1943 out_ptr: *mut f64,
1944 len: usize,
1945 period: usize,
1946) -> Result<(), JsValue> {
1947 if in_ptr.is_null() || out_ptr.is_null() {
1948 return Err(JsValue::from_str("null pointer passed to linreg_into"));
1949 }
1950
1951 unsafe {
1952 let data = std::slice::from_raw_parts(in_ptr, len);
1953
1954 if period == 0 || period > len {
1955 return Err(JsValue::from_str("Invalid period"));
1956 }
1957
1958 let params = LinRegParams {
1959 period: Some(period),
1960 };
1961 let input = LinRegInput::from_slice(data, params);
1962
1963 if in_ptr == out_ptr {
1964 let mut temp = vec![0.0; len];
1965 linreg_into_slice(&mut temp, &input, Kernel::Scalar)
1966 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1967
1968 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1969 out.copy_from_slice(&temp);
1970 } else {
1971 let out = std::slice::from_raw_parts_mut(out_ptr, len);
1972 linreg_into_slice(out, &input, Kernel::Scalar)
1973 .map_err(|e| JsValue::from_str(&e.to_string()))?;
1974 }
1975
1976 Ok(())
1977 }
1978}
1979
1980#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
1981#[wasm_bindgen]
1982pub fn linreg_batch_into(
1983 in_ptr: *const f64,
1984 out_ptr: *mut f64,
1985 len: usize,
1986 period_start: usize,
1987 period_end: usize,
1988 period_step: usize,
1989) -> Result<usize, JsValue> {
1990 if in_ptr.is_null() || out_ptr.is_null() {
1991 return Err(JsValue::from_str(
1992 "null pointer passed to linreg_batch_into",
1993 ));
1994 }
1995
1996 unsafe {
1997 let data = std::slice::from_raw_parts(in_ptr, len);
1998
1999 let sweep = LinRegBatchRange {
2000 period: (period_start, period_end, period_step),
2001 };
2002
2003 let combos = expand_grid_linreg(&sweep);
2004 let rows = combos.len();
2005 let cols = len;
2006
2007 let out = std::slice::from_raw_parts_mut(out_ptr, rows * cols);
2008
2009 linreg_batch_inner_into(data, &sweep, Kernel::Scalar, false, out)
2010 .map_err(|e| JsValue::from_str(&e.to_string()))?;
2011
2012 Ok(rows)
2013 }
2014}