wickra_core/indicators/
kendall_tau.rs1use std::collections::VecDeque;
4
5use crate::error::{Error, Result};
6use crate::traits::Indicator;
7
8fn sign(a: f64, b: f64) -> i32 {
10 if a > b {
11 1
12 } else if a < b {
13 -1
14 } else {
15 0
16 }
17}
18
19#[derive(Debug, Clone)]
58pub struct KendallTau {
59 period: usize,
60 window: VecDeque<(f64, f64)>,
61 last: Option<f64>,
62}
63
64impl KendallTau {
65 pub fn new(period: usize) -> Result<Self> {
72 if period < 2 {
73 return Err(Error::InvalidPeriod {
74 message: "Kendall tau needs period >= 2",
75 });
76 }
77 Ok(Self {
78 period,
79 window: VecDeque::with_capacity(period),
80 last: None,
81 })
82 }
83
84 pub const fn period(&self) -> usize {
86 self.period
87 }
88
89 pub const fn value(&self) -> Option<f64> {
91 self.last
92 }
93
94 fn compute(&self) -> f64 {
95 let pairs: Vec<(f64, f64)> = self.window.iter().copied().collect();
96 let len = pairs.len();
97 let mut concordant: i64 = 0;
98 let mut discordant: i64 = 0;
99 let mut tie_x: i64 = 0;
100 let mut tie_y: i64 = 0;
101 for i in 0..len {
102 for j in (i + 1)..len {
103 let sx = sign(pairs[j].0, pairs[i].0);
104 let sy = sign(pairs[j].1, pairs[i].1);
105 if sx == 0 {
106 tie_x += 1;
107 }
108 if sy == 0 {
109 tie_y += 1;
110 }
111 let prod = sx * sy;
112 if prod > 0 {
113 concordant += 1;
114 } else if prod < 0 {
115 discordant += 1;
116 }
117 }
118 }
119 let n0 = (len * (len - 1) / 2) as f64;
120 let denom = ((n0 - tie_x as f64) * (n0 - tie_y as f64)).sqrt();
121 if denom == 0.0 {
122 return 0.0;
123 }
124 ((concordant - discordant) as f64 / denom).clamp(-1.0, 1.0)
125 }
126}
127
128impl Indicator for KendallTau {
129 type Input = (f64, f64);
130 type Output = f64;
131
132 fn update(&mut self, input: (f64, f64)) -> Option<f64> {
133 if self.window.len() == self.period {
134 self.window.pop_front();
135 }
136 self.window.push_back(input);
137 if self.window.len() < self.period {
138 return None;
139 }
140 let out = self.compute();
141 self.last = Some(out);
142 Some(out)
143 }
144
145 fn reset(&mut self) {
146 self.window.clear();
147 self.last = None;
148 }
149
150 fn warmup_period(&self) -> usize {
151 self.period
152 }
153
154 fn is_ready(&self) -> bool {
155 self.last.is_some()
156 }
157
158 fn name(&self) -> &'static str {
159 "KendallTau"
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::traits::BatchExt;
167 use approx::assert_relative_eq;
168
169 #[test]
170 fn rejects_period_below_two() {
171 assert!(matches!(
172 KendallTau::new(1),
173 Err(Error::InvalidPeriod { .. })
174 ));
175 assert!(KendallTau::new(2).is_ok());
176 }
177
178 #[test]
179 fn accessors_and_metadata() {
180 let k = KendallTau::new(20).unwrap();
181 assert_eq!(k.period(), 20);
182 assert_eq!(k.warmup_period(), 20);
183 assert_eq!(k.name(), "KendallTau");
184 assert!(!k.is_ready());
185 assert_eq!(k.value(), None);
186 }
187
188 #[test]
189 fn first_emission_at_warmup_period() {
190 let mut k = KendallTau::new(4).unwrap();
191 let out = k.batch(&[(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0), (5.0, 5.0)]);
192 for v in out.iter().take(3) {
193 assert!(v.is_none());
194 }
195 assert!(out[3].is_some());
196 }
197
198 #[test]
199 fn monotone_increasing_is_one() {
200 let pairs: Vec<(f64, f64)> = (0..20)
201 .map(|i| (f64::from(i), 2.0 * f64::from(i) + 1.0))
202 .collect();
203 let last = KendallTau::new(10)
204 .unwrap()
205 .batch(&pairs)
206 .into_iter()
207 .flatten()
208 .last()
209 .unwrap();
210 assert_relative_eq!(last, 1.0, epsilon = 1e-9);
211 }
212
213 #[test]
214 fn monotone_decreasing_is_minus_one() {
215 let pairs: Vec<(f64, f64)> = (0..20)
216 .map(|i| (f64::from(i), -3.0 * f64::from(i)))
217 .collect();
218 let last = KendallTau::new(10)
219 .unwrap()
220 .batch(&pairs)
221 .into_iter()
222 .flatten()
223 .last()
224 .unwrap();
225 assert_relative_eq!(last, -1.0, epsilon = 1e-9);
226 }
227
228 #[test]
229 fn constant_channel_yields_zero() {
230 let pairs: Vec<(f64, f64)> = (0..20).map(|i| (f64::from(i), 7.0)).collect();
232 let last = KendallTau::new(8)
233 .unwrap()
234 .batch(&pairs)
235 .into_iter()
236 .flatten()
237 .last()
238 .unwrap();
239 assert_relative_eq!(last, 0.0, epsilon = 1e-12);
240 }
241
242 #[test]
243 fn output_in_range() {
244 let pairs: Vec<(f64, f64)> = (0..80)
245 .map(|i| {
246 let t = f64::from(i);
247 (100.0 + t.sin() * 5.0, 50.0 + (t * 0.3).cos() * 3.0)
248 })
249 .collect();
250 for v in KendallTau::new(20)
251 .unwrap()
252 .batch(&pairs)
253 .into_iter()
254 .flatten()
255 {
256 assert!((-1.0..=1.0).contains(&v));
257 }
258 }
259
260 #[test]
261 fn reset_clears_state() {
262 let mut k = KendallTau::new(4).unwrap();
263 k.batch(&[(1.0, 1.0), (2.0, 2.0), (3.0, 3.0), (4.0, 4.0)]);
264 assert!(k.is_ready());
265 k.reset();
266 assert!(!k.is_ready());
267 assert_eq!(k.value(), None);
268 assert_eq!(k.update((1.0, 1.0)), None);
269 }
270
271 #[test]
272 fn batch_equals_streaming() {
273 let pairs: Vec<(f64, f64)> = (0..60)
274 .map(|i| {
275 let t = f64::from(i);
276 (t.sin(), (t * 0.5).cos())
277 })
278 .collect();
279 let batch = KendallTau::new(14).unwrap().batch(&pairs);
280 let mut b = KendallTau::new(14).unwrap();
281 let streamed: Vec<_> = pairs.iter().map(|p| b.update(*p)).collect();
282 assert_eq!(batch, streamed);
283 }
284
285 #[test]
286 fn ties_are_corrected() {
287 let mut k = KendallTau::new(4).unwrap();
290 assert_eq!(k.update((1.0, 1.0)), None);
291 assert_eq!(k.update((1.0, 2.0)), None);
292 assert_eq!(k.update((2.0, 2.0)), None);
293 let v = k.update((3.0, 3.0)).unwrap();
294 assert!((-1.0..=1.0).contains(&v), "got {v}");
295 }
296}