wickra_core/indicators/
kendall_tau.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8fn sign(a: f64, b: f64) -> i32 {
10 if a > b {
11 1
12 } else if a < b {
13 -1
14 } else {
15 0
16 }
17}
18
19#[derive(Debug, Clone)]
58pub struct KendallTau {
59 period: usize,
60 window: VecDeque<(f64, f64)>,
61 last: Option<f64>,
62}
63
64impl KendallTau {
65 pub fn new(period: usize) -> Result<Self> {
72 if period < 2 {
73 return Err(Error::InvalidPeriod {
74 message: "Kendall tau needs period >= 2",
75 });
76 }
77 Ok(Self {
78 period,
79 window: VecDeque::with_capacity(period),
80 last: None,
81 })
82 }
83
84 pub const fn period(&self) -> usize {
86 self.period
87 }
88
89 pub const fn value(&self) -> Option<f64> {
91 self.last
92 }
93
94 fn compute(&self) -> f64 {
95 let pairs: Vec<(f64, f64)> = self.window.iter().copied().collect();
96 let len = pairs.len();
97 let mut concordant: i64 = 0;
98 let mut discordant: i64 = 0;
99 let mut tie_x: i64 = 0;
100 let mut tie_y: i64 = 0;
101 for i in 0..len {
102 for j in (i + 1)..len {
103 let sx = sign(pairs[j].0, pairs[i].0);
104 let sy = sign(pairs[j].1, pairs[i].1);
105 if sx == 0 {
106 tie_x += 1;
107 }
108 if sy == 0 {
109 tie_y += 1;
110 }
111 let prod = sx * sy;
112 if prod > 0 {
113 concordant += 1;
114 } else if prod < 0 {
115 discordant += 1;
116 }
117 }
118 }
119 let n0 = (len * (len - 1) / 2) as f64;
120 let denom = ((n0 - tie_x as f64) * (n0 - tie_y as f64)).sqrt();
121 if denom == 0.0 {
122 return 0.0;
123 }
124 ((concordant - discordant) as f64 / denom).clamp(-1.0, 1.0)
125 }
126}
127
128impl Indicator for KendallTau {
129 type Input = (f64, f64);
130 type Output = f64;
131
132 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
133 if !input.0.is_finite() || !input.1.is_finite() {
134 return None;
135 }
136 if self.window.len() == self.period {
137 self.window.pop_front();
138 }
139 self.window.push_back(input);
140 if self.window.len() < self.period {
141 return None;
142 }
143 let out = self.compute();
144 self.last = Some(out);
145 Some(out)
146 }
147
148 fn reset(&mut self) {
149 self.window.clear();
150 self.last = None;
151 }
152
153 fn warmup_period(&self) -> usize {
154 self.period
155 }
156
157 fn is_ready(&self) -> bool {
158 self.last.is_some()
159 }
160
161 fn name(&self) -> &'static str {
162 "KendallTau"
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::traits::BatchExt;
170 use approx::assert_relative_eq;
171
172 #[test]
173 fn rejects_period_below_two() {
174 assert!(matches!(
175 KendallTau::new(1),
176 Err(Error::InvalidPeriod { .. })
177 ));
178 assert!(KendallTau::new(2).is_ok());
179 }
180
181 #[test]
182 fn accessors_and_metadata() {
183 let k = KendallTau::new(20).unwrap();
184 assert_eq!(k.period(), 20);
185 assert_eq!(k.warmup_period(), 20);
186 assert_eq!(k.name(), "KendallTau");
187 assert!(!k.is_ready());
188 assert_eq!(k.value(), None);
189 }
190
191 #[test]
192 fn first_emission_at_warmup_period() {
193 let mut k = KendallTau::new(4).unwrap();
194 let out = k.batch(&[(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0), (5.0, 5.0)]);
195 for v in out.iter().take(3) {
196 assert!(v.is_none());
197 }
198 assert!(out[3].is_some());
199 }
200
201 #[test]
202 fn monotone_increasing_is_one() {
203 let pairs: Vec<(f64, f64)> = (0..20)
204 .map(|i| (f64::from(i), 2.0 * f64::from(i) + 1.0))
205 .collect();
206 let last = KendallTau::new(10)
207 .unwrap()
208 .batch(&pairs)
209 .into_iter()
210 .flatten()
211 .last()
212 .unwrap();
213 assert_relative_eq!(last, 1.0, epsilon = 1e-9);
214 }
215
216 #[test]
217 fn monotone_decreasing_is_minus_one() {
218 let pairs: Vec<(f64, f64)> = (0..20)
219 .map(|i| (f64::from(i), -3.0 * f64::from(i)))
220 .collect();
221 let last = KendallTau::new(10)
222 .unwrap()
223 .batch(&pairs)
224 .into_iter()
225 .flatten()
226 .last()
227 .unwrap();
228 assert_relative_eq!(last, -1.0, epsilon = 1e-9);
229 }
230
231 #[test]
232 fn constant_channel_yields_zero() {
233 let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
235 let last = KendallTau::new(8)
236 .unwrap()
237 .batch(&pairs)
238 .into_iter()
239 .flatten()
240 .last()
241 .unwrap();
242 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
243 }
244
245 #[test]
246 fn output_in_range() {
247 let pairs: Vec<(f64, f64)> = (0..80)
248 .map(|i| {
249 let t = f64::from(i);
250 (100.0 + t.sin() * 5.0, 50.0 + (t * 0.3).cos() * 3.0)
251 })
252 .collect();
253 for v in KendallTau::new(20)
254 .unwrap()
255 .batch(&pairs)
256 .into_iter()
257 .flatten()
258 {
259 assert!((-1.0..=1.0).contains(&v));
260 }
261 }
262
263 #[test]
264 fn reset_clears_state() {
265 let mut k = KendallTau::new(4).unwrap();
266 k.batch(&[(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0)]);
267 assert!(k.is_ready());
268 k.reset();
269 assert!(!k.is_ready());
270 assert_eq!(k.value(), None);
271 assert_eq!(k.update((1.0, 1.0)), None);
272 }
273
274 #[test]
275 fn batch_equals_streaming() {
276 let pairs: Vec<(f64, f64)> = (0..60)
277 .map(|i| {
278 let t = f64::from(i);
279 (t.sin(), (t * 0.5).cos())
280 })
281 .collect();
282 let batch = KendallTau::new(14).unwrap().batch(&pairs);
283 let mut b = KendallTau::new(14).unwrap();
284 let streamed: Vec<_> = pairs.iter().map(|p| b.update(*p)).collect();
285 assert_eq!(batch, streamed);
286 }
287
288 #[test]
289 fn ties_are_corrected() {
290 let mut k = KendallTau::new(4).unwrap();
293 assert_eq!(k.update((1.0, 1.0)), None);
294 assert_eq!(k.update((1.0, 2.0)), None);
295 assert_eq!(k.update((2.0, 2.0)), None);
296 let v = k.update((3.0, 3.0)).unwrap();
297 assert!((-1.0..=1.0).contains(&v), "got {v}");
298 }
299
300 #[test]
301 fn non_finite_input_returns_none() {
302 let mut k = KendallTau::new(2).unwrap();
303 assert_eq!(k.update((f64::NAN, 1.0)), None);
304 assert_eq!(k.update((1.0, f64::INFINITY)), None);
305 assert_eq!(k.update((1.0, 2.0)), None);
307 assert!(k.update((2.0, 5.0)).is_some());
308 }
309}