1use std::f64::consts::PI;
2
3use sphereql_core::{
4 CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5 spherical_to_cartesian,
6};
7
8use crate::quality::{MAX_QUALITY_N, OVERLAP_THRESHOLD};
9use crate::traits::{DimensionMapper, LayoutStrategy};
10use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
11
12const MAX_KMEANS_ITERATIONS: usize = 50;
13
14pub struct ClusteredLayout {
15 pub num_clusters: usize,
16 pub radius: f64,
17 pub intra_cluster_spread: f64,
18}
19
20impl Default for ClusteredLayout {
21 fn default() -> Self {
22 Self {
23 num_clusters: 4,
24 radius: 1.0,
25 intra_cluster_spread: 0.3,
26 }
27 }
28}
29
30impl ClusteredLayout {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_clusters(mut self, n: usize) -> Self {
36 self.num_clusters = n;
37 self
38 }
39
40 pub fn with_radius(mut self, r: f64) -> Self {
41 self.radius = r;
42 self
43 }
44
45 pub fn with_spread(mut self, s: f64) -> Self {
46 self.intra_cluster_spread = s;
47 self
48 }
49}
50
51fn evenly_spaced_centers(k: usize) -> Vec<CartesianPoint> {
52 let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
53 (0..k)
54 .map(|i| {
55 let phi = (1.0 - 2.0 * (i as f64 + 0.5) / k as f64)
56 .clamp(-1.0, 1.0)
57 .acos();
58 let theta = (2.0 * PI * (i as f64) / golden_ratio).rem_euclid(2.0 * PI);
59 let sp = SphericalPoint::new_unchecked(1.0, theta, phi);
60 spherical_to_cartesian(&sp)
61 })
62 .collect()
63}
64
65fn normalized_mean(points: &[CartesianPoint]) -> CartesianPoint {
66 if points.is_empty() {
67 return CartesianPoint::new(0.0, 0.0, 1.0);
68 }
69 let mut sx = 0.0;
70 let mut sy = 0.0;
71 let mut sz = 0.0;
72 for p in points {
73 sx += p.x;
74 sy += p.y;
75 sz += p.z;
76 }
77 let mean = CartesianPoint::new(sx, sy, sz);
78 let n = mean.normalize();
79 if n.magnitude() < 1e-12 {
84 for p in points {
85 if p.magnitude() >= 1e-12 {
86 return p.normalize();
87 }
88 }
89 CartesianPoint::new(0.0, 0.0, 1.0)
90 } else {
91 n
92 }
93}
94
95struct KMeansResult {
96 assignments: Vec<usize>,
97 centers: Vec<CartesianPoint>,
98}
99
100fn kmeans_spherical(mapped_cartesian: &[CartesianPoint], k: usize) -> KMeansResult {
101 let n = mapped_cartesian.len();
102
103 let mut centers: Vec<CartesianPoint> = if n >= k {
104 mapped_cartesian[..k]
105 .iter()
106 .map(|c| c.normalize())
107 .collect()
108 } else {
109 evenly_spaced_centers(k)
110 };
111
112 let mut assignments = vec![0usize; n];
113
114 #[inline]
122 fn dot(a: &CartesianPoint, b: &CartesianPoint) -> f64 {
123 a.x * b.x + a.y * b.y + a.z * b.z
124 }
125
126 for _ in 0..MAX_KMEANS_ITERATIONS {
127 let mut changed = false;
128
129 for (i, point) in mapped_cartesian.iter().enumerate() {
130 let mut best = 0;
131 let mut best_dot = f64::MIN;
132 for (j, center) in centers.iter().enumerate() {
133 let d = dot(point, center);
134 if d > best_dot {
135 best_dot = d;
136 best = j;
137 }
138 }
139 if assignments[i] != best {
140 assignments[i] = best;
141 changed = true;
142 }
143 }
144
145 if !changed {
146 break;
147 }
148
149 let mut cluster_points: Vec<Vec<CartesianPoint>> = vec![vec![]; k];
150 for (i, &a) in assignments.iter().enumerate() {
151 cluster_points[a].push(mapped_cartesian[i]);
152 }
153
154 for (j, cp) in cluster_points.iter().enumerate() {
155 if cp.is_empty() {
156 let mut farthest_idx = 0;
160 let mut farthest_dot = f64::MAX;
161 for (i, point) in mapped_cartesian.iter().enumerate() {
162 let d = dot(point, ¢ers[assignments[i]]);
163 if d < farthest_dot {
164 farthest_dot = d;
165 farthest_idx = i;
166 }
167 }
168 centers[j] = mapped_cartesian[farthest_idx].normalize();
169 } else {
170 centers[j] = normalized_mean(cp);
171 }
172 }
173 }
174
175 KMeansResult {
176 assignments,
177 centers,
178 }
179}
180
181fn fibonacci_sub_spiral(
182 center: &SphericalPoint,
183 count: usize,
184 spread: f64,
185 radius: f64,
186) -> Vec<SphericalPoint> {
187 if count == 0 {
188 return vec![];
189 }
190 if count == 1 {
191 return vec![SphericalPoint::new_unchecked(
192 radius,
193 center.theta,
194 center.phi,
195 )];
196 }
197
198 let golden_angle = PI * (3.0 - 5.0_f64.sqrt());
199 let center_cart = spherical_to_cartesian(&SphericalPoint::new_unchecked(
200 1.0,
201 center.theta,
202 center.phi,
203 ));
204
205 let (tangent_u, tangent_v) = local_frame(¢er_cart);
206
207 (0..count)
208 .map(|i| {
209 let frac = i as f64 / count as f64;
210 let angular_r = spread * frac.sqrt();
211 let angle = golden_angle * i as f64;
212
213 let offset_u = angular_r * angle.cos();
214 let offset_v = angular_r * angle.sin();
215
216 let displaced = CartesianPoint::new(
217 center_cart.x + offset_u * tangent_u.x + offset_v * tangent_v.x,
218 center_cart.y + offset_u * tangent_u.y + offset_v * tangent_v.y,
219 center_cart.z + offset_u * tangent_u.z + offset_v * tangent_v.z,
220 )
221 .normalize();
222
223 let sp = cartesian_to_spherical(&displaced);
224 SphericalPoint::new_unchecked(radius, sp.theta, sp.phi)
225 })
226 .collect()
227}
228
229fn local_frame(center: &CartesianPoint) -> (CartesianPoint, CartesianPoint) {
230 let up = if center.z.abs() < 0.9 {
231 CartesianPoint::new(0.0, 0.0, 1.0)
232 } else {
233 CartesianPoint::new(1.0, 0.0, 0.0)
234 };
235
236 let ux = up.y * center.z - up.z * center.y;
238 let uy = up.z * center.x - up.x * center.z;
239 let uz = up.x * center.y - up.y * center.x;
240 let u = CartesianPoint::new(ux, uy, uz).normalize();
241
242 let vx = center.y * u.z - center.z * u.y;
244 let vy = center.z * u.x - center.x * u.z;
245 let vz = center.x * u.y - center.y * u.x;
246 let v = CartesianPoint::new(vx, vy, vz).normalize();
247
248 (u, v)
249}
250
251fn compute_quality(
252 positions: &[SphericalPoint],
253 assignments: &[usize],
254 num_clusters: usize,
255) -> LayoutQuality {
256 let n = positions.len();
257
258 if n <= 1 {
259 return LayoutQuality {
260 dispersion_score: if n == 0 { 0.0 } else { 1.0 },
261 overlap_score: 0.0,
262 silhouette_score: 0.0,
263 };
264 }
265
266 let (positions, assignments, n) = if n > MAX_QUALITY_N {
267 let step = n / MAX_QUALITY_N;
268 let sampled_pos: Vec<_> = positions
269 .iter()
270 .step_by(step)
271 .take(MAX_QUALITY_N)
272 .copied()
273 .collect();
274 let sampled_asgn: Vec<_> = assignments
275 .iter()
276 .step_by(step)
277 .take(MAX_QUALITY_N)
278 .copied()
279 .collect();
280 let len = sampled_pos.len();
281 (sampled_pos, sampled_asgn, len)
282 } else {
283 (positions.to_vec(), assignments.to_vec(), n)
284 };
285
286 let mut cluster_point_sets: Vec<Vec<CartesianPoint>> = vec![vec![]; num_clusters];
288 for (i, &a) in assignments.iter().enumerate() {
289 cluster_point_sets[a].push(spherical_to_cartesian(&positions[i]));
290 }
291 let active_centers: Vec<SphericalPoint> = cluster_point_sets
292 .iter()
293 .filter(|cp| !cp.is_empty())
294 .map(|cp| cartesian_to_spherical(&normalized_mean(cp)))
295 .collect();
296
297 use rayon::prelude::*;
303 const SERIAL_THRESHOLD: usize = 128;
304
305 let dispersion_score = if active_centers.len() >= 2 {
306 let len = active_centers.len();
307 let (sum, count) = if len < SERIAL_THRESHOLD {
308 let mut s = 0.0;
309 let mut c = 0u64;
310 for i in 0..len {
311 for j in (i + 1)..len {
312 s += angular_distance(&active_centers[i], &active_centers[j]);
313 c += 1;
314 }
315 }
316 (s, c)
317 } else {
318 (0..len)
319 .into_par_iter()
320 .map(|i| {
321 let mut s = 0.0;
322 let mut c = 0u64;
323 for j in (i + 1)..len {
324 s += angular_distance(&active_centers[i], &active_centers[j]);
325 c += 1;
326 }
327 (s, c)
328 })
329 .reduce(|| (0.0, 0), |(sa, ca), (sb, cb)| (sa + sb, ca + cb))
330 };
331 (sum / count as f64 / PI).clamp(0.0, 1.0)
332 } else {
333 0.0
334 };
335
336 let total_pairs = (n * (n - 1)) / 2;
338 let overlap_count: u64 = if n < SERIAL_THRESHOLD {
339 let mut c = 0u64;
340 for i in 0..n {
341 for j in (i + 1)..n {
342 if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
343 c += 1;
344 }
345 }
346 }
347 c
348 } else {
349 (0..n)
350 .into_par_iter()
351 .map(|i| {
352 let mut c = 0u64;
353 for j in (i + 1)..n {
354 if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
355 c += 1;
356 }
357 }
358 c
359 })
360 .sum()
361 };
362 let overlap_score = if total_pairs > 0 {
363 overlap_count as f64 / total_pairs as f64
364 } else {
365 0.0
366 };
367
368 let silhouette_score = if num_clusters <= 1 || active_centers.len() <= 1 {
375 0.0
376 } else {
377 let cluster_members: Vec<Vec<usize>> = {
378 let mut buckets = vec![Vec::new(); num_clusters];
379 for (j, &cj) in assignments.iter().enumerate() {
380 if cj < num_clusters {
381 buckets[cj].push(j);
382 }
383 }
384 buckets
385 };
386
387 let per_point = |i: usize| -> f64 {
388 let ci = assignments[i];
389 let same = &cluster_members[ci];
390
391 let a = if same.len() <= 1 {
392 0.0
393 } else {
394 let s: f64 = same
395 .iter()
396 .filter(|&&j| j != i)
397 .map(|&j| angular_distance(&positions[i], &positions[j]))
398 .sum();
399 s / (same.len() - 1) as f64
400 };
401
402 let mut b = f64::MAX;
403 for (k, members) in cluster_members.iter().enumerate() {
404 if k == ci || members.is_empty() {
405 continue;
406 }
407 let s: f64 = members
408 .iter()
409 .map(|&j| angular_distance(&positions[i], &positions[j]))
410 .sum();
411 let mean_dist = s / members.len() as f64;
412 if mean_dist < b {
413 b = mean_dist;
414 }
415 }
416 if b == f64::MAX {
417 b = 0.0;
418 }
419
420 let denom = a.max(b);
421 if denom > 0.0 { (b - a) / denom } else { 0.0 }
422 };
423
424 let sil_sum: f64 = if n < SERIAL_THRESHOLD {
425 (0..n).map(per_point).sum()
426 } else {
427 (0..n).into_par_iter().map(per_point).sum()
428 };
429 sil_sum / n as f64
430 };
431
432 LayoutQuality {
433 dispersion_score,
434 overlap_score,
435 silhouette_score,
436 }
437}
438
439impl<T: Clone + Send + Sync> LayoutStrategy<T> for ClusteredLayout {
440 fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
441 if items.is_empty() {
442 return LayoutResult {
443 entries: vec![],
444 quality: LayoutQuality::default(),
445 };
446 }
447
448 let mapped: Vec<SphericalPoint> = items.iter().map(|item| mapper.map(item)).collect();
449 let mapped_cart: Vec<CartesianPoint> = mapped.iter().map(spherical_to_cartesian).collect();
450
451 let k = self.num_clusters.min(items.len()).max(1);
452 let km = kmeans_spherical(&mapped_cart, k);
453
454 let mut cluster_items: Vec<Vec<usize>> = vec![vec![]; k];
455 for (i, &a) in km.assignments.iter().enumerate() {
456 cluster_items[a].push(i);
457 }
458
459 let mut entries: Vec<(usize, LayoutEntry<T>)> = Vec::with_capacity(items.len());
460 let mut final_positions: Vec<(usize, SphericalPoint)> = Vec::with_capacity(items.len());
461 let mut final_assignments = vec![0usize; items.len()];
462
463 for (cluster_idx, member_indices) in cluster_items.iter().enumerate() {
464 let center_sp = cartesian_to_spherical(&km.centers[cluster_idx]);
465 let sub_positions = fibonacci_sub_spiral(
466 ¢er_sp,
467 member_indices.len(),
468 self.intra_cluster_spread,
469 self.radius,
470 );
471
472 for (sub_idx, &item_idx) in member_indices.iter().enumerate() {
473 let pos = sub_positions[sub_idx];
474 entries.push((
475 item_idx,
476 LayoutEntry {
477 item: items[item_idx].clone(),
478 position: pos,
479 },
480 ));
481 final_positions.push((item_idx, pos));
482 final_assignments[item_idx] = cluster_idx;
483 }
484 }
485
486 entries.sort_by_key(|(idx, _)| *idx);
487 let entries: Vec<LayoutEntry<T>> = entries.into_iter().map(|(_, e)| e).collect();
488
489 final_positions.sort_by_key(|(idx, _)| *idx);
490 let positions: Vec<SphericalPoint> = final_positions.into_iter().map(|(_, p)| p).collect();
491
492 let quality = compute_quality(&positions, &final_assignments, k);
493
494 LayoutResult { entries, quality }
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 struct FixedMapper {
503 positions: Vec<SphericalPoint>,
504 }
505
506 impl DimensionMapper for FixedMapper {
507 type Item = usize;
508 fn map(&self, item: &usize) -> SphericalPoint {
509 self.positions[*item]
510 }
511 }
512
513 #[test]
514 fn empty_items_returns_empty_result() {
515 let layout = ClusteredLayout::new();
516 let mapper = FixedMapper { positions: vec![] };
517 let result = layout.layout(&[], &mapper);
518 assert!(result.entries.is_empty());
519 }
520
521 #[test]
522 fn single_item_gets_placed() {
523 let layout = ClusteredLayout::new().with_clusters(1);
524 let mapper = FixedMapper {
525 positions: vec![SphericalPoint::new_unchecked(1.0, 0.5, 1.0)],
526 };
527 let result = layout.layout(&[0usize], &mapper);
528 assert_eq!(result.entries.len(), 1);
529 assert!((result.entries[0].position.r - 1.0).abs() < 1e-12);
530 }
531
532 #[test]
533 fn correct_number_of_entries() {
534 let layout = ClusteredLayout::new().with_clusters(3);
535 let positions: Vec<SphericalPoint> = (0..20)
536 .map(|i| {
537 let theta = (i as f64 * 0.3) % (2.0 * PI);
538 SphericalPoint::new_unchecked(1.0, theta, 1.0)
539 })
540 .collect();
541 let mapper = FixedMapper { positions };
542 let items: Vec<usize> = (0..20).collect();
543 let result = layout.layout(&items, &mapper);
544 assert_eq!(result.entries.len(), 20);
545 }
546
547 #[test]
548 fn items_in_same_cluster_are_angularly_close() {
549 let mut positions = Vec::new();
550 for i in 0..5 {
551 positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.1));
552 }
553 for i in 0..5 {
554 positions.push(SphericalPoint::new_unchecked(
555 1.0,
556 0.01 * i as f64,
557 PI - 0.1,
558 ));
559 }
560 let mapper = FixedMapper { positions };
561 let items: Vec<usize> = (0..10).collect();
562 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
563 let result = layout.layout(&items, &mapper);
564
565 let group_a: Vec<&SphericalPoint> =
566 result.entries[..5].iter().map(|e| &e.position).collect();
567 for i in 0..group_a.len() {
568 for j in (i + 1)..group_a.len() {
569 let d = angular_distance(group_a[i], group_a[j]);
570 assert!(d < 1.0, "Intra-cluster distance too large: {d}");
571 }
572 }
573 }
574
575 #[test]
576 fn different_clusters_are_angularly_separated() {
577 let mut positions = Vec::new();
578 for i in 0..5 {
579 positions.push(SphericalPoint::new_unchecked(
580 1.0,
581 0.01 * i as f64,
582 PI / 2.0,
583 ));
584 }
585 for i in 0..5 {
586 positions.push(SphericalPoint::new_unchecked(
587 1.0,
588 PI + 0.01 * i as f64,
589 PI / 2.0,
590 ));
591 }
592 let mapper = FixedMapper { positions };
593 let items: Vec<usize> = (0..10).collect();
594 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
595 let result = layout.layout(&items, &mapper);
596
597 let p_a = &result.entries[0].position;
598 let p_b = &result.entries[5].position;
599 let d = angular_distance(p_a, p_b);
600 assert!(d > 1.0, "Inter-cluster distance too small: {d}");
601 }
602
603 #[test]
604 fn silhouette_positive_for_well_separated_data() {
605 let mut positions = Vec::new();
606 for i in 0..10 {
607 positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.2));
608 }
609 for i in 0..10 {
610 positions.push(SphericalPoint::new_unchecked(
611 1.0,
612 PI + 0.01 * i as f64,
613 PI - 0.2,
614 ));
615 }
616 let mapper = FixedMapper { positions };
617 let items: Vec<usize> = (0..20).collect();
618 let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.15);
619 let result = layout.layout(&items, &mapper);
620 assert!(
621 result.quality.silhouette_score > 0.0,
622 "Silhouette should be positive for well-separated clusters, got {}",
623 result.quality.silhouette_score
624 );
625 }
626
627 #[test]
628 fn builder_methods_apply() {
629 let layout = ClusteredLayout::new()
630 .with_clusters(8)
631 .with_radius(2.5)
632 .with_spread(0.5);
633 assert_eq!(layout.num_clusters, 8);
634 assert!((layout.radius - 2.5).abs() < 1e-12);
635 assert!((layout.intra_cluster_spread - 0.5).abs() < 1e-12);
636 }
637
638 #[test]
639 fn output_radius_matches_configured() {
640 let layout = ClusteredLayout::new().with_radius(3.0).with_clusters(2);
641 let positions = vec![
642 SphericalPoint::new_unchecked(1.0, 0.0, 0.5),
643 SphericalPoint::new_unchecked(1.0, PI, 2.0),
644 ];
645 let mapper = FixedMapper { positions };
646 let result = layout.layout(&[0usize, 1], &mapper);
647 for entry in &result.entries {
648 assert!(
649 (entry.position.r - 3.0).abs() < 1e-12,
650 "Expected radius 3.0, got {}",
651 entry.position.r
652 );
653 }
654 }
655}