Skip to main content

sphereql_layout/
clustered.rs

1use std::f64::consts::PI;
2
3use sphereql_core::{
4    CartesianPoint, SphericalPoint, angular_distance, cartesian_to_spherical,
5    spherical_to_cartesian,
6};
7
8use crate::traits::{DimensionMapper, LayoutStrategy};
9use crate::types::{LayoutEntry, LayoutQuality, LayoutResult};
10
11const MAX_KMEANS_ITERATIONS: usize = 50;
12const OVERLAP_THRESHOLD: f64 = 0.01;
13
14pub struct ClusteredLayout {
15    pub num_clusters: usize,
16    pub radius: f64,
17    pub intra_cluster_spread: f64,
18}
19
20impl Default for ClusteredLayout {
21    fn default() -> Self {
22        Self {
23            num_clusters: 4,
24            radius: 1.0,
25            intra_cluster_spread: 0.3,
26        }
27    }
28}
29
30impl ClusteredLayout {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    pub fn with_clusters(mut self, n: usize) -> Self {
36        self.num_clusters = n;
37        self
38    }
39
40    pub fn with_radius(mut self, r: f64) -> Self {
41        self.radius = r;
42        self
43    }
44
45    pub fn with_spread(mut self, s: f64) -> Self {
46        self.intra_cluster_spread = s;
47        self
48    }
49}
50
51fn evenly_spaced_centers(k: usize) -> Vec<CartesianPoint> {
52    let golden_ratio = (1.0 + 5.0_f64.sqrt()) / 2.0;
53    (0..k)
54        .map(|i| {
55            let phi = (1.0 - 2.0 * (i as f64 + 0.5) / k as f64)
56                .clamp(-1.0, 1.0)
57                .acos();
58            let theta = (2.0 * PI * (i as f64) / golden_ratio).rem_euclid(2.0 * PI);
59            let sp = SphericalPoint::new_unchecked(1.0, theta, phi);
60            spherical_to_cartesian(&sp)
61        })
62        .collect()
63}
64
65fn normalized_mean(points: &[CartesianPoint]) -> CartesianPoint {
66    if points.is_empty() {
67        return CartesianPoint::new(0.0, 0.0, 1.0);
68    }
69    let mut sx = 0.0;
70    let mut sy = 0.0;
71    let mut sz = 0.0;
72    for p in points {
73        sx += p.x;
74        sy += p.y;
75        sz += p.z;
76    }
77    let mean = CartesianPoint::new(sx, sy, sz);
78    let n = mean.normalize();
79    if n.magnitude() == 0.0 {
80        points[0].normalize()
81    } else {
82        n
83    }
84}
85
86struct KMeansResult {
87    assignments: Vec<usize>,
88    centers: Vec<CartesianPoint>,
89}
90
91fn kmeans_spherical(mapped_cartesian: &[CartesianPoint], k: usize) -> KMeansResult {
92    let n = mapped_cartesian.len();
93
94    let mut centers: Vec<CartesianPoint> = if n >= k {
95        mapped_cartesian[..k]
96            .iter()
97            .map(|c| c.normalize())
98            .collect()
99    } else {
100        evenly_spaced_centers(k)
101    };
102
103    let mut assignments = vec![0usize; n];
104
105    // k-means inner loop in Cartesian. Both points and centers live on
106    // the unit sphere, so angular distance `acos(dot)` and
107    // `maximize dot` pick the same cluster for every point (acos is
108    // monotonic). The previous implementation converted every
109    // Cartesian center back to spherical and called `angular_distance`
110    // per (point, center) pair — ~50·n·k redundant trig calls per
111    // k-means run, which dominated layout cost.
112    #[inline]
113    fn dot(a: &CartesianPoint, b: &CartesianPoint) -> f64 {
114        a.x * b.x + a.y * b.y + a.z * b.z
115    }
116
117    for _ in 0..MAX_KMEANS_ITERATIONS {
118        let mut changed = false;
119
120        for (i, point) in mapped_cartesian.iter().enumerate() {
121            let mut best = 0;
122            let mut best_dot = f64::MIN;
123            for (j, center) in centers.iter().enumerate() {
124                let d = dot(point, center);
125                if d > best_dot {
126                    best_dot = d;
127                    best = j;
128                }
129            }
130            if assignments[i] != best {
131                assignments[i] = best;
132                changed = true;
133            }
134        }
135
136        if !changed {
137            break;
138        }
139
140        let mut cluster_points: Vec<Vec<CartesianPoint>> = vec![vec![]; k];
141        for (i, &a) in assignments.iter().enumerate() {
142            cluster_points[a].push(mapped_cartesian[i]);
143        }
144
145        for (j, cp) in cluster_points.iter().enumerate() {
146            if cp.is_empty() {
147                // Reseed from whichever point is farthest from its
148                // current cluster center — Cartesian dot gives the
149                // ordering (smaller dot = larger angular distance).
150                let mut farthest_idx = 0;
151                let mut farthest_dot = f64::MAX;
152                for (i, point) in mapped_cartesian.iter().enumerate() {
153                    let d = dot(point, &centers[assignments[i]]);
154                    if d < farthest_dot {
155                        farthest_dot = d;
156                        farthest_idx = i;
157                    }
158                }
159                centers[j] = mapped_cartesian[farthest_idx].normalize();
160            } else {
161                centers[j] = normalized_mean(cp);
162            }
163        }
164    }
165
166    KMeansResult {
167        assignments,
168        centers,
169    }
170}
171
172fn fibonacci_sub_spiral(
173    center: &SphericalPoint,
174    count: usize,
175    spread: f64,
176    radius: f64,
177) -> Vec<SphericalPoint> {
178    if count == 0 {
179        return vec![];
180    }
181    if count == 1 {
182        return vec![SphericalPoint::new_unchecked(
183            radius,
184            center.theta,
185            center.phi,
186        )];
187    }
188
189    let golden_angle = PI * (3.0 - 5.0_f64.sqrt());
190    let center_cart = spherical_to_cartesian(&SphericalPoint::new_unchecked(
191        1.0,
192        center.theta,
193        center.phi,
194    ));
195
196    let (tangent_u, tangent_v) = local_frame(&center_cart);
197
198    (0..count)
199        .map(|i| {
200            let frac = i as f64 / count as f64;
201            let angular_r = spread * frac.sqrt();
202            let angle = golden_angle * i as f64;
203
204            let offset_u = angular_r * angle.cos();
205            let offset_v = angular_r * angle.sin();
206
207            let displaced = CartesianPoint::new(
208                center_cart.x + offset_u * tangent_u.x + offset_v * tangent_v.x,
209                center_cart.y + offset_u * tangent_u.y + offset_v * tangent_v.y,
210                center_cart.z + offset_u * tangent_u.z + offset_v * tangent_v.z,
211            )
212            .normalize();
213
214            let sp = cartesian_to_spherical(&displaced);
215            SphericalPoint::new_unchecked(radius, sp.theta, sp.phi)
216        })
217        .collect()
218}
219
220fn local_frame(center: &CartesianPoint) -> (CartesianPoint, CartesianPoint) {
221    let up = if center.z.abs() < 0.9 {
222        CartesianPoint::new(0.0, 0.0, 1.0)
223    } else {
224        CartesianPoint::new(1.0, 0.0, 0.0)
225    };
226
227    // u = normalize(up x center)
228    let ux = up.y * center.z - up.z * center.y;
229    let uy = up.z * center.x - up.x * center.z;
230    let uz = up.x * center.y - up.y * center.x;
231    let u = CartesianPoint::new(ux, uy, uz).normalize();
232
233    // v = center x u
234    let vx = center.y * u.z - center.z * u.y;
235    let vy = center.z * u.x - center.x * u.z;
236    let vz = center.x * u.y - center.y * u.x;
237    let v = CartesianPoint::new(vx, vy, vz).normalize();
238
239    (u, v)
240}
241
242const MAX_QUALITY_N: usize = 5000;
243
244fn compute_quality(
245    positions: &[SphericalPoint],
246    assignments: &[usize],
247    num_clusters: usize,
248) -> LayoutQuality {
249    let n = positions.len();
250
251    if n <= 1 {
252        return LayoutQuality {
253            dispersion_score: if n == 0 { 0.0 } else { 1.0 },
254            overlap_score: 0.0,
255            silhouette_score: 0.0,
256        };
257    }
258
259    let (positions, assignments, n) = if n > MAX_QUALITY_N {
260        let step = n / MAX_QUALITY_N;
261        let sampled_pos: Vec<_> = positions
262            .iter()
263            .step_by(step)
264            .take(MAX_QUALITY_N)
265            .copied()
266            .collect();
267        let sampled_asgn: Vec<_> = assignments
268            .iter()
269            .step_by(step)
270            .take(MAX_QUALITY_N)
271            .copied()
272            .collect();
273        let len = sampled_pos.len();
274        (sampled_pos, sampled_asgn, len)
275    } else {
276        (positions.to_vec(), assignments.to_vec(), n)
277    };
278
279    // Dispersion: average inter-cluster center distance / PI
280    let mut cluster_point_sets: Vec<Vec<CartesianPoint>> = vec![vec![]; num_clusters];
281    for (i, &a) in assignments.iter().enumerate() {
282        cluster_point_sets[a].push(spherical_to_cartesian(&positions[i]));
283    }
284    let active_centers: Vec<SphericalPoint> = cluster_point_sets
285        .iter()
286        .filter(|cp| !cp.is_empty())
287        .map(|cp| cartesian_to_spherical(&normalized_mean(cp)))
288        .collect();
289
290    // Every pair-scan below is embarrassingly parallel: pure reads
291    // over `positions` / `active_centers` / `assignments`, per-point
292    // scalar reductions. `rayon::par_iter` over the outer index
293    // keeps the reduce trivial and skips the thread-pool overhead
294    // for small corpora.
295    use rayon::prelude::*;
296    const SERIAL_THRESHOLD: usize = 128;
297
298    let dispersion_score = if active_centers.len() >= 2 {
299        let len = active_centers.len();
300        let (sum, count) = if len < SERIAL_THRESHOLD {
301            let mut s = 0.0;
302            let mut c = 0u64;
303            for i in 0..len {
304                for j in (i + 1)..len {
305                    s += angular_distance(&active_centers[i], &active_centers[j]);
306                    c += 1;
307                }
308            }
309            (s, c)
310        } else {
311            (0..len)
312                .into_par_iter()
313                .map(|i| {
314                    let mut s = 0.0;
315                    let mut c = 0u64;
316                    for j in (i + 1)..len {
317                        s += angular_distance(&active_centers[i], &active_centers[j]);
318                        c += 1;
319                    }
320                    (s, c)
321                })
322                .reduce(|| (0.0, 0), |(sa, ca), (sb, cb)| (sa + sb, ca + cb))
323        };
324        (sum / count as f64 / PI).clamp(0.0, 1.0)
325    } else {
326        0.0
327    };
328
329    // Overlap: fraction of pairs within threshold.
330    let total_pairs = (n * (n - 1)) / 2;
331    let overlap_count: u64 = if n < SERIAL_THRESHOLD {
332        let mut c = 0u64;
333        for i in 0..n {
334            for j in (i + 1)..n {
335                if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
336                    c += 1;
337                }
338            }
339        }
340        c
341    } else {
342        (0..n)
343            .into_par_iter()
344            .map(|i| {
345                let mut c = 0u64;
346                for j in (i + 1)..n {
347                    if angular_distance(&positions[i], &positions[j]) < OVERLAP_THRESHOLD {
348                        c += 1;
349                    }
350                }
351                c
352            })
353            .sum()
354    };
355    let overlap_score = if total_pairs > 0 {
356        overlap_count as f64 / total_pairs as f64
357    } else {
358        0.0
359    };
360
361    // Silhouette coefficient — per-point independent scans. Parallel
362    // over the outer `i` is safe since each worker builds a local
363    // scalar; the reduction is a single f64 sum.
364    let silhouette_score = if num_clusters <= 1 || active_centers.len() <= 1 {
365        0.0
366    } else {
367        let per_point = |i: usize| -> f64 {
368            let ci = assignments[i];
369
370            // a(i) = mean distance to same-cluster members
371            let mut a_sum = 0.0;
372            let mut a_count = 0;
373            for j in 0..n {
374                if j != i && assignments[j] == ci {
375                    a_sum += angular_distance(&positions[i], &positions[j]);
376                    a_count += 1;
377                }
378            }
379            let a = if a_count > 0 {
380                a_sum / a_count as f64
381            } else {
382                0.0
383            };
384
385            // b(i) = min over other clusters of mean distance to that cluster
386            let mut b = f64::MAX;
387            for k in 0..num_clusters {
388                if k == ci {
389                    continue;
390                }
391                let mut b_sum = 0.0;
392                let mut b_count = 0;
393                for j in 0..n {
394                    if assignments[j] == k {
395                        b_sum += angular_distance(&positions[i], &positions[j]);
396                        b_count += 1;
397                    }
398                }
399                if b_count > 0 {
400                    let mean_dist = b_sum / b_count as f64;
401                    if mean_dist < b {
402                        b = mean_dist;
403                    }
404                }
405            }
406            if b == f64::MAX {
407                b = 0.0;
408            }
409
410            let denom = a.max(b);
411            if denom > 0.0 { (b - a) / denom } else { 0.0 }
412        };
413
414        let sil_sum: f64 = if n < SERIAL_THRESHOLD {
415            (0..n).map(per_point).sum()
416        } else {
417            (0..n).into_par_iter().map(per_point).sum()
418        };
419        sil_sum / n as f64
420    };
421
422    LayoutQuality {
423        dispersion_score,
424        overlap_score,
425        silhouette_score,
426    }
427}
428
429impl<T: Clone + Send + Sync> LayoutStrategy<T> for ClusteredLayout {
430    fn layout(&self, items: &[T], mapper: &dyn DimensionMapper<Item = T>) -> LayoutResult<T> {
431        if items.is_empty() {
432            return LayoutResult {
433                entries: vec![],
434                quality: LayoutQuality::default(),
435            };
436        }
437
438        let mapped: Vec<SphericalPoint> = items.iter().map(|item| mapper.map(item)).collect();
439        let mapped_cart: Vec<CartesianPoint> = mapped.iter().map(spherical_to_cartesian).collect();
440
441        let k = self.num_clusters.min(items.len()).max(1);
442        let km = kmeans_spherical(&mapped_cart, k);
443
444        let mut cluster_items: Vec<Vec<usize>> = vec![vec![]; k];
445        for (i, &a) in km.assignments.iter().enumerate() {
446            cluster_items[a].push(i);
447        }
448
449        let mut entries: Vec<(usize, LayoutEntry<T>)> = Vec::with_capacity(items.len());
450        let mut final_positions: Vec<(usize, SphericalPoint)> = Vec::with_capacity(items.len());
451        let mut final_assignments = vec![0usize; items.len()];
452
453        for (cluster_idx, member_indices) in cluster_items.iter().enumerate() {
454            let center_sp = cartesian_to_spherical(&km.centers[cluster_idx]);
455            let sub_positions = fibonacci_sub_spiral(
456                &center_sp,
457                member_indices.len(),
458                self.intra_cluster_spread,
459                self.radius,
460            );
461
462            for (sub_idx, &item_idx) in member_indices.iter().enumerate() {
463                let pos = sub_positions[sub_idx];
464                entries.push((
465                    item_idx,
466                    LayoutEntry {
467                        item: items[item_idx].clone(),
468                        position: pos,
469                    },
470                ));
471                final_positions.push((item_idx, pos));
472                final_assignments[item_idx] = cluster_idx;
473            }
474        }
475
476        entries.sort_by_key(|(idx, _)| *idx);
477        let entries: Vec<LayoutEntry<T>> = entries.into_iter().map(|(_, e)| e).collect();
478
479        final_positions.sort_by_key(|(idx, _)| *idx);
480        let positions: Vec<SphericalPoint> = final_positions.into_iter().map(|(_, p)| p).collect();
481
482        let quality = compute_quality(&positions, &final_assignments, k);
483
484        LayoutResult { entries, quality }
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    struct FixedMapper {
493        positions: Vec<SphericalPoint>,
494    }
495
496    impl DimensionMapper for FixedMapper {
497        type Item = usize;
498        fn map(&self, item: &usize) -> SphericalPoint {
499            self.positions[*item]
500        }
501    }
502
503    #[test]
504    fn empty_items_returns_empty_result() {
505        let layout = ClusteredLayout::new();
506        let mapper = FixedMapper { positions: vec![] };
507        let result = layout.layout(&[], &mapper);
508        assert!(result.entries.is_empty());
509    }
510
511    #[test]
512    fn single_item_gets_placed() {
513        let layout = ClusteredLayout::new().with_clusters(1);
514        let mapper = FixedMapper {
515            positions: vec![SphericalPoint::new_unchecked(1.0, 0.5, 1.0)],
516        };
517        let result = layout.layout(&[0usize], &mapper);
518        assert_eq!(result.entries.len(), 1);
519        assert!((result.entries[0].position.r - 1.0).abs() < 1e-12);
520    }
521
522    #[test]
523    fn correct_number_of_entries() {
524        let layout = ClusteredLayout::new().with_clusters(3);
525        let positions: Vec<SphericalPoint> = (0..20)
526            .map(|i| {
527                let theta = (i as f64 * 0.3) % (2.0 * PI);
528                SphericalPoint::new_unchecked(1.0, theta, 1.0)
529            })
530            .collect();
531        let mapper = FixedMapper { positions };
532        let items: Vec<usize> = (0..20).collect();
533        let result = layout.layout(&items, &mapper);
534        assert_eq!(result.entries.len(), 20);
535    }
536
537    #[test]
538    fn items_in_same_cluster_are_angularly_close() {
539        let mut positions = Vec::new();
540        for i in 0..5 {
541            positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.1));
542        }
543        for i in 0..5 {
544            positions.push(SphericalPoint::new_unchecked(
545                1.0,
546                0.01 * i as f64,
547                PI - 0.1,
548            ));
549        }
550        let mapper = FixedMapper { positions };
551        let items: Vec<usize> = (0..10).collect();
552        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
553        let result = layout.layout(&items, &mapper);
554
555        let group_a: Vec<&SphericalPoint> =
556            result.entries[..5].iter().map(|e| &e.position).collect();
557        for i in 0..group_a.len() {
558            for j in (i + 1)..group_a.len() {
559                let d = angular_distance(group_a[i], group_a[j]);
560                assert!(d < 1.0, "Intra-cluster distance too large: {d}");
561            }
562        }
563    }
564
565    #[test]
566    fn different_clusters_are_angularly_separated() {
567        let mut positions = Vec::new();
568        for i in 0..5 {
569            positions.push(SphericalPoint::new_unchecked(
570                1.0,
571                0.01 * i as f64,
572                PI / 2.0,
573            ));
574        }
575        for i in 0..5 {
576            positions.push(SphericalPoint::new_unchecked(
577                1.0,
578                PI + 0.01 * i as f64,
579                PI / 2.0,
580            ));
581        }
582        let mapper = FixedMapper { positions };
583        let items: Vec<usize> = (0..10).collect();
584        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.2);
585        let result = layout.layout(&items, &mapper);
586
587        let p_a = &result.entries[0].position;
588        let p_b = &result.entries[5].position;
589        let d = angular_distance(p_a, p_b);
590        assert!(d > 1.0, "Inter-cluster distance too small: {d}");
591    }
592
593    #[test]
594    fn silhouette_positive_for_well_separated_data() {
595        let mut positions = Vec::new();
596        for i in 0..10 {
597            positions.push(SphericalPoint::new_unchecked(1.0, 0.01 * i as f64, 0.2));
598        }
599        for i in 0..10 {
600            positions.push(SphericalPoint::new_unchecked(
601                1.0,
602                PI + 0.01 * i as f64,
603                PI - 0.2,
604            ));
605        }
606        let mapper = FixedMapper { positions };
607        let items: Vec<usize> = (0..20).collect();
608        let layout = ClusteredLayout::new().with_clusters(2).with_spread(0.15);
609        let result = layout.layout(&items, &mapper);
610        assert!(
611            result.quality.silhouette_score > 0.0,
612            "Silhouette should be positive for well-separated clusters, got {}",
613            result.quality.silhouette_score
614        );
615    }
616
617    #[test]
618    fn builder_methods_apply() {
619        let layout = ClusteredLayout::new()
620            .with_clusters(8)
621            .with_radius(2.5)
622            .with_spread(0.5);
623        assert_eq!(layout.num_clusters, 8);
624        assert!((layout.radius - 2.5).abs() < 1e-12);
625        assert!((layout.intra_cluster_spread - 0.5).abs() < 1e-12);
626    }
627
628    #[test]
629    fn output_radius_matches_configured() {
630        let layout = ClusteredLayout::new().with_radius(3.0).with_clusters(2);
631        let positions = vec![
632            SphericalPoint::new_unchecked(1.0, 0.0, 0.5),
633            SphericalPoint::new_unchecked(1.0, PI, 2.0),
634        ];
635        let mapper = FixedMapper { positions };
636        let result = layout.layout(&[0usize, 1], &mapper);
637        for entry in &result.entries {
638            assert!(
639                (entry.position.r - 3.0).abs() < 1e-12,
640                "Expected radius 3.0, got {}",
641                entry.position.r
642            );
643        }
644    }
645}