1#![allow(non_snake_case)] use super::timeseries::TimeSeries;
29use ndarray::{Array1, Array2};
30use serde::{Deserialize, Serialize};
31use so_core::error::{Error, Result};
32use so_linalg;
33
34#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
36pub struct ARIMAOrder {
37 pub p: usize,
39 pub d: usize,
41 pub q: usize,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
47pub struct SARIMAOrder {
48 pub order: ARIMAOrder,
50 pub seasonal_p: usize,
52 pub seasonal_d: usize,
54 pub seasonal_q: usize,
56 pub seasonal_period: usize,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq)]
62pub enum EstimationMethod {
63 CSS,
65 ML,
67 ExactML,
69}
70
71#[derive(Debug, Clone)]
73pub struct ARIMAConfig {
74 pub order: ARIMAOrder,
76 pub with_constant: bool,
78 pub method: EstimationMethod,
80 pub max_iter: usize,
82 pub tol: f64,
84}
85
86impl Default for ARIMAConfig {
87 fn default() -> Self {
88 Self {
89 order: ARIMAOrder { p: 1, d: 0, q: 1 },
90 with_constant: true,
91 method: EstimationMethod::CSS,
92 max_iter: 100,
93 tol: 1e-6,
94 }
95 }
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ARIMAResults {
101 pub ar_coef: Option<Array1<f64>>,
103 pub ma_coef: Option<Array1<f64>>,
105 pub constant: Option<f64>,
107 pub sigma2: f64,
109 pub log_likelihood: f64,
111 pub aic: f64,
113 pub bic: f64,
115 pub n_obs: usize,
117 pub residuals: Array1<f64>,
119 pub fitted: Array1<f64>,
121}
122
123pub struct ARIMABuilder {
125 config: ARIMAConfig,
126}
127
128impl ARIMABuilder {
129 pub fn new(p: usize, d: usize, q: usize) -> Self {
131 Self {
132 config: ARIMAConfig {
133 order: ARIMAOrder { p, d, q },
134 ..Default::default()
135 },
136 }
137 }
138
139 pub fn seasonal(self, P: usize, D: usize, Q: usize, period: usize) -> SARIMABuilder {
141 SARIMABuilder::new(
142 self.config.order.p,
143 self.config.order.d,
144 self.config.order.q,
145 )
146 .seasonal(P, D, Q, period)
147 }
148
149 pub fn with_constant(mut self, include: bool) -> Self {
151 self.config.with_constant = include;
152 self
153 }
154
155 pub fn method(mut self, method: EstimationMethod) -> Self {
157 self.config.method = method;
158 self
159 }
160
161 pub fn max_iter(mut self, max_iter: usize) -> Self {
163 self.config.max_iter = max_iter;
164 self
165 }
166
167 pub fn tol(mut self, tol: f64) -> Self {
169 self.config.tol = tol;
170 self
171 }
172
173 pub fn fit(self, ts: &TimeSeries) -> Result<ARIMAResults> {
175 let mut arima = ARIMA::new(self.config);
176 arima.fit(ts)
177 }
178}
179
180pub struct SARIMABuilder {
182 order: SARIMAOrder,
183 with_constant: bool,
184 method: EstimationMethod,
185 max_iter: usize,
186 tol: f64,
187}
188
189impl SARIMABuilder {
190 pub fn new(p: usize, d: usize, q: usize) -> Self {
192 Self {
193 order: SARIMAOrder {
194 order: ARIMAOrder { p, d, q },
195 seasonal_p: 0,
196 seasonal_d: 0,
197 seasonal_q: 0,
198 seasonal_period: 1,
199 },
200 with_constant: true,
201 method: EstimationMethod::CSS,
202 max_iter: 100,
203 tol: 1e-6,
204 }
205 }
206
207 pub fn seasonal(mut self, P: usize, D: usize, Q: usize, period: usize) -> Self {
209 self.order.seasonal_p = P;
210 self.order.seasonal_d = D;
211 self.order.seasonal_q = Q;
212 self.order.seasonal_period = period;
213 self
214 }
215
216 pub fn with_constant(mut self, include: bool) -> Self {
218 self.with_constant = include;
219 self
220 }
221
222 pub fn method(mut self, method: EstimationMethod) -> Self {
224 self.method = method;
225 self
226 }
227
228 pub fn max_iter(mut self, max_iter: usize) -> Self {
230 self.max_iter = max_iter;
231 self
232 }
233
234 pub fn tol(mut self, tol: f64) -> Self {
236 self.tol = tol;
237 self
238 }
239
240 pub fn fit(self, ts: &TimeSeries) -> Result<ARIMAResults> {
242 let total_p = self.order.order.p + self.order.seasonal_p * self.order.seasonal_period;
244 let total_q = self.order.order.q + self.order.seasonal_q * self.order.seasonal_period;
245
246 let mut arima = ARIMA::new(ARIMAConfig {
247 order: ARIMAOrder {
248 p: total_p,
249 d: self.order.order.d + self.order.seasonal_d * self.order.seasonal_period,
250 q: total_q,
251 },
252 with_constant: self.with_constant,
253 method: self.method,
254 max_iter: self.max_iter,
255 tol: self.tol,
256 });
257
258 arima.fit(ts)
259 }
260}
261
262pub struct ARIMA {
264 config: ARIMAConfig,
265}
266
267impl ARIMA {
268 pub fn new(config: ARIMAConfig) -> Self {
270 Self { config }
271 }
272
273 pub fn builder(p: usize, d: usize, q: usize) -> ARIMABuilder {
275 ARIMABuilder::new(p, d, q)
276 }
277
278 pub fn fit(&mut self, ts: &TimeSeries) -> Result<ARIMAResults> {
280 let n = ts.len();
281 let order = self.config.order;
282
283 if n < order.p + order.q + 10 {
284 return Err(Error::DataError(format!(
285 "Not enough observations for ARIMA({},{},{}), need at least {}, got {}",
286 order.p,
287 order.d,
288 order.q,
289 order.p + order.q + 10,
290 n
291 )));
292 }
293
294 let (diffed_ts, _diff_timestamps) = self.difference(ts)?;
296 let y = diffed_ts.values();
297
298 match self.config.method {
299 EstimationMethod::CSS => self.fit_css(y, n),
300 EstimationMethod::ML => self.fit_ml(y, n),
301 EstimationMethod::ExactML => self.fit_exact_ml(y, n),
302 }
303 }
304
305 fn difference(&self, ts: &TimeSeries) -> Result<(TimeSeries, Vec<i64>)> {
307 if self.config.order.d == 0 {
308 return Ok((ts.clone(), ts.timestamps().to_vec()));
309 }
310
311 let diffed = ts.diff(1, self.config.order.d)?;
312 let timestamps = diffed.timestamps().to_vec();
313 Ok((diffed, timestamps))
314 }
315
316 fn fit_css(&self, y: &Array1<f64>, n_orig: usize) -> Result<ARIMAResults> {
318 let order = self.config.order;
319 let n = y.len();
320
321 let mut X = Array2::zeros((n - order.p, order.p + order.q + 1));
323 let mut y_reg = Array1::zeros(n - order.p);
324
325 let mut residuals = Array1::zeros(n);
326 let mut fitted = Array1::zeros(n);
327
328 for i in 0..n {
330 residuals[i] = y[i];
331 }
332
333 let mut converged = false;
335 let mut iteration = 0;
336
337 let mut ar_coef = if order.p > 0 {
339 Some(Array1::zeros(order.p))
340 } else {
341 None
342 };
343
344 let mut ma_coef = if order.q > 0 {
345 Some(Array1::zeros(order.q))
346 } else {
347 None
348 };
349
350 let mut constant = if self.config.with_constant {
351 Some(0.0)
352 } else {
353 None
354 };
355
356 while iteration < self.config.max_iter && !converged {
357 for t in order.p..n {
359 let mut row_idx = 0;
360
361 for lag in 1..=order.p {
363 X[(t - order.p, row_idx)] = y[t - lag];
364 row_idx += 1;
365 }
366
367 for lag in 1..=order.q {
369 if t - lag < residuals.len() {
370 X[(t - order.p, row_idx)] = residuals[t - lag];
371 }
372 row_idx += 1;
373 }
374
375 if self.config.with_constant {
377 X[(t - order.p, row_idx)] = 1.0;
378 }
379
380 y_reg[t - order.p] = y[t];
381 }
382
383 let XtX = X.t().dot(&X);
385 let Xty = X.t().dot(&y_reg);
386
387 let coef = so_linalg::solve(&XtX, &Xty)
388 .map_err(|e| Error::LinearAlgebraError(format!("ARIMA CSS solve failed: {}", e)))?;
389
390 let mut idx = 0;
392
393 if let Some(ref mut ar) = ar_coef {
394 for i in 0..order.p {
395 ar[i] = coef[idx];
396 idx += 1;
397 }
398 }
399
400 if let Some(ref mut ma) = ma_coef {
401 for i in 0..order.q {
402 ma[i] = coef[idx];
403 idx += 1;
404 }
405 }
406
407 if let Some(ref mut c) = constant {
408 *c = coef[idx];
409 }
410
411 let mut prev_change = 0.0;
413 for t in 0..n {
414 let mut prediction = 0.0;
415
416 if let Some(ref ar) = ar_coef {
418 for lag in 1..=order.p {
419 if t >= lag {
420 prediction += ar[lag - 1] * y[t - lag];
421 }
422 }
423 }
424
425 if let Some(ref ma) = ma_coef {
427 for lag in 1..=order.q {
428 if t >= lag {
429 prediction += ma[lag - 1] * residuals[t - lag];
430 }
431 }
432 }
433
434 if let Some(c) = constant {
436 prediction += c;
437 }
438
439 if t >= order.p {
440 fitted[t] = prediction;
441 }
442
443 let new_residual = y[t] - prediction;
444 prev_change += (new_residual - residuals[t]).abs();
445 residuals[t] = new_residual;
446 }
447
448 if prev_change / (n as f64) < self.config.tol {
450 converged = true;
451 }
452
453 iteration += 1;
454 }
455
456 if !converged {
457 return Err(Error::DataError(format!(
458 "ARIMA CSS did not converge after {} iterations",
459 self.config.max_iter
460 )));
461 }
462
463 let rss: f64 = residuals.iter().map(|&r| r.powi(2)).sum();
465 let sigma2 =
466 rss / (n - order.p - order.q - if self.config.with_constant { 1 } else { 0 }) as f64;
467
468 let log_likelihood = self.calculate_log_likelihood(&residuals, sigma2, n);
469 let (aic, bic) = self.calculate_information_criteria(
470 log_likelihood,
471 order.p + order.q + if self.config.with_constant { 1 } else { 0 },
472 n_orig,
473 );
474
475 Ok(ARIMAResults {
476 ar_coef,
477 ma_coef,
478 constant,
479 sigma2,
480 log_likelihood,
481 aic,
482 bic,
483 n_obs: n_orig,
484 residuals,
485 fitted,
486 })
487 }
488
489 fn fit_ml(&self, y: &Array1<f64>, n_orig: usize) -> Result<ARIMAResults> {
491 self.fit_css(y, n_orig)
493 }
494
495 fn fit_exact_ml(&self, y: &Array1<f64>, n_orig: usize) -> Result<ARIMAResults> {
497 self.fit_ml(y, n_orig)
500 }
501
502 fn calculate_log_likelihood(&self, residuals: &Array1<f64>, sigma2: f64, n: usize) -> f64 {
504 -0.5 * n as f64 * (2.0 * std::f64::consts::PI * sigma2).ln()
505 - 0.5 * residuals.iter().map(|&r| r.powi(2)).sum::<f64>() / sigma2
506 }
507
508 fn calculate_information_criteria(&self, log_lik: f64, k: usize, n: usize) -> (f64, f64) {
510 let aic = 2.0 * k as f64 - 2.0 * log_lik;
511 let bic = (n as f64).ln() * k as f64 - 2.0 * log_lik;
512 (aic, bic)
513 }
514
515 pub fn forecast(&self, results: &ARIMAResults, steps: usize) -> Array1<f64> {
517 let order = self.config.order;
518 let n = results.residuals.len();
519
520 let mut forecasts = Array1::zeros(steps);
521 let mut y_extended = results.fitted.clone();
522 let mut residuals_extended = results.residuals.clone();
523
524 for h in 0..steps {
525 let mut prediction = 0.0;
526
527 if let Some(ref ar) = results.ar_coef {
529 for lag in 1..=order.p {
530 let idx = n + h - lag;
531 if idx < y_extended.len() {
532 prediction += ar[lag - 1] * y_extended[idx];
533 }
534 }
535 }
536
537 if let Some(ref ma) = results.ma_coef {
539 for lag in 1..=order.q {
540 let idx = n + h - lag;
541 if idx < residuals_extended.len() {
542 prediction += ma[lag - 1] * residuals_extended[idx];
543 }
544 }
545 }
546
547 if let Some(c) = results.constant {
549 prediction += c;
550 }
551
552 forecasts[h] = prediction;
553
554 y_extended = ndarray::concatenate(
556 ndarray::Axis(0),
557 &[y_extended.view(), ndarray::array![prediction].view()],
558 )
559 .unwrap();
560
561 residuals_extended = ndarray::concatenate(
563 ndarray::Axis(0),
564 &[residuals_extended.view(), ndarray::array![0.0].view()],
565 )
566 .unwrap();
567 }
568
569 forecasts
570 }
571
572 pub fn prediction_intervals(
574 &self,
575 results: &ARIMAResults,
576 forecasts: &Array1<f64>,
577 alpha: f64,
578 ) -> (Array1<f64>, Array1<f64>) {
579 let sigma = results.sigma2.sqrt();
580 let _z = 1.0 - alpha / 2.0;
581 let z_value = 1.96; let lower = forecasts.mapv(|f| f - z_value * sigma);
584 let upper = forecasts.mapv(|f| f + z_value * sigma);
585
586 (lower, upper)
587 }
588}
589
590pub trait ARIMAExt {
592 fn arima(&self, p: usize, d: usize, q: usize) -> Result<ARIMAResults>;
594}
595
596impl ARIMAExt for TimeSeries {
597 fn arima(&self, p: usize, d: usize, q: usize) -> Result<ARIMAResults> {
598 ARIMA::builder(p, d, q).fit(self)
599 }
600}