rustalib/indicators/oscillators/
macd.rs

1use crate::indicators::moving_averages::calculate_ema;
2use crate::util::dataframe_utils::check_window_size;
3use polars::prelude::*;
4
5/// Calculates Moving Average Convergence Divergence (MACD)
6///
7/// # Arguments
8///
9/// * `df` - DataFrame containing the price data
10/// * `fast_period` - Fast EMA period (typically 12)
11/// * `slow_period` - Slow EMA period (typically 26)
12/// * `signal_period` - Signal line period (typically 9)
13/// * `column` - Column name to use for calculations (default "close")
14///
15/// # Returns
16///
17/// Returns a PolarsResult containing tuple of (MACD, Signal) Series
18pub fn calculate_macd(
19    df: &DataFrame,
20    fast_period: usize,
21    slow_period: usize,
22    signal_period: usize,
23    column: &str,
24) -> PolarsResult<(Series, Series)> {
25    // Check we have enough data for the longest period (slow_period)
26    check_window_size(df, slow_period, "MACD")?;
27
28    let ema_fast = calculate_ema(df, column, fast_period)?;
29    let ema_slow = calculate_ema(df, column, slow_period)?;
30
31    let macd = (&ema_fast - &ema_slow)?;
32
33    // Create a temporary DataFrame with MACD series for calculating the signal
34    let macd_series = macd.clone();
35    let temp_df = DataFrame::new(vec![macd_series.with_name(column.into()).into()])?;
36
37    // Calculate the signal line as an EMA of the MACD
38    let signal = calculate_ema(&temp_df, column, signal_period)?;
39
40    // Replace NaN values in signal with zeros at positions where MACD has values
41    let macd_ca = macd.f64()?;
42    let signal_ca = signal.f64()?;
43
44    let mut signal_vec: Vec<f64> = Vec::with_capacity(signal.len());
45
46    for i in 0..signal.len() {
47        if i < slow_period - 1 {
48            // Keep first slow_period-1 values as NaN to match MACD
49            signal_vec.push(f64::NAN);
50        } else if i < slow_period - 1 + signal_period {
51            // For index positions where signal might be NaN but MACD has values,
52            // use non-NaN values or 0.0
53            if let Some(macd_val) = macd_ca.get(i) {
54                if !macd_val.is_nan() {
55                    // Signal might be NaN here, use 0.0 as initial value
56                    signal_vec.push(0.0);
57                } else {
58                    signal_vec.push(f64::NAN);
59                }
60            } else {
61                signal_vec.push(f64::NAN);
62            }
63        } else {
64            // For positions where signal should have valid values
65            let val = signal_ca.get(i).unwrap_or(0.0);
66            if val.is_nan() && macd_ca.get(i).is_some_and(|v| !v.is_nan()) {
67                signal_vec.push(0.0);
68            } else {
69                signal_vec.push(val);
70            }
71        }
72    }
73
74    let macd_name = format!("macd_{0}_{1}", fast_period, slow_period);
75    let signal_name = format!(
76        "macd_signal_{0}_{1}_{2}",
77        fast_period, slow_period, signal_period
78    );
79
80    Ok((
81        macd.with_name(macd_name.into()),
82        Series::new(signal_name.into(), signal_vec),
83    ))
84}