1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
8
9use crate::error::{SpatialError, SpatialResult};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum KernelType {
14 Gaussian,
16 Epanechnikov,
18 Quartic,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum BandwidthMethod {
25 Silverman,
27 Scott,
29}
30
31#[derive(Debug, Clone)]
33pub struct SpatialKdeConfig {
34 pub kernel: KernelType,
36 pub bandwidth_x: Option<f64>,
38 pub bandwidth_y: Option<f64>,
40 pub bandwidth_method: BandwidthMethod,
42 pub grid_nx: usize,
44 pub grid_ny: usize,
46}
47
48impl Default for SpatialKdeConfig {
49 fn default() -> Self {
50 Self {
51 kernel: KernelType::Gaussian,
52 bandwidth_x: None,
53 bandwidth_y: None,
54 bandwidth_method: BandwidthMethod::Silverman,
55 grid_nx: 50,
56 grid_ny: 50,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct KdeGrid {
64 pub density: Array2<f64>,
66 pub x_coords: Array1<f64>,
68 pub y_coords: Array1<f64>,
70 pub bandwidth_x: f64,
72 pub bandwidth_y: f64,
74}
75
76pub fn select_bandwidth(
84 coordinates: &ArrayView2<f64>,
85 method: BandwidthMethod,
86) -> SpatialResult<(f64, f64)> {
87 let n = coordinates.nrows();
88 if n < 2 {
89 return Err(SpatialError::ValueError(
90 "Need at least 2 points for bandwidth selection".to_string(),
91 ));
92 }
93 if coordinates.ncols() < 2 {
94 return Err(SpatialError::DimensionError(
95 "Coordinates must have at least 2 columns (x, y)".to_string(),
96 ));
97 }
98
99 let nf = n as f64;
100
101 let col_x: Vec<f64> = coordinates.column(0).iter().copied().collect();
102 let col_y: Vec<f64> = coordinates.column(1).iter().copied().collect();
103 let hx = bandwidth_1d(&col_x, nf, method);
104 let hy = bandwidth_1d(&col_y, nf, method);
105
106 if hx <= 0.0 || hy <= 0.0 {
107 return Err(SpatialError::ValueError(
108 "Computed bandwidth is non-positive; data may have zero variance".to_string(),
109 ));
110 }
111
112 Ok((hx, hy))
113}
114
115fn bandwidth_1d(data: &[f64], nf: f64, method: BandwidthMethod) -> f64 {
116 let mean: f64 = data.iter().sum::<f64>() / nf;
117 let var: f64 = data.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / nf;
118 let std = var.sqrt();
119
120 match method {
121 BandwidthMethod::Silverman => {
122 let iqr = interquartile_range(data);
123 let spread = std.min(iqr / 1.34);
124 let spread = if spread > 0.0 { spread } else { std };
125 0.9 * spread * nf.powf(-0.2)
126 }
127 BandwidthMethod::Scott => {
128 std * nf.powf(-1.0 / 6.0)
130 }
131 }
132}
133
134fn interquartile_range(data: &[f64]) -> f64 {
135 if data.len() < 4 {
136 let min = data.iter().cloned().fold(f64::INFINITY, f64::min);
138 let max = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
139 return max - min;
140 }
141 let mut sorted: Vec<f64> = data.to_vec();
142 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
143 let n = sorted.len();
144 let q1 = sorted[n / 4];
145 let q3 = sorted[3 * n / 4];
146 q3 - q1
147}
148
149fn kernel_eval(u_sq: f64, kernel: KernelType) -> f64 {
154 match kernel {
155 KernelType::Gaussian => {
156 (1.0 / (2.0 * std::f64::consts::PI)) * (-0.5 * u_sq).exp()
158 }
159 KernelType::Epanechnikov => {
160 if u_sq <= 1.0 {
161 (2.0 / std::f64::consts::PI) * (1.0 - u_sq)
163 } else {
164 0.0
165 }
166 }
167 KernelType::Quartic => {
168 if u_sq <= 1.0 {
169 let t = 1.0 - u_sq;
171 (3.0 / std::f64::consts::PI) * t * t
172 } else {
173 0.0
174 }
175 }
176 }
177}
178
179pub fn kde_on_grid(
187 coordinates: &ArrayView2<f64>,
188 config: &SpatialKdeConfig,
189) -> SpatialResult<KdeGrid> {
190 let n = coordinates.nrows();
191 if n == 0 {
192 return Err(SpatialError::ValueError("No data points".to_string()));
193 }
194 if coordinates.ncols() < 2 {
195 return Err(SpatialError::DimensionError(
196 "Coordinates must be 2D".to_string(),
197 ));
198 }
199
200 let (hx, hy) = match (config.bandwidth_x, config.bandwidth_y) {
202 (Some(bx), Some(by)) => {
203 if bx <= 0.0 || by <= 0.0 {
204 return Err(SpatialError::ValueError(
205 "Bandwidths must be positive".to_string(),
206 ));
207 }
208 (bx, by)
209 }
210 _ => {
211 let (auto_hx, auto_hy) = select_bandwidth(coordinates, config.bandwidth_method)?;
212 (
213 config.bandwidth_x.unwrap_or(auto_hx),
214 config.bandwidth_y.unwrap_or(auto_hy),
215 )
216 }
217 };
218
219 let mut xmin = f64::INFINITY;
221 let mut xmax = f64::NEG_INFINITY;
222 let mut ymin = f64::INFINITY;
223 let mut ymax = f64::NEG_INFINITY;
224
225 for i in 0..n {
226 let x = coordinates[[i, 0]];
227 let y = coordinates[[i, 1]];
228 if x < xmin {
229 xmin = x;
230 }
231 if x > xmax {
232 xmax = x;
233 }
234 if y < ymin {
235 ymin = y;
236 }
237 if y > ymax {
238 ymax = y;
239 }
240 }
241
242 let margin_x = 3.0 * hx;
243 let margin_y = 3.0 * hy;
244 xmin -= margin_x;
245 xmax += margin_x;
246 ymin -= margin_y;
247 ymax += margin_y;
248
249 let nx = config.grid_nx.max(2);
250 let ny = config.grid_ny.max(2);
251
252 let dx = (xmax - xmin) / (nx as f64 - 1.0);
253 let dy = (ymax - ymin) / (ny as f64 - 1.0);
254
255 let x_coords = Array1::from_shape_fn(nx, |i| xmin + i as f64 * dx);
256 let y_coords = Array1::from_shape_fn(ny, |j| ymin + j as f64 * dy);
257
258 let nf = n as f64;
259 let mut density = Array2::zeros((ny, nx));
260
261 for j in 0..ny {
262 let gy = y_coords[j];
263 for i in 0..nx {
264 let gx = x_coords[i];
265
266 let mut sum = 0.0;
267 for k in 0..n {
268 let ux = (gx - coordinates[[k, 0]]) / hx;
269 let uy = (gy - coordinates[[k, 1]]) / hy;
270 let u_sq = ux * ux + uy * uy;
271 sum += kernel_eval(u_sq, config.kernel);
272 }
273
274 density[[j, i]] = sum / (nf * hx * hy);
275 }
276 }
277
278 Ok(KdeGrid {
279 density,
280 x_coords,
281 y_coords,
282 bandwidth_x: hx,
283 bandwidth_y: hy,
284 })
285}
286
287pub fn kde_at_point(
293 coordinates: &ArrayView2<f64>,
294 query: &[f64; 2],
295 hx: f64,
296 hy: f64,
297 kernel: KernelType,
298) -> SpatialResult<f64> {
299 let n = coordinates.nrows();
300 if n == 0 {
301 return Err(SpatialError::ValueError("No data points".to_string()));
302 }
303 if hx <= 0.0 || hy <= 0.0 {
304 return Err(SpatialError::ValueError(
305 "Bandwidths must be positive".to_string(),
306 ));
307 }
308
309 let nf = n as f64;
310 let mut sum = 0.0;
311 for k in 0..n {
312 let ux = (query[0] - coordinates[[k, 0]]) / hx;
313 let uy = (query[1] - coordinates[[k, 1]]) / hy;
314 let u_sq = ux * ux + uy * uy;
315 sum += kernel_eval(u_sq, kernel);
316 }
317
318 Ok(sum / (nf * hx * hy))
319}
320
321#[cfg(test)]
326mod tests {
327 use super::*;
328 use scirs2_core::ndarray::array;
329
330 #[test]
331 fn test_kde_peak_at_concentration() {
332 let coords = array![
334 [0.0, 0.0],
335 [0.1, 0.0],
336 [0.0, 0.1],
337 [-0.1, 0.0],
338 [0.0, -0.1],
339 [5.0, 5.0], ];
341
342 let config = SpatialKdeConfig {
343 kernel: KernelType::Gaussian,
344 bandwidth_x: Some(0.5),
345 bandwidth_y: Some(0.5),
346 grid_nx: 20,
347 grid_ny: 20,
348 ..Default::default()
349 };
350
351 let grid = kde_on_grid(&coords.view(), &config).expect("kde_on_grid");
352
353 let d_origin = kde_at_point(&coords.view(), &[0.0, 0.0], 0.5, 0.5, KernelType::Gaussian)
355 .expect("point kde");
356 let d_far = kde_at_point(&coords.view(), &[5.0, 5.0], 0.5, 0.5, KernelType::Gaussian)
357 .expect("point kde");
358
359 assert!(
360 d_origin > d_far,
361 "Density at concentration ({}) should exceed outlier density ({})",
362 d_origin,
363 d_far
364 );
365 }
366
367 #[test]
368 fn test_kde_integrates_approximately_to_one() {
369 let coords = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
371
372 let config = SpatialKdeConfig {
373 kernel: KernelType::Gaussian,
374 bandwidth_x: Some(0.5),
375 bandwidth_y: Some(0.5),
376 grid_nx: 80,
377 grid_ny: 80,
378 ..Default::default()
379 };
380
381 let grid = kde_on_grid(&coords.view(), &config).expect("kde_on_grid");
382
383 let dx = (grid.x_coords[grid.x_coords.len() - 1] - grid.x_coords[0])
385 / (grid.x_coords.len() as f64 - 1.0);
386 let dy = (grid.y_coords[grid.y_coords.len() - 1] - grid.y_coords[0])
387 / (grid.y_coords.len() as f64 - 1.0);
388
389 let integral: f64 = grid.density.sum() * dx * dy;
390
391 assert!(
393 (integral - 1.0).abs() < 0.15,
394 "KDE integral = {}, expected ~1.0",
395 integral
396 );
397 }
398
399 #[test]
400 fn test_bandwidth_selection_silverman() {
401 let coords = array![
402 [0.0, 0.0],
403 [1.0, 0.0],
404 [0.0, 1.0],
405 [1.0, 1.0],
406 [0.5, 0.5],
407 [0.2, 0.8],
408 [0.8, 0.2],
409 [0.3, 0.7],
410 ];
411
412 let (hx, hy) =
413 select_bandwidth(&coords.view(), BandwidthMethod::Silverman).expect("bandwidth");
414 assert!(hx > 0.0, "hx should be positive");
415 assert!(hy > 0.0, "hy should be positive");
416 }
417
418 #[test]
419 fn test_bandwidth_selection_scott() {
420 let coords = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5],];
421
422 let (hx, hy) = select_bandwidth(&coords.view(), BandwidthMethod::Scott).expect("bw");
423 assert!(hx > 0.0);
424 assert!(hy > 0.0);
425 }
426
427 #[test]
428 fn test_epanechnikov_kernel() {
429 let coords = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
430
431 let config = SpatialKdeConfig {
432 kernel: KernelType::Epanechnikov,
433 bandwidth_x: Some(2.0),
434 bandwidth_y: Some(2.0),
435 grid_nx: 10,
436 grid_ny: 10,
437 ..Default::default()
438 };
439
440 let grid = kde_on_grid(&coords.view(), &config).expect("epanechnikov kde");
441
442 for &d in grid.density.iter() {
444 assert!(d >= 0.0, "density should be non-negative");
445 }
446 }
447
448 #[test]
449 fn test_quartic_kernel() {
450 let coords = array![[0.0, 0.0], [0.5, 0.5],];
451
452 let d = kde_at_point(&coords.view(), &[0.25, 0.25], 1.0, 1.0, KernelType::Quartic)
453 .expect("quartic");
454 assert!(d > 0.0, "quartic density should be positive near data");
455
456 let d_far = kde_at_point(
458 &coords.view(),
459 &[100.0, 100.0],
460 1.0,
461 1.0,
462 KernelType::Quartic,
463 )
464 .expect("quartic far");
465 assert!(d_far < 1e-15, "quartic density should be ~0 far from data");
466 }
467
468 #[test]
469 fn test_kde_errors() {
470 let empty: Array2<f64> = Array2::zeros((0, 2));
471 let config = SpatialKdeConfig::default();
472 assert!(kde_on_grid(&empty.view(), &config).is_err());
473
474 let single = array![[0.0, 0.0]];
475 assert!(
476 kde_at_point(&single.view(), &[0.0, 0.0], -1.0, 1.0, KernelType::Gaussian).is_err()
477 );
478 }
479}