1use crate::error::{Error, Result};
4use crate::indicators::atr::Atr;
5use crate::indicators::ema::Ema;
6use crate::ohlcv::Candle;
7use crate::traits::Indicator;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
11pub struct KeltnerOutput {
12 pub upper: f64,
14 pub middle: f64,
16 pub lower: f64,
18}
19
20#[derive(Debug, Clone)]
38pub struct Keltner {
39 ema: Ema,
40 atr: Atr,
41 multiplier: f64,
42 ema_period: usize,
43 atr_period: usize,
44}
45
46impl Keltner {
47 pub fn new(ema_period: usize, atr_period: usize, multiplier: f64) -> Result<Self> {
50 if !multiplier.is_finite() || multiplier <= 0.0 {
51 return Err(Error::NonPositiveMultiplier);
52 }
53 Ok(Self {
54 ema: Ema::new(ema_period)?,
55 atr: Atr::new(atr_period)?,
56 multiplier,
57 ema_period,
58 atr_period,
59 })
60 }
61
62 pub fn classic() -> Self {
64 Self::new(20, 10, 2.0).expect("classic Keltner parameters are valid")
65 }
66
67 pub const fn periods(&self) -> (usize, usize, f64) {
69 (self.ema_period, self.atr_period, self.multiplier)
70 }
71}
72
73impl Indicator for Keltner {
74 type Input = Candle;
75 type Output = KeltnerOutput;
76
77 fn update(&mut self, candle: Candle) -> Option<KeltnerOutput> {
78 let mid = self.ema.update(candle.typical_price());
84 let atr = self.atr.update(candle);
85 let (mid, atr) = (mid?, atr?);
86 Some(KeltnerOutput {
87 upper: mid + self.multiplier * atr,
88 middle: mid,
89 lower: mid - self.multiplier * atr,
90 })
91 }
92
93 fn reset(&mut self) {
94 self.ema.reset();
95 self.atr.reset();
96 }
97
98 fn warmup_period(&self) -> usize {
99 self.ema_period.max(self.atr_period)
100 }
101
102 fn is_ready(&self) -> bool {
103 self.ema.is_ready() && self.atr.is_ready()
104 }
105
106 fn name(&self) -> &'static str {
107 "KeltnerChannels"
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::traits::BatchExt;
115 use approx::assert_relative_eq;
116
117 fn c(h: f64, l: f64, cl: f64) -> Candle {
118 Candle::new(cl, h, l, cl, 1.0, 0).unwrap()
119 }
120
121 #[test]
122 fn flat_market_collapses_bands() {
123 let candles: Vec<Candle> = (0..50).map(|_| c(10.0, 10.0, 10.0)).collect();
124 let mut k = Keltner::new(20, 10, 2.0).unwrap();
125 let last = k.batch(&candles).into_iter().flatten().last().unwrap();
126 assert_relative_eq!(last.upper, last.middle, epsilon = 1e-9);
127 assert_relative_eq!(last.lower, last.middle, epsilon = 1e-9);
128 }
129
130 #[test]
131 fn upper_above_middle_above_lower() {
132 let candles: Vec<Candle> = (0..100)
133 .map(|i| {
134 let m = 100.0 + (f64::from(i) * 0.2).sin() * 5.0;
135 c(m + 1.0, m - 1.0, m)
136 })
137 .collect();
138 let mut k = Keltner::classic();
139 for o in k.batch(&candles).into_iter().flatten() {
140 assert!(o.upper >= o.middle);
141 assert!(o.middle >= o.lower);
142 }
143 }
144
145 #[test]
146 fn batch_equals_streaming() {
147 let candles: Vec<Candle> = (0..50)
148 .map(|i| c(f64::from(i) + 1.0, f64::from(i) - 1.0, f64::from(i)))
149 .collect();
150 let mut a = Keltner::classic();
151 let mut b = Keltner::classic();
152 assert_eq!(
153 a.batch(&candles),
154 candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
155 );
156 }
157
158 #[test]
159 fn rejects_invalid_input() {
160 assert!(Keltner::new(0, 10, 2.0).is_err());
161 assert!(Keltner::new(20, 10, 0.0).is_err());
162 assert!(Keltner::new(20, 10, -1.0).is_err());
163 }
164
165 #[test]
169 fn accessors_and_metadata() {
170 let k = Keltner::new(20, 10, 2.0).unwrap();
171 let (ema, atr, mult) = k.periods();
172 assert_eq!(ema, 20);
173 assert_eq!(atr, 10);
174 assert!((mult - 2.0).abs() < 1e-12);
175 assert_eq!(k.name(), "KeltnerChannels");
176 }
177
178 #[test]
179 fn reset_clears_state() {
180 let candles: Vec<Candle> = (0..50)
181 .map(|i| c(f64::from(i) + 1.0, f64::from(i) - 1.0, f64::from(i)))
182 .collect();
183 let mut k = Keltner::classic();
184 k.batch(&candles);
185 assert!(k.is_ready());
186 k.reset();
187 assert!(!k.is_ready());
188 assert_eq!(k.update(candles[0]), None);
189 }
190
191 #[test]
192 fn first_emission_matches_warmup_period() {
193 let candles: Vec<Candle> = (0..60)
194 .map(|i| {
195 let base = 100.0 + f64::from(i);
196 c(base + 1.0, base - 1.0, base)
197 })
198 .collect();
199 let mut k = Keltner::classic();
200 let out = k.batch(&candles);
201 let warmup = k.warmup_period();
202 assert_eq!(warmup, 20);
203 for (i, v) in out.iter().enumerate().take(warmup - 1) {
204 assert!(v.is_none(), "index {i} must be None during warmup");
205 }
206 assert!(
207 out[warmup - 1].is_some(),
208 "first KeltnerOutput must land at warmup_period - 1"
209 );
210 }
211
212 #[test]
213 fn matches_independent_ema_and_atr() {
214 let candles: Vec<Candle> = (0..60)
218 .map(|i| {
219 let m = 100.0 + (f64::from(i) * 0.2).sin() * 5.0;
220 c(m + 1.5, m - 1.5, m)
221 })
222 .collect();
223 let mut k = Keltner::classic();
224 let mut ema = Ema::new(20).unwrap();
225 let mut atr = Atr::new(10).unwrap();
226 for (i, candle) in candles.iter().enumerate() {
227 let got = k.update(*candle);
228 let mid = ema.update(candle.typical_price());
229 let a = atr.update(*candle);
230 match (mid, a) {
231 (Some(m), Some(av)) => {
232 let o = got.expect("Keltner emits once EMA and ATR are both ready");
233 assert_relative_eq!(o.middle, m, epsilon = 1e-9);
234 assert_relative_eq!(o.upper, m + 2.0 * av, epsilon = 1e-9);
235 assert_relative_eq!(o.lower, m - 2.0 * av, epsilon = 1e-9);
236 }
237 _ => assert!(
238 got.is_none(),
239 "Keltner must be None until both ready (i={i})"
240 ),
241 }
242 }
243 }
244}