1use std::f64::consts::PI;
2
3use sphereql_core::{
4 CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5 spherical_to_cartesian,
6};
7
8use crate::traits::{DimensionMapper, LayoutStrategy};
9use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
10
11const MAX_KMEANS_ITERATIONS: usize = 50;
12const OVERLAP_THRESHOLD: f64 = 0.01;
13
14pub struct ClusteredLayout {
15 pub num_clusters: usize,
16 pub radius: f64,
17 pub intra_cluster_spread: f64,
18}
19
20impl Default for ClusteredLayout {
21 fn default() -> Self {
22 Self {
23 num_clusters: 4,
24 radius: 1.0,
25 intra_cluster_spread: 0.3,
26 }
27 }
28}
29
30impl ClusteredLayout {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_clusters(mut self, n: usize) -> Self {
36 self.num_clusters = n;
37 self
38 }
39
40 pub fn with_radius(mut self, r: f64) -> Self {
41 self.radius = r;
42 self
43 }
44
45 pub fn with_spread(mut self, s: f64) -> Self {
46 self.intra_cluster_spread = s;
47 self
48 }
49}
50
51fn evenly_spaced_centers(k: usize) -> Vec<CartesianPoint> {
52 let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
53 (0..k)
54 .map(|i| {
55 let phi = (1.0 - 2.0 * (i as f64 + 0.5) / k as f64)
56 .clamp(-1.0, 1.0)
57 .acos();
58 let theta = (2.0 * PI * (i as f64) / golden_ratio).rem_euclid(2.0 * PI);
59 let sp = SphericalPoint::new_unchecked(1.0, theta, phi);
60 spherical_to_cartesian(&sp)
61 })
62 .collect()
63}
64
65fn normalized_mean(points: &[CartesianPoint]) -> CartesianPoint {
66 if points.is_empty() {
67 return CartesianPoint::new(0.0, 0.0, 1.0);
68 }
69 let mut sx = 0.0;
70 let mut sy = 0.0;
71 let mut sz = 0.0;
72 for p in points {
73 sx += p.x;
74 sy += p.y;
75 sz += p.z;
76 }
77 let mean = CartesianPoint::new(sx, sy, sz);
78 let n = mean.normalize();
79 if n.magnitude() == 0.0 {
80 points[0].normalize()
81 } else {
82 n
83 }
84}
85
86struct KMeansResult {
87 assignments: Vec<usize>,
88 centers: Vec<CartesianPoint>,
89}
90
91fn kmeans_spherical(
92 mapped_cartesian: &[CartesianPoint],
93 mapped_spherical: &[SphericalPoint],
94 k: usize,
95) -> KMeansResult {
96 let n = mapped_cartesian.len();
97
98 let mut centers: Vec<CartesianPoint> = if n >= k {
99 mapped_cartesian[..k]
100 .iter()
101 .map(|c| c.normalize())
102 .collect()
103 } else {
104 evenly_spaced_centers(k)
105 };
106
107 let mut assignments = vec![0usize; n];
108
109 for _ in 0..MAX_KMEANS_ITERATIONS {
110 let mut changed = false;
111
112 for (i, sp) in mapped_spherical.iter().enumerate() {
113 let mut best = 0;
114 let mut best_dist = f64::MAX;
115 for (j, center) in centers.iter().enumerate() {
116 let center_sp = cartesian_to_spherical(center);
117 let d = angular_distance(sp, ¢er_sp);
118 if d < best_dist {
119 best_dist = d;
120 best = j;
121 }
122 }
123 if assignments[i] != best {
124 assignments[i] = best;
125 changed = true;
126 }
127 }
128
129 if !changed {
130 break;
131 }
132
133 let mut cluster_points: Vec<Vec<CartesianPoint>> = vec![vec![]; k];
134 for (i, &a) in assignments.iter().enumerate() {
135 cluster_points[a].push(mapped_cartesian[i]);
136 }
137
138 for (j, cp) in cluster_points.iter().enumerate() {
139 if cp.is_empty() {
140 let mut farthest_idx = 0;
141 let mut farthest_dist = 0.0_f64;
142 for (i, sp) in mapped_spherical.iter().enumerate() {
143 let center_sp = cartesian_to_spherical(¢ers[assignments[i]]);
144 let d = angular_distance(sp, ¢er_sp);
145 if d > farthest_dist {
146 farthest_dist = d;
147 farthest_idx = i;
148 }
149 }
150 centers[j] = mapped_cartesian[farthest_idx].normalize();
151 } else {
152 centers[j] = normalized_mean(cp);
153 }
154 }
155 }
156
157 KMeansResult {
158 assignments,
159 centers,
160 }
161}
162
163fn fibonacci_sub_spiral(
164 center: &SphericalPoint,
165 count: usize,
166 spread: f64,
167 radius: f64,
168) -> Vec<SphericalPoint> {
169 if count == 0 {
170 return vec![];
171 }
172 if count == 1 {
173 return vec![SphericalPoint::new_unchecked(
174 radius,
175 center.theta,
176 center.phi,
177 )];
178 }
179
180 let golden_angle = PI * (3.0 - 5.0_f64.sqrt());
181 let center_cart = spherical_to_cartesian(&SphericalPoint::new_unchecked(
182 1.0,
183 center.theta,
184 center.phi,
185 ));
186
187 let (tangent_u, tangent_v) = local_frame(¢er_cart);
188
189 (0..count)
190 .map(|i| {
191 let frac = i as f64 / count as f64;
192 let angular_r = spread * frac.sqrt();
193 let angle = golden_angle * i as f64;
194
195 let offset_u = angular_r * angle.cos();
196 let offset_v = angular_r * angle.sin();
197
198 let displaced = CartesianPoint::new(
199 center_cart.x + offset_u * tangent_u.x + offset_v * tangent_v.x,
200 center_cart.y + offset_u * tangent_u.y + offset_v * tangent_v.y,
201 center_cart.z + offset_u * tangent_u.z + offset_v * tangent_v.z,
202 )
203 .normalize();
204
205 let sp = cartesian_to_spherical(&displaced);
206 SphericalPoint::new_unchecked(radius, sp.theta, sp.phi)
207 })
208 .collect()
209}
210
211fn local_frame(center: &CartesianPoint) -> (CartesianPoint, CartesianPoint) {
212 let up = if center.z.abs() < 0.9 {
213 CartesianPoint::new(0.0, 0.0, 1.0)
214 } else {
215 CartesianPoint::new(1.0, 0.0, 0.0)
216 };
217
218 let ux = up.y * center.z - up.z * center.y;
220 let uy = up.z * center.x - up.x * center.z;
221 let uz = up.x * center.y - up.y * center.x;
222 let u = CartesianPoint::new(ux, uy, uz).normalize();
223
224 let vx = center.y * u.z - center.z * u.y;
226 let vy = center.z * u.x - center.x * u.z;
227 let vz = center.x * u.y - center.y * u.x;
228 let v = CartesianPoint::new(vx, vy, vz).normalize();
229
230 (u, v)
231}
232
233const MAX_QUALITY_N: usize = 5000;
234
235fn compute_quality(
236 positions: &[SphericalPoint],
237 assignments: &[usize],
238 num_clusters: usize,
239) -> LayoutQuality {
240 let n = positions.len();
241
242 if n <= 1 {
243 return LayoutQuality {
244 dispersion_score: if n == 0 { 0.0 } else { 1.0 },
245 overlap_score: 0.0,
246 silhouette_score: 0.0,
247 };
248 }
249
250 let (positions, assignments, n) = if n > MAX_QUALITY_N {
251 let step = n / MAX_QUALITY_N;
252 let sampled_pos: Vec<_> = positions
253 .iter()
254 .step_by(step)
255 .take(MAX_QUALITY_N)
256 .copied()
257 .collect();
258 let sampled_asgn: Vec<_> = assignments
259 .iter()
260 .step_by(step)
261 .take(MAX_QUALITY_N)
262 .copied()
263 .collect();
264 let len = sampled_pos.len();
265 (sampled_pos, sampled_asgn, len)
266 } else {
267 (positions.to_vec(), assignments.to_vec(), n)
268 };
269
270 let mut cluster_point_sets: Vec<Vec<CartesianPoint>> = vec![vec![]; num_clusters];
272 for (i, &a) in assignments.iter().enumerate() {
273 cluster_point_sets[a].push(spherical_to_cartesian(&positions[i]));
274 }
275 let active_centers: Vec<SphericalPoint> = cluster_point_sets
276 .iter()
277 .filter(|cp| !cp.is_empty())
278 .map(|cp| cartesian_to_spherical(&normalized_mean(cp)))
279 .collect();
280
281 let dispersion_score = if active_centers.len() >= 2 {
282 let mut sum = 0.0;
283 let mut count = 0;
284 for i in 0..active_centers.len() {
285 for j in (i + 1)..active_centers.len() {
286 sum += angular_distance(&active_centers[i], &active_centers[j]);
287 count += 1;
288 }
289 }
290 (sum / count as f64 / PI).clamp(0.0, 1.0)
291 } else {
292 0.0
293 };
294
295 let mut overlap_count = 0u64;
297 let total_pairs = (n * (n - 1)) / 2;
298 for i in 0..n {
299 for j in (i + 1)..n {
300 if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
301 overlap_count += 1;
302 }
303 }
304 }
305 let overlap_score = if total_pairs > 0 {
306 overlap_count as f64 / total_pairs as f64
307 } else {
308 0.0
309 };
310
311 let silhouette_score = if num_clusters <= 1 || active_centers.len() <= 1 {
313 0.0
314 } else {
315 let mut sil_sum = 0.0;
316 for i in 0..n {
317 let ci = assignments[i];
318
319 let mut a_sum = 0.0;
321 let mut a_count = 0;
322 for j in 0..n {
323 if j != i && assignments[j] == ci {
324 a_sum += angular_distance(&positions[i], &positions[j]);
325 a_count += 1;
326 }
327 }
328 let a = if a_count > 0 {
329 a_sum / a_count as f64
330 } else {
331 0.0
332 };
333
334 let mut b = f64::MAX;
336 for k in 0..num_clusters {
337 if k == ci {
338 continue;
339 }
340 let mut b_sum = 0.0;
341 let mut b_count = 0;
342 for j in 0..n {
343 if assignments[j] == k {
344 b_sum += angular_distance(&positions[i], &positions[j]);
345 b_count += 1;
346 }
347 }
348 if b_count > 0 {
349 let mean_dist = b_sum / b_count as f64;
350 if mean_dist < b {
351 b = mean_dist;
352 }
353 }
354 }
355 if b == f64::MAX {
356 b = 0.0;
357 }
358
359 let denom = a.max(b);
360 let s = if denom > 0.0 { (b - a) / denom } else { 0.0 };
361 sil_sum += s;
362 }
363 sil_sum / n as f64
364 };
365
366 LayoutQuality {
367 dispersion_score,
368 overlap_score,
369 silhouette_score,
370 }
371}
372
373impl<T: Clone + Send + Sync> LayoutStrategy<T> for ClusteredLayout {
374 fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
375 if items.is_empty() {
376 return LayoutResult {
377 entries: vec![],
378 quality: LayoutQuality::default(),
379 };
380 }
381
382 let mapped: Vec<SphericalPoint> = items.iter().map(|item| mapper.map(item)).collect();
383 let mapped_cart: Vec<CartesianPoint> = mapped.iter().map(spherical_to_cartesian).collect();
384
385 let k = self.num_clusters.min(items.len()).max(1);
386 let km = kmeans_spherical(&mapped_cart, &mapped, k);
387
388 let mut cluster_items: Vec<Vec<usize>> = vec![vec![]; k];
389 for (i, &a) in km.assignments.iter().enumerate() {
390 cluster_items[a].push(i);
391 }
392
393 let mut entries: Vec<(usize, LayoutEntry<T>)> = Vec::with_capacity(items.len());
394 let mut final_positions: Vec<(usize, SphericalPoint)> = Vec::with_capacity(items.len());
395 let mut final_assignments = vec![0usize; items.len()];
396
397 for (cluster_idx, member_indices) in cluster_items.iter().enumerate() {
398 let center_sp = cartesian_to_spherical(&km.centers[cluster_idx]);
399 let sub_positions = fibonacci_sub_spiral(
400 ¢er_sp,
401 member_indices.len(),
402 self.intra_cluster_spread,
403 self.radius,
404 );
405
406 for (sub_idx, &item_idx) in member_indices.iter().enumerate() {
407 let pos = sub_positions[sub_idx];
408 entries.push((
409 item_idx,
410 LayoutEntry {
411 item: items[item_idx].clone(),
412 position: pos,
413 },
414 ));
415 final_positions.push((item_idx, pos));
416 final_assignments[item_idx] = cluster_idx;
417 }
418 }
419
420 entries.sort_by_key(|(idx, _)| *idx);
421 let entries: Vec<LayoutEntry<T>> = entries.into_iter().map(|(_, e)| e).collect();
422
423 final_positions.sort_by_key(|(idx, _)| *idx);
424 let positions: Vec<SphericalPoint> = final_positions.into_iter().map(|(_, p)| p).collect();
425
426 let quality = compute_quality(&positions, &final_assignments, k);
427
428 LayoutResult { entries, quality }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 struct FixedMapper {
437 positions: Vec<SphericalPoint>,
438 }
439
440 impl DimensionMapper for FixedMapper {
441 type Item = usize;
442 fn map(&self, item: &usize) -> SphericalPoint {
443 self.positions[*item]
444 }
445 }
446
447 #[test]
448 fn empty_items_returns_empty_result() {
449 let layout = ClusteredLayout::new();
450 let mapper = FixedMapper { positions: vec![] };
451 let result = layout.layout(&[], &mapper);
452 assert!(result.entries.is_empty());
453 }
454
455 #[test]
456 fn single_item_gets_placed() {
457 let layout = ClusteredLayout::new().with_clusters(1);
458 let mapper = FixedMapper {
459 positions: vec![SphericalPoint::new_unchecked(1.0, 0.5, 1.0)],
460 };
461 let result = layout.layout(&[0usize], &mapper);
462 assert_eq!(result.entries.len(), 1);
463 assert!((result.entries[0].position.r - 1.0).abs() < 1e-12);
464 }
465
466 #[test]
467 fn correct_number_of_entries() {
468 let layout = ClusteredLayout::new().with_clusters(3);
469 let positions: Vec<SphericalPoint> = (0..20)
470 .map(|i| {
471 let theta = (i as f64 * 0.3) % (2.0 * PI);
472 SphericalPoint::new_unchecked(1.0, theta, 1.0)
473 })
474 .collect();
475 let mapper = FixedMapper { positions };
476 let items: Vec<usize> = (0..20).collect();
477 let result = layout.layout(&items, &mapper);
478 assert_eq!(result.entries.len(), 20);
479 }
480
481 #[test]
482 fn items_in_same_cluster_are_angularly_close() {
483 let mut positions = Vec::new();
484 for i in 0..5 {
485 positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.1));
486 }
487 for i in 0..5 {
488 positions.push(SphericalPoint::new_unchecked(
489 1.0,
490 0.01 * i as f64,
491 PI - 0.1,
492 ));
493 }
494 let mapper = FixedMapper { positions };
495 let items: Vec<usize> = (0..10).collect();
496 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
497 let result = layout.layout(&items, &mapper);
498
499 let group_a: Vec<&SphericalPoint> =
500 result.entries[..5].iter().map(|e| &e.position).collect();
501 for i in 0..group_a.len() {
502 for j in (i + 1)..group_a.len() {
503 let d = angular_distance(group_a[i], group_a[j]);
504 assert!(d < 1.0, "Intra-cluster distance too large: {d}");
505 }
506 }
507 }
508
509 #[test]
510 fn different_clusters_are_angularly_separated() {
511 let mut positions = Vec::new();
512 for i in 0..5 {
513 positions.push(SphericalPoint::new_unchecked(
514 1.0,
515 0.01 * i as f64,
516 PI / 2.0,
517 ));
518 }
519 for i in 0..5 {
520 positions.push(SphericalPoint::new_unchecked(
521 1.0,
522 PI + 0.01 * i as f64,
523 PI / 2.0,
524 ));
525 }
526 let mapper = FixedMapper { positions };
527 let items: Vec<usize> = (0..10).collect();
528 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
529 let result = layout.layout(&items, &mapper);
530
531 let p_a = &result.entries[0].position;
532 let p_b = &result.entries[5].position;
533 let d = angular_distance(p_a, p_b);
534 assert!(d > 1.0, "Inter-cluster distance too small: {d}");
535 }
536
537 #[test]
538 fn silhouette_positive_for_well_separated_data() {
539 let mut positions = Vec::new();
540 for i in 0..10 {
541 positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.2));
542 }
543 for i in 0..10 {
544 positions.push(SphericalPoint::new_unchecked(
545 1.0,
546 PI + 0.01 * i as f64,
547 PI - 0.2,
548 ));
549 }
550 let mapper = FixedMapper { positions };
551 let items: Vec<usize> = (0..20).collect();
552 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.15);
553 let result = layout.layout(&items, &mapper);
554 assert!(
555 result.quality.silhouette_score > 0.0,
556 "Silhouette should be positive for well-separated clusters, got {}",
557 result.quality.silhouette_score
558 );
559 }
560
561 #[test]
562 fn builder_methods_apply() {
563 let layout = ClusteredLayout::new()
564 .with_clusters(8)
565 .with_radius(2.5)
566 .with_spread(0.5);
567 assert_eq!(layout.num_clusters, 8);
568 assert!((layout.radius - 2.5).abs() < 1e-12);
569 assert!((layout.intra_cluster_spread - 0.5).abs() < 1e-12);
570 }
571
572 #[test]
573 fn output_radius_matches_configured() {
574 let layout = ClusteredLayout::new().with_radius(3.0).with_clusters(2);
575 let positions = vec![
576 SphericalPoint::new_unchecked(1.0, 0.0, 0.5),
577 SphericalPoint::new_unchecked(1.0, PI, 2.0),
578 ];
579 let mapper = FixedMapper { positions };
580 let result = layout.layout(&[0usize, 1], &mapper);
581 for entry in &result.entries {
582 assert!(
583 (entry.position.r - 3.0).abs() < 1e-12,
584 "Expected radius 3.0, got {}",
585 entry.position.r
586 );
587 }
588 }
589}